Что такое Zero Redundancy Optimizer (ZeRO)
Zero Redundancy Optimizer (ZeRO) — это набор техник, направленных на устранение дублирования данных в процессе обучения больших нейронных сетей на кластере GPU. Традиционный Data Parallel (DP) копирует все параметры модели, градиенты и состояния оптимизатора на каждый графический процессор, что быстро приводит к нехватке памяти при масштабных архитектурах. ZeRO разбивает эти три компонента (параметры, градиенты, состояния оптимизатора) между устройствами, оставляя каждый из них только в том месте, где он действительно нужен. В результате потребление видеопамяти снижается почти в три раза, а обучение становится экономически эффективнее.
Принцип работы ZeRO
ZeRO реализует три уровня шардирования:
-
Stage 1 – шардирование градиентов. После обратного прохода градиенты распределяются между GPU, каждый из которых хранит лишь часть полной градиентной матрицы. При обновлении параметров каждый процесс собирает только свою долю градиента, а остальные части остаются на своих устройствах.
-
Stage 2 – шардирование состояний оптимизатора. Большинство современных оптимизаторов (Adam, AdamW) хранят два вспомогательных тензора (момент первого и второго порядка) для каждого параметра. ZeRO распределяет эти тензоры аналогично градиентам, устраняя их дублирование.
-
Stage 3 – шардирование самих параметров. На этом этапе параметры модели также разбиваются на части и хранятся только на тех GPU, где они необходимы в текущий момент. При необходимости параметр «запрашивается» через коммуникацию NCCL, а после обновления возвращается в своё шардированное представление.
Коммуникация между процессами происходит в виде асинхронных all‑reduce и all‑gather операций, что минимизирует простой GPU и поддерживает высокую пропускную способность.
Реализация ZeRO «с нуля»
Для понимания механизма можно построить упрощённый прототип на PyTorch без использования готовых библиотек:
import torch
import torch.distributed as dist
def shard_tensor(tensor, rank, world_size):
"""Разбивает тензор вдоль первой оси на world_size частей и возвращает часть для текущего ранка."""
chunk_size = tensor.size(0) // world_size
start = rank * chunk_size
end = start + chunk_size
return tensor[start:end].contiguous()
def all_gather_shard(shard, rank, world_size):
"""Собирает полную копию тензора из шардов всех ранков."""
full = [torch.empty_like(shard) for _ in range(world_size)]
dist.all_gather(full, shard)
return torch.cat(full, dim=0)
# Пример использования в тренировочном цикле
model = MyModel().to(rank) # каждый GPU хранит только свою часть параметров
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for data, target in loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
# Шардируем градиенты
for p in model.parameters():
if p.grad is not None:
grad_shard = shard_tensor(p.grad, rank, world_size)
dist.all_reduce(grad_shard, op=dist.ReduceOp.SUM)
p.grad = grad_shard # теперь каждый процесс хранит только свою часть градиента
optimizer.step()
Этот скелет иллюстрирует базовую идею: каждый процесс хранит лишь часть градиентов и параметров, а синхронизация происходит только над нужными фрагментами. В реальном ZeRO добавляются буферизация, динамическое масштабирование и поддержка сложных оптимизаторов.
Как использовать ZeRO в PyTorch
В продакшн‑сценариях рекомендуется пользоваться библиотекой DeepSpeed, где ZeRO интегрирован и оптимизирован:
from deepspeed import DeepSpeedEngine, DeepSpeedConfig
ds_config = {
"train_batch_size": 32,
"zero_optimization": {
"stage": 3, # полное шардирование
"offload_param": {"device": "cpu"},
"offload_optimizer": {"device": "cpu"}
},
"fp16": {"enabled": True}
}
engine, _, _, _ = DeepSpeedEngine(
model=model,
model_parameters=model.parameters(),
config=DeepSpeedConfig(ds_config)
)
for batch in data_loader:
loss = engine(batch)
engine.backward(loss)
engine.step()
DeepSpeed автоматически управляет распределением параметров, градиентов и состояний оптимизатора, а также обеспечивает эффективный обмен данными через NCCL и NVLink.
Ограничения и практические рекомендации
- Баланс нагрузки: при шардировании параметров важно, чтобы размер каждой части был примерно одинаковым. Иначе один GPU может стать узким местом.
- Коммуникационные накладные расходы: в Stage 3 количество all‑gather операций возрастает, поэтому рекомендуется использовать высокоскоростные соединения (NVLink, InfiniBand).
- Отладка: из‑за распределённого характера ошибок отладка становится сложнее. Полезно включать
torch.distributed.debugи логировать размеры шардов. - Смешанная точность: сочетание ZeRO с FP16/BF16 часто приводит к дополнительному уменьшению потребления памяти без потери качества модели.
Fully Sharded Data Parallel (FSDP)
Fully Sharded Data Parallel (FSDP) — ещё один подход к масштабированию обучения, реализованный в PyTorch ≥ 1.12. В отличие от ZeRO, FSDP полностью шардирует модель, градиенты и состояния оптимизатора внутри единого API torch.distributed.fsdp.FullyShardedDataParallel.
Принцип работы FSDP
- Шардирование параметров происходит в момент обертывания слоя в
FSDP. Каждый процесс хранит только часть параметров, а остальные выгружаются в CPU или в системную память. - Промежуточные активации сохраняются в виде «шардированных» тензоров, а при обратном проходе происходит «re‑compute» (повторный форвард) для восстановления необходимых активаций без их хранения.
- Градиенты собираются через
all‑reduceтолько над теми шардированными фрагментами, которые находятся на текущем устройстве, что сокращает объём передаваемых данных.
Реализация FSDP в PyTorch
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
torch.distributed.init_process_group(backend="nccl")
device = torch.device(f"cuda:{dist.get_rank()}")
model = MyLargeModel().to(device)
fsdp_model = FSDP(model,
mixed_precision=True,
sharding_strategy="FULL_SHARD",
cpu_offload=False)
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=1e-4)
for batch in loader:
optimizer.zero_grad()
output = fsdp_model(batch)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
FSDP автоматически управляет жизненным циклом шардов, включая их выгрузку и загрузку, а также поддерживает «checkpointing» для снижения памяти, затрачиваемой на активации.
Интеграция и совместимость
- Модульность: FSDP можно применять к отдельным подмоделям, что удобно при работе с гибридными архитектурами (например, трансформер + CNN).
- Совместимость с DeepSpeed: в некоторых проектах комбинируют FSDP (для шардирования модели) и DeepSpeed ZeRO (для оптимизации состояний оптимизатора), получая гибкую схему распределения.
- Смешанная точность: включение
mixed_precision=Trueпереводит параметры и градиенты в FP16/ BF16, дополнительно экономя память.
Сравнение ZeZero и FSDP
| Аспект | ZeRO (DeepSpeed) | FSDP (PyTorch) |
|---|---|---|
| Уровень шардирования | 3 стадии, гибко настраиваемые | Полное шардирование (FULL_SHARD) |
| Поддержка оптимизаторов | Шардирует состояния всех популярных оптимизаторов | Шардирует только те, которые находятся в параметрах модели |
| API | Требует отдельный конфиг и обёртку DeepSpeedEngine | Прямой torch.nn.Module → FSDP |
| Коммуникация | NCCL + CPU offload, оптимизированные all‑reduce | NCCL all‑reduce + гибкий checkpointing |
| Память | Возможность выгрузки параметров/оптимизатора в CPU | Выгрузка параметров, активаций через ре‑compute |
| Экосистема | Поддержка ZeRO‑3, ZeRO‑Offload, DP‑Zero гибрид | Интегрировано в ядро PyTorch, поддержка TorchElastic |
Выбор между ZeRO и FSDP зависит от конкретных ограничений: если требуется максимальная гибкость в управлении состояниями оптимизатора и возможность offload на CPU, предпочтительнее DeepSpeed ZeRO. Если же нужен нативный PyTorch‑стек без внешних зависимостей и удобная работа с большими трансформерами, FSDP будет более естественным решением.
Выбор подхода для масштабирования
При планировании масштабирования моделей следует учитывать:
- Размер модели: выше 10 ГБ параметров — обе технологии помогают, но FSDP часто выигрывает за счёт более агрессивного шардирования.
- Аппаратная топология: кластеры с NVLink/InfiniBand лучше подходят под ZeRO‑3, где количество all‑gather операций критично.
- Скорость разработки: если проект уже использует DeepSpeed, переход на ZeRO будет менее затратным; для чистого PyTorch‑экосистемы FSDP проще интегрировать.
- Экономика: Offload в CPU (ZeRO) может позволить обучать на GPU‑меньшего объёма, но добавит задержку из‑за PCIe‑трафика.
Обе технологии продолжают развиваться, вводя новые уровни шардирования и улучшения коммуникаций. Их грамотное применение позволяет обучать модели, которые ранее были недоступны из‑за ограничений видеопамяти, открывая путь к более масштабным исследованиям в области искусственного интеллекта.