Почему один GPU часто не хватает
Современные нейросети быстро растут по параметрам: от сотен миллионов до нескольких миллиардов весов. При обучении на типичных датасетах (ImageNet, COCO, большие текстовые корпуса) размер батча, необходимый для стабильного градиентного спуска, может превышать 1 000 примеров. На одной видеокарте такой объём не помещается в видеопамять: даже топ‑модели с 24 ГБ VRAM не способны держать в памяти более 64‑128 изображений с разрешением 224×224 px.
Решения две: уменьшить размер батча и увеличить количество итераций, либо распределить вычисления между несколькими GPU. Первый подход ухудшает статистическую эффективность обучения, второй требует корректной синхронизации градиентов. В PyTorch для этого существует два фундаментальных инструмента – градиентный аккумулятор (gradient accumulation) и параллелизм данных (data parallelism). Оба могут быть реализованы «с нуля», без готовых обёрток, что даёт полное понимание происходящего под капотом.
Принцип градиентного аккумулятора
Градиентный аккумулятор позволяет имитировать большой батч, разбивая его на несколько микробатчей, каждый из которых помещается в память GPU. После обратного прохода по каждому микробатчу градиенты не обнуляются, а суммируются. Когда обработано заданное количество микробатчей (например, 8), происходит один шаг оптимизатора – фактически выполнен один шаг обучения с батчем в 8 × микробатч‑size.
Ключевые детали реализации:
- Отключить автоматическое обнуление градиентов после
loss.backward(). - Отслеживать счётчик микробатчей, после которого делаем
optimizer.step()и только потом вызываемoptimizer.zero_grad(). - При использовании нескольких GPU градиенты должны быть синхронизированы (см. ниже).
Пример кода:
# model – любой nn.Module, optimizer – выбранный оптимизатор
accum_steps = 8 # сколько микробатчей собрать в один большой батч
model.train()
optimizer.zero_grad()
for i, (inputs, targets) in enumerate(dataloader):
outputs = model(inputs) # forward
loss = criterion(outputs, targets) / accum_steps # масштабируем loss
loss.backward() # градиенты накапливаются
if (i + 1) % accum_steps == 0: # пора делать шаг
optimizer.step()
optimizer.zero_grad()
Обратите внимание на деление loss на accum_steps. Это сохраняет корректный масштаб градиентов, аналогичный обучению с полным батчем.
Параллелизм данных в PyTorch
Параллелизм данных (DP) – классический способ распределения работы: каждый GPU получает свою часть входного батча, проводит прямой и обратный проход, после чего градиенты агрегируются (обычно суммируются) и обновляются совместно.
В PyTorch есть готовый torch.nn.DataParallel, но он реализован как single‑process wrapper, что приводит к лишним копиям модели и неэффективному использованию GPU. Более гибким решением является ручная реализация DP:
def scatter_batch(batch, devices):
"""Разбивает батч на части по количеству GPU."""
inputs, targets = batch
inputs_split = torch.chunk(inputs, len(devices), dim=0)
targets_split = torch.chunk(targets, len(devices), dim=0)
return list(zip(inputs_split, targets_split))
def parallel_step(model, batch, device):
"""Выполняет forward+backward на отдельном GPU."""
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
return loss.item()
Синхронизация градиентов производится через torch.nn.utils.clip_grad_norm_ и torch.distributed.all_reduce, но в простом случае достаточно вызвать model.module (если модель уже обёрнута в DistributedDataParallel) или выполнить суммирование вручную:
def average_gradients(model):
"""Суммирует градиенты по всем GPU и делит на их количество."""
world_size = torch.cuda.device_count()
for param in model.parameters():
if param.grad is None:
continue
# all_reduce суммирует градиенты по всем процессам
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.SUM)
param.grad /= world_size
Для однопроцессного сценария можно просто собрать градиенты после каждого микробатча, используя torch.cuda.synchronize() и torch.nn.functional.reduce_sum на каждом параметре.
Совмещение аккумулятора и параллелизма
Объединить оба механизма довольно просто: каждый GPU обрабатывает свой микробатч, градиенты суммируются локально, а затем агрегируются глобально. После получения нужного количества микробатчей (по каждому GPU) делаем один глобальный шаг оптимизатора.
Алгоритм:
- Разбить большой батч на
N_devices × accum_stepsмикробатчей. - На каждом устройстве выполнить
accum_stepsобратных проходов без обнуления градиентов. - После завершения локального аккумулятора вызвать
average_gradients(илиtorch.distributed.all_reduce). - Сделать
optimizer.step()один раз, затемoptimizer.zero_grad()на всех GPU.
Код‑скелет:
devices = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]
model = MyModel().to(devices[0])
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices)
accum_steps = 4
world_size = len(devices)
model.train()
optimizer.zero_grad()
for batch_idx, batch in enumerate(dataloader):
# batch -> (inputs, targets)
micro_batches = scatter_batch(batch, devices)
for step in range(accum_steps):
# каждый GPU получает свой микробатч
loss_val = parallel_step(model, micro_batches[step % world_size], devices[step % world_size])
# масштабируем loss внутри parallel_step, если нужно
# после accum_steps микробатчей – глобальная синхронизация градиентов
average_gradients(model)
optimizer.step()
optimizer.zero_grad()
Такой подход сохраняет эффективность больших батчей, одновременно используя всю доступную видеопамять.
Практический пример: обучение ResNet на CIFAR‑10
Ниже представлена минимальная рабочая программа, демонстрирующая совместное использование градиентного аккумулятора и параллелизма данных.
import torch, torch.nn as nn, torch.optim as optim
import torchvision.datasets as datasets, torchvision.transforms as T
import torchvision.models as models
import torch.multiprocessing as mp
import torch.distributed as dist
def train(rank, world_size):
dist.init_process_group(backend='nccl', init_method='env://',
world_size=world_size, rank=rank)
device = torch.device(f'cuda:{rank}')
torch.cuda.set_device(device)
# Датасет
transform = T.Compose([T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize((0.5,)*3, (0.5,)*3)])
train_set = datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
sampler = torch.utils.data.distributed.DistributedSampler(train_set,
num_replicas=world_size,
rank=rank)
loader = torch.utils.data.DataLoader(train_set, batch_size=64,
sampler=sampler, num_workers=2)
# Модель
model = models.resnet18(num_classes=10).to(device)
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
accum_steps = 4
model.train()
optimizer.zero_grad()
for epoch in range(10):
sampler.set_epoch(epoch) # перемешивание
for i, (x, y) in enumerate(loader):
x, y = x.to(device), y.to(device)
outputs = model(x)
loss = criterion(outputs, y) / accum_steps
loss.backward()
if (i + 1) % accum_steps == 0:
# градиенты уже агрегированы DDP, делаем шаг
optimizer.step()
optimizer.zero_grad()
if rank == 0:
print(f'Epoch {epoch} completed')
if __name__ == '__main__':
world_size = torch.cuda.device_count()
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
Ключевые моменты:
DistributedSamplerгарантирует, что каждый процесс получает уникальную часть датасета.DistributedDataParallelавтоматически суммирует градиенты между процессами после каждогоloss.backward(). Поэтому в коде достаточно делитьlossнаaccum_steps.accum_stepsзадаёт размер виртуального батча: приbatch_size=64иaccum_steps=4эффективный батч составляет 256 образцов, что достаточно для стабильного обучения ResNet‑18 на CIFAR‑10.
Запуск этой программы на 2‑4 GPU позволяет достичь почти линейного ускорения без потери качества модели.
Выводы
Градиентный аккумулятор и параллелизм данных – два взаимодополняющих инструмента, позволяющих решить проблему ограниченной видеопамяти при обучении современных нейросетей. Реализовав их «с нуля», разработчик получает полный контроль над масштабированием, может тонко настраивать частоту синхронизаций и легко интегрировать дополнительные техники (AMP, градиентный клиппинг, кастомные оптимизаторы).
В PyTorch их совместное использование реализуется через DistributedDataParallel + деление потерь на количество аккумуляций, либо через ручную агрегацию градиентов и собственный цикл optimizer.step(). Пример с ResNet‑18 на CIFAR‑10 демонстрирует, как за один цикл кода получить эффективность обучения, сравнимую с использованием крупных серверных GPU‑класса, но на обычных потребительских видеокартах.
Эти подходы становятся базовым строительным блоком для дальнейшего масштабирования: от нескольких GPU до целых кластеров, от небольших датасетов до терабайтов данных. Их знание – обязательный навык любого инженера по машинному обучению, работающего с PyTorch.