Skip to content

Commit 865405e

Browse files
authored
Merge pull request #4 from huggingface/tp-fixes
Expert Parallelism + TP fixes
2 parents 581f912 + a904696 commit 865405e

File tree

7 files changed

+369
-58
lines changed

7 files changed

+369
-58
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING
16+
17+
from ..utils import _LazyModule
18+
19+
20+
_import_structure = {
21+
"configuration_utils": [
22+
"DistributedConfig",
23+
],
24+
}
25+
26+
27+
if TYPE_CHECKING:
28+
from .configuration_utils import (
29+
DistributedConfig,
30+
)
31+
32+
else:
33+
import sys
34+
35+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
import json
17+
from typing import Any, Dict, Union
18+
import os
19+
import copy
20+
21+
22+
@dataclass
23+
class DistributedConfig:
24+
"""
25+
Base class for distributed configs
26+
"""
27+
28+
enable_expert_parallel: bool = False
29+
# TODO: add tp_plan, pp_plan, device_mesh etc..
30+
31+
@classmethod
32+
def from_dict(cls, config_dict, **kwargs):
33+
"""
34+
Constructs a DistributedConfig instance from a dictionary of parameters.
35+
Args:
36+
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
37+
**kwargs: Additional keyword arguments to override dictionary values.
38+
39+
Returns:
40+
DistributedConfig: Instance of DistributedConfig constructed from the dictionary.
41+
"""
42+
config = cls(**config_dict)
43+
to_remove = []
44+
for key, value in kwargs.items():
45+
if hasattr(config, key):
46+
setattr(config, key, value)
47+
to_remove.append(key)
48+
for key in to_remove:
49+
kwargs.pop(key, None)
50+
return config
51+
52+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
53+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
54+
"""
55+
Save this instance to a JSON file.
56+
57+
Args:
58+
json_file_path (`str` or `os.PathLike`):
59+
Path to the JSON file in which this configuration instance's parameters will be saved.
60+
use_diff (`bool`, *optional*, defaults to `True`):
61+
If set to `True`, only the difference between the config instance and the default
62+
`QuantizationConfig()` is serialized to JSON file.
63+
"""
64+
with open(json_file_path, "w", encoding="utf-8") as writer:
65+
config_dict = self.to_dict()
66+
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
67+
68+
writer.write(json_string)
69+
70+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
71+
def to_dict(self) -> Dict[str, Any]:
72+
"""
73+
Serializes this instance to a Python dictionary. Returns:
74+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
75+
"""
76+
return copy.deepcopy(self.__dict__)
77+
78+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
79+
def __iter__(self):
80+
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
81+
for attr, value in copy.deepcopy(self.__dict__).items():
82+
yield attr, value
83+
84+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
85+
def __repr__(self):
86+
return f"{self.__class__.__name__} {self.to_json_string()}"
87+
88+
def to_json_string(self):
89+
"""
90+
Serializes this instance to a JSON formatted string.
91+
Returns:
92+
str: JSON formatted string representing the configuration instance.
93+
"""
94+
return json.dumps(self.__dict__, indent=2) + "\n"
95+
96+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
97+
def update(self, **kwargs):
98+
"""
99+
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
100+
returning all the unused kwargs.
101+
102+
Args:
103+
kwargs (`Dict[str, Any]`):
104+
Dictionary of attributes to tentatively update this class.
105+
106+
Returns:
107+
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
108+
"""
109+
to_remove = []
110+
for key, value in kwargs.items():
111+
if hasattr(self, key):
112+
setattr(self, key, value)
113+
to_remove.append(key)
114+
115+
# Remove all the attributes that were updated, without modifying the input dict
116+
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
117+
return unused_kwargs
118+

src/transformers/integrations/tensor_parallel.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,11 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_
437437
@staticmethod
438438
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
439439
# this op cannot be async, otherwise it completely breaks the outputs of models
440-
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
440+
if isinstance(outputs, torch.Tensor):
441+
torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.SUM, async_op=False)
442+
else:
443+
# TODO: we assume we want to allreduce first element of tuple
444+
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) # TODO: rename GatherParallel to ReduceParallel or something
441445
return outputs
442446

