Skip to content

Commit

Permalink
Auto convert moe param groups (#5354)
Browse files Browse the repository at this point in the history
When using frameworks like HF Accelerate with MoE models in HF there's
an issue when DeepSpeed is creating the optimizer where we have no way
to automatically create the compatible MoE param groups. This PR detects
if no client optimizer is set and model_parameters are passed to
DeepSpeed that they are either MoE compatible or makes them MoE
compatible automatically.

This was never an issue previously since (1) MoE hasn't really been
tested outside MDS and (2) MDS manually converts the weight-decay param
groups into being MoE compatible before deepspeed.initialize.

The error that is triggered if the param groups are not MoE compatible
is triggered here:
https://github.com/microsoft/DeepSpeed/blob/cc897ecf15fdac5437fa4a2743154dc6c1749da4/deepspeed/runtime/zero/stage_1_and_2.py#L610-L612

Tagging @tohtana and @ykim362 to help review

---------

Co-authored-by: Jeff Rasley <jeff.rasley@snowflake.com>
  • Loading branch information
jeffra and sfc-gh-jrasley authored Apr 5, 2024
1 parent 731fd68 commit 42a8eaa
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
30 changes: 30 additions & 0 deletions deepspeed/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,33 @@ def split_params_into_different_moe_groups_for_optimizer(

def is_moe_param_group(param_group):
return param_group.get('moe', False)


def configure_moe_param_groups(model_parameters: List):
assert isinstance(model_parameters, list), "model_parameters must be a list"

for p in model_parameters:
# match torch.optim.Optimizer expectations,
# see: https://github.com/pytorch/pytorch/blob/2ffab6e663b9c6951048b8c8ba82d2cc5ca5c2fc/torch/optim/optimizer.py#L270-L272
if not isinstance(p, (torch.Tensor, dict)):
raise TypeError("param argument that would be given to the optimizer should be "
f"an iterable of Tensors or dicts, but got {type(p)}")

# peak at the first element to determine how to proceed
first = model_parameters[0]

# Case 1: model_parameters is a list of torch.nn.Parameter
# -> need to create moe compatible param groups
if isinstance(first, torch.nn.Parameter):
param_group = {'params': model_parameters, 'name': 'dense-params'}
return split_params_into_different_moe_groups_for_optimizer(param_group)

# Case 2: model_parameters is a list of param groups List[dict]
# -> moe compatible param groups might already exist, if not create them
elif isinstance(first, dict):
#there are no moe groups created
if not any(['moe' in param_group for param_group in model_parameters]):
return split_params_into_different_moe_groups_for_optimizer(model_parameters)
else:
# moe groups exist, nothing to do
return model_parameters
4 changes: 3 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE
from ..moe.utils import is_moe_param
from ..moe.utils import is_moe_param, configure_moe_param_groups
from ..git_version_info import version

from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
Expand Down Expand Up @@ -1227,6 +1227,8 @@ def _do_optimizer_sanity_check(self, basic_optimizer):
# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):
if client_optimizer is None:
if self.has_moe_layers:
model_parameters = configure_moe_param_groups(model_parameters)
basic_optimizer = self._configure_basic_optimizer(model_parameters)
log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0])
else:
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,42 @@
from deepspeed.runtime.utils import required_torch_version


@pytest.mark.parametrize("zero_stage", [0, 1, 2])
class TestSimpleMoE(DistributedTest):
world_size = 2

def test(self, zero_stage):
if not required_torch_version(min_version=1.8):
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
}
}
# should automatically create moe param groups in deepspeed backend
hidden_dim = 16
model = SimpleMoEModel(hidden_dim=hidden_dim, ep_size=1)
model, optimizer, _, _ = deepspeed.initialize(config=config_dict, model=model)
data_loader = sequence_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device)

for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()


@pytest.mark.parametrize("ep_size", [2, 4])
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
@pytest.mark.parametrize("use_residual", [True, False])
Expand Down

0 comments on commit 42a8eaa

Please sign in to comment.