In veScale, we provide two optimizers for Optimizer Parallel:
-
DistributedOptimizer
-
BasicOptimizer
DistributedOptimizer
is a ZeRO 2+ optimizer. Similar to the original ZeRO2, it parallelizes model gradient and optimizer states along Data Parallel dimension. Differently, it further parallelizes model parameters virtually but not physically.
DistributedOptimizer
is primarily inherited from Megatron-LM's DistributedOptimizer for its performance and mostly due to the lacking of ZeRO2 optimizer in native PyTorch. We extend and enhance DistributedOptimizer
with extra features:
-
convert between
Tensor
andDTensor
-
support online resharding of optimzier state
In DistributedOptimizer
, the model gradients and optimizer states are sharded along Data Parallel dimension in each gradient Bucket of Gradient Buffer (see DDP
for more details), where each DP rank only manages its own shard of gradient, generates its own shard of optimizer states, and updates its own shard of parameters.
The flow of DistributedOptimizer
is as follows:
- During initialization, model parameters are virtually sharded across all DP ranks, such that each DP rank owns a partial view of the original model parameters
- This sharding does not respect parameter boundaries, i.e., a parameter could be split into two halves and belong to two DP ranks. Therefore, a complex mapping between the sharded parameters and the original parameters is established, which is mostly done in the
__init__
function. Then the optimizer'sparam_groups
is replaced with the Sharded Parameter.
-
Receive Reduced Gradient resulting from
ReduceScatter
per Gradient Bucket inDDP
-
Attach Reduced Gradient (
main_grad
of each original parameter) to the Sharded Parameter -
Run the actual
optimizer.step()
to generate Optimizer State of each shard and updates Sharded Parameter with Reduced Gradient -
Copy the updated Sharded Parameter to a specific parameter buffer and get ready for
AllGather
communication to restore the full parameters
- To avoid the performance overhead and memory cost of per-parameter
AllGather
, the Gradient Buffer ofDDP
is reused as the communication buffer forAllGather
.
- Overlap the parameter
AllGather
with the forward computation in the next iteration for hiding communication overhead, similar to gradientReduceScater
overlap with backward computation
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.optim.distributed_optimizer import DistributedOptimizer
from vescale.dmodule.api import parallelize_module
from vescale.dtensor.device_mesh import DeviceMesh
# create an torch-native model
mlp = MLP()
# create 2-dim DeviceMesh, the first for data-parallel, while the second for tensor-parallel.
device_mesh = DeviceMesh("cuda", [[0, 1], [2, 3]], mesh_dim_names=("DP", "TP"))
# parallelize torch-native model into TP model
tp_mlp = parallelize_module(mlp, device_mesh["TP"], sharding_plan)
# wrap TP model with `DDP`
dp_tp_mlp = DDP(
module=tp_mlp,
device_mesh["DP"],
use_distributed_optimizer=True
)
# create DistributedOptimizer
doptim = DistributedOptimizer(
# choose core optimizer class
torch.optim.Adam,
# feed model
models=[dp_tp_mlp],
# choose whether overlap the param all-gather with the next forward for speeding up
overlap_param_gather=True or False,
# feed core optimizer kwargs
optimizer_kwargs={"lr": 0.01},
)
# training current iteration
dp_tp_mlp(torch.rand(...)).sum().bakward()
# reduce-scatter the gradient across the DP world.
dp_tp_mlp.finish_grad_sync()
# update model
doptim.step()
# training next iteration
doptim.zero_grad()
# <repeat above>
APIs can found in: <repo>/vescale/optim/distributed_optimizer.py
.
More examples can found in: <repo>/test/parallel/ddp_optim/test_doptimizer.py
.
BasicOptimizer
is a not ZeRO optimizer but a simple optimizer that works like Data Parallel which replicates parameters, gradients, and optimizer states along Data Parallel dimension.
BasicOptimizer
itself is nothing but a simple wrapper that wraps given optimizer instance with utilities for veScale DTensor
, DModule
, and DDP
:
-
convert between
Tensor
andDTensor
-
recover flattened gradient from
DDP
-
trigger gradient synchronization of
DModule
(e.g., for Sequence Parallel)
APIs can be found in: <repo>/vescale/optim/base_optimizer.py
.
Examples can be found in <repo>/test/parallel/ddp_optim/test_ddp.py
.
The compatibility of the above optimizers with DDP
is as follows:
BasicOptimizer |
DistributedOptimizer |
|
---|---|---|
DDP |
yes | yes |
NO DDP |
yes | no |