@@ -150,6 +150,7 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
150150 "F64" : torch .float64 ,
151151 "I64" : torch .int64 ,
152152 "F8_E4M3" : torch .float8_e4m3fn ,
153+ "F8_E5M2" : torch .float8_e5m2 ,
153154}
154155
155156
@@ -525,6 +526,43 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
525526 return param
526527
527528
529+ class ReduceFromModelParallelRegion (torch .autograd .Function ):
530+ """
531+ All-reduce in forward pass, identity in backward pass.
532+ This is the `g` function in the paper: https://arxiv.org/abs/1909.08053
533+ """
534+
535+ @staticmethod
536+ def forward (ctx , x , device_mesh ):
537+ if device_mesh .size () == 1 :
538+ return x
539+ dist .all_reduce (x , op = dist .ReduceOp .SUM , group = device_mesh .get_group ())
540+ return x
541+
542+ @staticmethod
543+ def backward (ctx , grad_output ):
544+ return grad_output
545+
546+
547+ class CopyToModelParallelRegion (torch .autograd .Function ):
548+ """
549+ Copy in forward pass, all-reduce in backward pass.
550+ This is the `f` function in the paper: https://arxiv.org/abs/1909.08053
551+ """
552+
553+ @staticmethod
554+ def forward (ctx , x , device_mesh ):
555+ ctx .device_mesh = device_mesh
556+ return x
557+
558+ @staticmethod
559+ def backward (ctx , grad_output ):
560+ if ctx .device_mesh .size () == 1 :
561+ return grad_output
562+ dist .all_reduce (grad_output , op = dist .ReduceOp .SUM , group = ctx .device_mesh .get_group ())
563+ return grad_output
564+
565+
528566class ColwiseParallel (TensorParallelLayer ):
529567 """
530568 General tensor parallel layer for transformers.
@@ -547,15 +585,8 @@ def __init__(
547585
548586 @staticmethod
549587 def _prepare_input_fn (input_layouts , desired_input_layouts , mod , inputs , device_mesh ):
550- # TODO: figure out dynamo support for instance method and switch this to instance method
551588 # annotate module input placements/sharding with input_layouts
552589 input_tensor = inputs [0 ]
553- if not isinstance (input_tensor , DTensor ):
554- input_tensor = DTensor .from_local (input_tensor , device_mesh , input_layouts , run_check = False )
555-
556- # transform the input layouts to the desired layouts of ColwiseParallel
557- if input_layouts != desired_input_layouts :
558- input_tensor = input_tensor .redistribute (placements = desired_input_layouts , async_op = False )
559590 return input_tensor
560591
561592 def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
@@ -564,41 +595,19 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
564595 # weight would become Shard(1)
565596 if param_type == "bias" :
566597 parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 1 )
567- shard = [Shard (- 1 )]
568598 else :
569- shard = [Shard (- 2 )]
570599 parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 2 )
571600
572601 parameter = parameter .to (param_casting_dtype )
573602 if to_contiguous :
574603 parameter = parameter .contiguous ()
575- if self .use_dtensor :
576- parameter = DTensor .from_local (
577- parameter , device_mesh , shard , run_check = False , shape = empty_param .size (), stride = empty_param .stride ()
578- )
604+
579605 return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
580606
581607 @staticmethod
582608 def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
583- # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
584- if outputs .placements != output_layouts :
585- outputs = outputs .redistribute (placements = output_layouts , async_op = False )
586- # back to local tensor
587- return outputs .to_local () if use_local_output and isinstance (outputs , DTensor ) else outputs
588-
589-
590- class PackedColwiseParallel (ColwiseParallel ):
591- def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
592- # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
593- # means Colwise as Linear is input * weight^T + bias, where
594- # weight would become Shard(1)
595- parameter = get_packed_weights (param , empty_param , device_mesh , rank , - 2 )
596- parameter = parameter .to (param_casting_dtype )
597- if to_contiguous :
598- parameter = parameter .contiguous ()
599- if self .use_dtensor :
600- parameter = DTensor .from_local (parameter , device_mesh , [Shard (- 2 )], run_check = False )
601- return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
609+ outputs = CopyToModelParallelRegion .apply (outputs , device_mesh )
610+ return outputs
602611
603612
604613class RowwiseParallel (TensorParallelLayer ):
@@ -635,23 +644,15 @@ def __init__(
635644 self .use_dtensor = use_dtensor
636645
637646 def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
638- # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
639- # means Rowwise as nn.Linear is input * weight^T + bias, where
640- # weight would become Shard(0)
641- if param_type != "bias" :
642- parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 1 )
643- shard = [Shard (- 1 )]
644- else :
645- shard = [Replicate ()]
647+ if param_type == "bias" :
646648 parameter = param [:]
649+ else :
650+ parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 1 )
647651
648652 parameter = parameter .to (param_casting_dtype )
649653 if to_contiguous :
650654 parameter = parameter .contiguous ()
651- if self .use_dtensor :
652- parameter = DTensor .from_local (
653- parameter , device_mesh , shard , run_check = False , shape = empty_param .size (), stride = empty_param .stride ()
654- )
655+
655656 return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
656657
657658 @staticmethod
@@ -661,24 +662,14 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_
661662 mod .bias = None
662663
663664 input_tensor = inputs [0 ]
664- if not isinstance (input_tensor , DTensor ):
665- input_tensor = DTensor .from_local (input_tensor , device_mesh , input_layouts , run_check = False )
666-
667- if input_layouts != desired_input_layouts :
668- input_tensor = input_tensor .redistribute (placements = desired_input_layouts , async_op = True )
669665 return input_tensor
670666
671667 @staticmethod
672668 def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
673- # Rowwise sharding produces partial output, depending on output layouts:
674- # 1. to replicate -> allreduce
675- # 2. to shard -> reduce_scatter
676- if outputs .placements != output_layouts :
677- outputs = outputs .redistribute (placements = output_layouts , async_op = True )
669+ outputs = ReduceFromModelParallelRegion .apply (outputs , device_mesh )
678670 if hasattr (mod , "_bias" ):
679671 outputs += mod ._bias
680- # back to local tensor if use_local_output is True
681- return outputs .to_local () if use_local_output and isinstance (outputs , DTensor ) else outputs
672+ return outputs
682673
683674 def prepare_module_tp (self , module : nn .Module , device_mesh ) -> nn .Module :
684675 module ._distribute_module_applied = True
@@ -703,6 +694,21 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
703694 )
704695
705696
697+ class PackedColwiseParallel (ColwiseParallel ):
698+ def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
699+ # NOTE(3outeille): need to be deprecated as no longer using dtensors
700+ # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
701+ # means Colwise as Linear is input * weight^T + bias, where
702+ # weight would become Shard(1)
703+ parameter = get_packed_weights (param , empty_param , device_mesh , rank , - 2 )
704+ parameter = parameter .to (param_casting_dtype )
705+ if to_contiguous :
706+ parameter = parameter .contiguous ()
707+ if self .use_dtensor :
708+ parameter = DTensor .from_local (parameter , device_mesh , [Shard (- 2 )], run_check = False )
709+ return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
710+
711+
706712class PackedRowwiseParallel (RowwiseParallel ):
707713 def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
708714 # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
0 commit comments