@@ -437,7 +437,11 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_
437437 @staticmethod
438438 def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
439439 # this op cannot be async, otherwise it completely breaks the outputs of models
440- torch .distributed .all_reduce (outputs [0 ], op = torch .distributed .ReduceOp .SUM , async_op = False )
440+ if isinstance (outputs , torch .Tensor ):
441+ torch .distributed .all_reduce (outputs , op = torch .distributed .ReduceOp .SUM , async_op = False )
442+ else :
443+ # TODO: we assume we want to allreduce first element of tuple
444+ torch .distributed .all_reduce (outputs [0 ], op = torch .distributed .ReduceOp .SUM , async_op = False ) # TODO: rename GatherParallel to ReduceParallel or something
441445 return outputs
442446
443447
@@ -465,6 +469,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
465469 if to_contiguous :
466470 param = param .contiguous ()
467471 param = param / device_mesh .size () # TODO should be optionable
472+ # TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel)
468473 return param
469474
470475 def prepare_module_tp (self , module : nn .Module , device_mesh ) -> nn .Module :
@@ -786,6 +791,66 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
786791 parameter = DTensor .from_local (parameter , device_mesh , [Replicate ()], run_check = False )
787792 return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
788793
794+ class GroupedGemmParallel (TensorParallelLayer ):
795+ """
796+ Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
797+ """
798+ def __init__ (self ):
799+ super ().__init__ ()
800+ self .use_dtensor = False
801+
802+ def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
803+ ep_rank = rank
804+ global_num_experts = empty_param .shape [0 ]
805+ if global_num_experts % device_mesh .size () != 0 :
806+ raise ValueError (f"Global number of experts must be divisible by number of devices: { global_num_experts } % { device_mesh .size ()} != 0" )
807+ local_num_experts = global_num_experts // device_mesh .size ()
808+ param = param [ep_rank * local_num_experts :(ep_rank + 1 )* local_num_experts ].to (param_casting_dtype )
809+ if to_contiguous :
810+ param = param .contiguous ()
811+ return param
812+
813+ class RouterParallel (TensorParallelLayer ):
814+ """
815+ Applies Expert Parallelism to MoE router
816+ """
817+ def __init__ (self , * args , ** kwargs ):
818+ self .args = args
819+ self .kwargs = kwargs
820+ self .use_dtensor = False
821+
822+ @staticmethod
823+ def _prepare_input_fn (input_layouts , desired_input_layouts , mod , inputs , device_mesh ):
824+ input_tensor = inputs [0 ]
825+ if isinstance (input_tensor , DTensor ):
826+ raise NotImplementedError ("RouterParallel does not support DTensor input for now" )
827+ return input_tensor
828+
829+ @staticmethod
830+ def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
831+ ep_rank , ep_size = device_mesh .get_local_rank (), device_mesh .size ()
832+ num_local_experts = mod .num_experts // ep_size
833+ router_scores , router_indices = outputs
834+ router_scores = router_scores [ep_rank * num_local_experts :(ep_rank + 1 ) * num_local_experts ]
835+ return router_scores , router_indices
836+
837+ def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
838+ # TODO: i'd like for this to be the default
839+ param = param [...].to (param_casting_dtype )
840+ if to_contiguous :
841+ param = param .contiguous ()
842+ return param
843+
844+
845+ def prepare_module_tp (self , module : nn .Module , device_mesh ) -> nn .Module :
846+ # TODO: need an abstract Parallel class that is different from TensorParallelLayer
847+ distribute_module (
848+ module ,
849+ device_mesh ,
850+ partial (self ._prepare_input_fn , None , None ),
851+ partial (self ._prepare_output_fn , None , None ),
852+ )
853+
789854
790855class ParallelInterface (GeneralInterface ):
791856 # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
@@ -803,6 +868,8 @@ class ParallelInterface(GeneralInterface):
803868 "local_packed_rowwise" : PackedRowwiseParallel (use_dtensor = False ),
804869 "sequence_parallel" : SequenceParallel (),
805870 "replicate" : ReplicateParallel (),
871+ "grouped_gemm" : GroupedGemmParallel (),
872+ "ep_router" : RouterParallel (),
806873 }
807874 if is_torch_greater_or_equal ("2.5" ) and _torch_distributed_available
808875 else {}
@@ -901,7 +968,7 @@ def __init__(self):
901968
902969def shard_and_distribute_module (
903970 model , param , empty_param , parameter_name , param_casting_dtype , is_contiguous , rank , device_mesh
904- ):
971+ ): # TODO: rename to shard_and_distribute_param
905972 r"""
906973 Main uses cases:
907974 - column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
@@ -913,7 +980,7 @@ def shard_and_distribute_module(
913980 """
914981 param_name , param_type = parameter_name .rsplit ("." , 1 ) if "." in parameter_name else parameter_name
915982 tp_plan = model ._tp_plan
916- module_to_tp = model .get_submodule (param_name )
983+ module_to_tp = model .get_submodule (param_name ) # TODO: can i loop over modules?
917984 rank = int (rank )
918985
919986 current_shard_plan = _get_parameter_tp_plan (parameter_name , tp_plan )
0 commit comments