443447

@@ -465,6 +469,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
465469
if to_contiguous:
466470
param = param.contiguous()
467471
param = param / device_mesh.size() # TODO should be optionable
472+
# TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel)
468473
return param
469474

470475
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
@@ -786,6 +791,66 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
786791
parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
787792
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
788793

794+
class GroupedGemmParallel(TensorParallelLayer):
795+
"""
796+
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
797+
"""
798+
def __init__(self):
799+
super().__init__()
800+
self.use_dtensor = False
801+
802+
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
803+
ep_rank = rank
804+
global_num_experts = empty_param.shape[0]
805+
if global_num_experts % device_mesh.size() != 0:
806+
raise ValueError(f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0")
807+
local_num_experts = global_num_experts // device_mesh.size()
808+
param = param[ep_rank*local_num_experts:(ep_rank+1)*local_num_experts].to(param_casting_dtype)
809+
if to_contiguous:
810+
param = param.contiguous()
811+
return param
812+
813+
class RouterParallel(TensorParallelLayer):
814+
"""
815+
Applies Expert Parallelism to MoE router
816+
"""
817+
def __init__(self, *args, **kwargs):
818+
self.args = args
819+
self.kwargs = kwargs
820+
self.use_dtensor = False
821+
822+
@staticmethod
823+
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
824+
input_tensor = inputs[0]
825+
if isinstance(input_tensor, DTensor):
826+
raise NotImplementedError("RouterParallel does not support DTensor input for now")
827+
return input_tensor
828+
829+
@staticmethod
830+
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
831+
ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
832+
num_local_experts = mod.num_experts // ep_size
833+
router_scores, router_indices = outputs
834+
router_scores = router_scores[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts]
835+
return router_scores, router_indices
836+
837+
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
838+
# TODO: i'd like for this to be the default
839+
param = param[...].to(param_casting_dtype)
840+
if to_contiguous:
841+
param = param.contiguous()
842+
return param
843+
844+
845+
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
846+
# TODO: need an abstract Parallel class that is different from TensorParallelLayer
847+
distribute_module(
848+
module,
849+
device_mesh,
850+
partial(self._prepare_input_fn, None, None),
851+
partial(self._prepare_output_fn, None, None),
852+
)
853+
789854

790855
class ParallelInterface(GeneralInterface):
791856
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
@@ -803,6 +868,8 @@ class ParallelInterface(GeneralInterface):
803868
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
804869
"sequence_parallel": SequenceParallel(),
805870
"replicate": ReplicateParallel(),
871+
"grouped_gemm": GroupedGemmParallel(),
872+
"ep_router": RouterParallel(),
806873
}
807874
if is_torch_greater_or_equal("2.5") and _torch_distributed_available
808875
else {}
@@ -901,7 +968,7 @@ def __init__(self):
901968

902969
def shard_and_distribute_module(
903970
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
904-
):
971+
): # TODO: rename to shard_and_distribute_param
905972
r"""
906973
Main uses cases:
907974
- column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
@@ -913,7 +980,7 @@ def shard_and_distribute_module(
913980
"""
914981
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
915982
tp_plan = model._tp_plan
916-
module_to_tp = model.get_submodule(param_name)
983+
module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules?
917984
rank = int(rank)
918985

919986
current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)

