@@ -150,7 +150,6 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
150150 "F64" : torch .float64 ,
151151 "I64" : torch .int64 ,
152152 "F8_E4M3" : torch .float8_e4m3fn ,
153- "F8_E5M2" : torch .float8_e5m2 ,
154153}
155154
156155
@@ -526,43 +525,6 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
526525 return param
527526
528527
529- class ReduceFromModelParallelRegion (torch .autograd .Function ):
530- """
531- All-reduce in forward pass, identity in backward pass.
532- This is the `g` function in the paper: https://arxiv.org/abs/1909.08053
533- """
534-
535- @staticmethod
536- def forward (ctx , x , device_mesh ):
537- if device_mesh .size () == 1 :
538- return x
539- dist .all_reduce (x , op = dist .ReduceOp .SUM , group = device_mesh .get_group ())
540- return x
541-
542- @staticmethod
543- def backward (ctx , grad_output ):
544- return grad_output
545-
546-
547- class CopyToModelParallelRegion (torch .autograd .Function ):
548- """
549- Copy in forward pass, all-reduce in backward pass.
550- This is the `f` function in the paper: https://arxiv.org/abs/1909.08053
551- """
552-
553- @staticmethod
554- def forward (ctx , x , device_mesh ):
555- ctx .device_mesh = device_mesh
556- return x
557-
558- @staticmethod
559- def backward (ctx , grad_output ):
560- if ctx .device_mesh .size () == 1 :
561- return grad_output
562- dist .all_reduce (grad_output , op = dist .ReduceOp .SUM , group = ctx .device_mesh .get_group ())
563- return grad_output
564-
565-
566528class ColwiseParallel (TensorParallelLayer ):
567529 """
568530 General tensor parallel layer for transformers.
@@ -585,8 +547,15 @@ def __init__(
585547
586548 @staticmethod
587549 def _prepare_input_fn (input_layouts , desired_input_layouts , mod , inputs , device_mesh ):
550+ # TODO: figure out dynamo support for instance method and switch this to instance method
588551 # annotate module input placements/sharding with input_layouts
589552 input_tensor = inputs [0 ]
553+ if not isinstance (input_tensor , DTensor ):
554+ input_tensor = DTensor .from_local (input_tensor , device_mesh , input_layouts , run_check = False )
555+
556+ # transform the input layouts to the desired layouts of ColwiseParallel
557+ if input_layouts != desired_input_layouts :
558+ input_tensor = input_tensor .redistribute (placements = desired_input_layouts , async_op = False )
590559 return input_tensor
591560
592561 def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
@@ -595,19 +564,41 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
595564 # weight would become Shard(1)
596565 if param_type == "bias" :
597566 parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 1 )
567+ shard = [Shard (- 1 )]
598568 else :
569+ shard = [Shard (- 2 )]
599570 parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 2 )
600571
601572 parameter = parameter .to (param_casting_dtype )
602573 if to_contiguous :
603574 parameter = parameter .contiguous ()
604-
575+ if self .use_dtensor :
576+ parameter = DTensor .from_local (
577+ parameter , device_mesh , shard , run_check = False , shape = empty_param .size (), stride = empty_param .stride ()
578+ )
605579 return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
606580
607581 @staticmethod
608582 def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
609- outputs = CopyToModelParallelRegion .apply (outputs , device_mesh )
610- return outputs
583+ # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
584+ if outputs .placements != output_layouts :
585+ outputs = outputs .redistribute (placements = output_layouts , async_op = False )
586+ # back to local tensor
587+ return outputs .to_local () if use_local_output and isinstance (outputs , DTensor ) else outputs
588+
589+
590+ class PackedColwiseParallel (ColwiseParallel ):
591+ def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
592+ # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
593+ # means Colwise as Linear is input * weight^T + bias, where
594+ # weight would become Shard(1)
595+ parameter = get_packed_weights (param , empty_param , device_mesh , rank , - 2 )
596+ parameter = parameter .to (param_casting_dtype )
597+ if to_contiguous :
598+ parameter = parameter .contiguous ()
599+ if self .use_dtensor :
600+ parameter = DTensor .from_local (parameter , device_mesh , [Shard (- 2 )], run_check = False )
601+ return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
611602
612603
613604class RowwiseParallel (TensorParallelLayer ):
@@ -644,15 +635,23 @@ def __init__(
644635 self .use_dtensor = use_dtensor
645636
646637 def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
647- if param_type == "bias" :
648- parameter = param [:]
649- else :
638+ # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
639+ # means Rowwise as nn.Linear is input * weight^T + bias, where
640+ # weight would become Shard(0)
641+ if param_type != "bias" :
650642 parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 1 )
643+ shard = [Shard (- 1 )]
644+ else :
645+ shard = [Replicate ()]
646+ parameter = param [:]
651647
652648 parameter = parameter .to (param_casting_dtype )
653649 if to_contiguous :
654650 parameter = parameter .contiguous ()
655-
651+ if self .use_dtensor :
652+ parameter = DTensor .from_local (
653+ parameter , device_mesh , shard , run_check = False , shape = empty_param .size (), stride = empty_param .stride ()
654+ )
656655 return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
657656
658657 @staticmethod
@@ -662,13 +661,24 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_
662661 mod .bias = None
663662
664663 input_tensor = inputs [0 ]
664+ if not isinstance (input_tensor , DTensor ):
665+ input_tensor = DTensor .from_local (input_tensor , device_mesh , input_layouts , run_check = False )
666+
667+ if input_layouts != desired_input_layouts :
668+ input_tensor = input_tensor .redistribute (placements = desired_input_layouts , async_op = True )
665669 return input_tensor
666670
667671 @staticmethod
668672 def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
669- outputs = ReduceFromModelParallelRegion .apply (outputs , device_mesh )
673+ # Rowwise sharding produces partial output, depending on output layouts:
674+ # 1. to replicate -> allreduce
675+ # 2. to shard -> reduce_scatter
676+ if outputs .placements != output_layouts :
677+ outputs = outputs .redistribute (placements = output_layouts , async_op = True )
678+ outputs = outputs .to_local () # otherwise the `+=` op will gather
670679 if hasattr (mod , "_bias" ):
671680 outputs += mod ._bias
681+ # back to local tensor if use_local_output is True
672682 return outputs
673683
674684 def prepare_module_tp (self , module : nn .Module , device_mesh ) -> nn .Module :
@@ -694,21 +704,6 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
694704 )
695705
696706
697- class PackedColwiseParallel (ColwiseParallel ):
698- def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
699- # NOTE(3outeille): need to be deprecated as no longer using dtensors
700- # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
701- # means Colwise as Linear is input * weight^T + bias, where
702- # weight would become Shard(1)
703- parameter = get_packed_weights (param , empty_param , device_mesh , rank , - 2 )
704- parameter = parameter .to (param_casting_dtype )
705- if to_contiguous :
706- parameter = parameter .contiguous ()
707- if self .use_dtensor :
708- parameter = DTensor .from_local (parameter , device_mesh , [Shard (- 2 )], run_check = False )
709- return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
710-
711-
712707class PackedRowwiseParallel (RowwiseParallel ):
713708 def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
714709 # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
0 commit comments