Почему распределённые вычисления становятся обязательными
Современные модели искусственного интеллекта требуют гигабайтов памяти и терафлопов вычислительной мощности. Одна видеокарта уже не может обеспечить необходимый объём ресурсов, поэтому разработчики переходят к распределённым системам, где несколько GPU работают совместно. Основным инструментом для организации такой параллельной работы в экосистеме PyTorch является пакет torch.distributed. Он позволяет реализовать два фундаментальных типа коммуникаций: точечные (point‑to‑point) и коллективные (collective) операции.
Точечные операции: send / recv и их применение
Точечные операции — это прямой обмен данными между двумя процессами (или двумя GPU). В PyTorch они реализованы через функции torch.distributed.send и torch.distributed.recv.
# процесс 0 отправляет тензор t
torch.distributed.send(tensor=t, dst=1)
# процесс 1 получает тензор в буфер recv_t
torch.distributed.recv(tensor=recv_t, src=0)
Когда использовать
- Алгоритмы с асимметричной нагрузкой – когда один узел отвечает за подготовку данных, а остальные только их потребляют.
- Пайплайнинг – последовательный проход модели, где каждый слой размещён на отдельном GPU, и результаты передаются дальше по цепочке.
- Контрольные сообщения – небольшие сигналы синхронизации, которые не требуют глобального барьера.
Точечные операции дают гибкость, но требуют от разработчика явного управления порядком отправки и приёма, иначе возможны взаимные блокировки (deadlock). Для избежания такой ситуации часто используют неблокирующие версии isend/irecv, а затем вызывают torch.distributed.wait.
Коллективные операции: синхронизация и агрегация
Коллективные операции позволяют одновременно взаимодействовать со всеми процессами в группе. Ключевые функции:
| Операция | Описание | Пример |
|---|---|---|
broadcast | Рассылает один тензор от корневого процесса всем остальным | torch.distributed.broadcast(tensor, src=0) |
all_reduce | Выполняет редукцию (sum, max, min и т.д.) и распределяет результат обратно всем | torch.distributed.all_reduce(tensor, op=ReduceOp.SUM) |
reduce | Редуцирует данные только к одному процессу | torch.distributed.reduce(tensor, dst=0) |
scatter | Делит один большой тензор на части и раздаёт их процессам | torch.distributed.scatter(tensor_list, src=0) |
gather | Сбирает части тензоров от всех процессов в один | torch.distributed.gather(tensor, dst=0) |
all_gather | Собирает тензоры со всех процессов в список, доступный каждому | torch.distributed.all_gather(tensor_list, tensor) |
barrier | Останавливает процесс до тех пор, пока все не достигнут этой точки | torch.distributed.barrier() |
Применение в обучении моделей
- Синхронный SGD – каждый GPU вычисляет градиенты локально, а
all_reduceсуммирует их, после чего каждый процесс обновляет свои параметры. Это обеспечивает идентичные весовые коэффициенты на всех устройствах. - Data Parallelism –
broadcastиспользуется для распространения начального состояния модели от главного процессора к остальным. - Model Parallelism –
scatterиgatherпозволяют разбивать большие весовые матрицы между GPU и собирать их для последующего вычисления.
Коллективные операции оптимизированы под конкретные бекенды и часто реализуются с помощью высокопроизводительных библиотек, таких как NCCL (для NVIDIA GPU) и Gloo (универсальный, поддерживает CPU и GPU).
Выбор бекенда и инициализация процесса
Для корректной работы torch.distributed необходимо задать бекенд и инициализировать процессную группу. Наиболее распространённые варианты:
import torch.distributed as dist
dist.init_process_group(
backend='nccl', # 'nccl' для NVIDIA, 'gloo' для CPU/AMD
init_method='env://', # читаем переменные окружения (MASTER_ADDR, MASTER_PORT)
world_size=4, # количество процессов (GPU)
rank=local_rank # уникальный идентификатор текущего процесса
)
- NCCL обеспечивает наилучшую пропускную способность и низкую латентность при работе с несколькими NVIDIA GPU.
- Gloo более гибок, поддерживает как CPU, так и GPU, но обычно медленнее NCCL в чисто GPU‑сценариях.
Важно, чтобы каждый процесс запускался с правильным local_rank, часто передаваемым через переменную окружения CUDA_VISIBLE_DEVICES. В сценариях с torchrun (ранее torch.distributed.launch) это делается автоматически.
Практические рекомендации по производительности
- Минимизировать синхронные барьеры – каждый
barrierостанавливает весь кластер, поэтому используйте их только в случае реальной необходимости. - Пакетировать мелкие сообщения – вместо частых небольших
send/recvлучше собрать данные в один буфер и отправить за один раз. - Привязывать тензоры к нужным устройствам – перед вызовом коллективных функций убедитесь, что все тензоры находятся на тех же GPU, иначе будет лишняя копия через PCIe.
- Профилировать с
torch.profiler– он позволяет увидеть, какие операции занимают большую часть времени и где происходит узкое место коммуникации. - Учесть топологию узла – в многосокетных серверах связь между GPU, находящимися в разных сокетах, может иметь большую латентность; распределяйте процессные группы так, чтобы часто общающиеся GPU находились в одном сокете.
Пример полного цикла обучения с DistributedDataParallel
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def train(rank, world_size):
dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
model = nn.Linear(1024, 10).cuda()
ddp_model = DDP(model, device_ids=[rank])
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
dataset = torch.utils.data.TensorDataset(
torch.randn(10000, 1024), torch.randint(0, 10, (10000,))
)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler)
for epoch in range(5):
sampler.set_epoch(epoch) # обеспечивает перемешивание в каждом эпо
for data, target in loader:
data, target = data.cuda(rank), target.cuda(rank)
optimizer.zero_grad()
output = ddp_model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == '__main__':
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)
В этом примере DistributedDataParallel автоматически использует all_reduce для синхронизации градиентов, а DistributedSampler гарантирует, что каждый процесс получает уникальную часть датасета. Такой шаблон покрывает большинство сценариев обучения на нескольких GPU и служит надёжной отправной точкой для дальнейшего экспериментирования с точечными и коллективными операциями.