src/transformers/modeling_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .configuration_utils import PretrainedConfig
5353
from .dynamic_module_utils import custom_object_save
5454
from .generation import CompileConfig, GenerationConfig
55+
from .distributed import DistributedConfig
5556
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
5657
from .integrations.accelerate import find_tied_parameters, init_empty_weights
5758
from .integrations.deepspeed import _load_state_dict_into_zero3_model
@@ -774,7 +775,7 @@ def _load_state_dict_into_meta_model(
774775
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
775776

776777
for param_name, empty_param in state_dict.items():
777-
if param_name not in expected_keys:
778+
if param_name not in expected_keys: # when loading from ckpt, we skip param if doesnt exist in modeling
778779
continue
779780

780781
# we need to use serialized_param_name as file pointer is untouched
@@ -2149,6 +2150,23 @@ def post_init(self):
21492150
raise ValueError(
21502151
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
21512152
)
2153+
2154+
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
2155+
# loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit
2156+
device_mesh = self.config.device_mesh
2157+
for name, module in self.named_modules():
2158+
if not getattr(module, "_is_hooked", False):
2159+
from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module
2160+
add_tensor_parallel_hooks_to_module(
2161+
model=self,
2162+
module=module,
2163+
tp_plan=self._tp_plan,
2164+
layer_name="", # TODO: make this optional?
2165+
current_module_plan=_get_parameter_tp_plan(parameter_name=name, tp_plan=self._tp_plan),
2166+
device_mesh=device_mesh,
2167+
parameter_name=None
2168+
)
2169+
module._is_hooked = True
21522170

21532171
def dequantize(self):
21542172
"""
@@ -4445,6 +4463,7 @@ def from_pretrained(
44454463
gguf_file = kwargs.pop("gguf_file", None)
44464464
tp_plan = kwargs.pop("tp_plan", None)
44474465
tp_size = kwargs.pop("tp_size", None)
4466+
distributed_config : DistributedConfig = kwargs.pop("distributed_config", None)
44484467
device_mesh = kwargs.pop("device_mesh", None)
44494468
trust_remote_code = kwargs.pop("trust_remote_code", None)
44504469
use_kernels = kwargs.pop("use_kernels", False)
@@ -4808,6 +4827,14 @@ def from_pretrained(
48084827
device_map=device_map,
48094828
)
48104829

4830+
if distributed_config is not None and distributed_config.enable_expert_parallel:
4831+
# TODO: add proper support for ep_plan independently of tp_plan
4832+
if config.base_model_ep_plan is None:
4833+
raise ValueError("base_model_ep_plan is required when enable_expert_parallel is True")
4834+
config.base_model_tp_plan = config.base_model_ep_plan # TODO: hack for now
4835+
4836+
config.device_mesh = device_mesh # Used in post_init
4837+
48114838
with ContextManagers(model_init_context):
48124839
# Let's make sure we don't run the init function of buffer modules
48134840
model = cls(config, *model_args, **model_kwargs)

src/transformers/models/openai_moe/configuration_openai_moe.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,38 @@ class OpenAIMoeConfig(PretrainedConfig):
3434
"layers.*.self_attn.v_proj": "colwise",
3535
"layers.*.self_attn.o_proj": "rowwise",
3636
"layers.*.self_attn.sinks": "local_rowwise",
37+
3738
"layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise",
38-
"layers.*.mlp.experts.gate_up_proj_bias": "local_rowwise",
39+
"layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise",
3940
"layers.*.mlp.experts.down_proj": "local_colwise",
40-
"layers.*.mlp.experts.down_proj_bias": "local",
41-
"layers.*.mlp.experts": "gather",
41+
"layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs
42+
"layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output
4243
}
4344
base_model_pp_plan = {
4445
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
4546
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
4647
"norm": (["hidden_states"], ["hidden_states"]),
4748
}
49+
base_model_ep_plan = {
50+
"layers.*.self_attn.q_proj": "colwise",
51+
"layers.*.self_attn.k_proj": "colwise",
52+
"layers.*.self_attn.v_proj": "colwise",
53+
"layers.*.self_attn.o_proj": "rowwise",
54+
"layers.*.self_attn.sinks": "local_rowwise",
55+
56+
# TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them
57+
'layers.*.mlp.token_dispatcher': "gather",
58+
"layers.*.mlp.router": "ep_router",
59+
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
60+
"layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm",
61+
"layers.*.mlp.experts.down_proj": "grouped_gemm",
62+
"layers.*.mlp.experts.down_proj_bias": "grouped_gemm",
63+
}
4864

4965
def __init__(
5066
self,
5167
num_hidden_layers: int = 36,
52-
num_local_experts: int = 128,
68+
num_local_experts: int = 128, #TODO: rename to num_experts otherwise confusing with EP
5369
vocab_size: int = 201088,
5470
hidden_size: int = 2880,
5571
intermediate_size: int = 2880,

0 commit comments

Comments
 (0)