@@ -588,23 +588,33 @@ def forward(self, *args):
588588
589589class AutoParallelPP (AutoParallel ):
590590 def apply_placement_pp (
591- self , sharding_placement = None , generate_di_dw_split_graphs = False
591+ self , sharding_placement = None , graph_passes : list [ str ] = []
592592 ) -> dict [str , Any ]:
593+ assert all (
594+ g_pass in ["split_fsdp_collectives" , "split_dI_dW" ]
595+ for g_pass in graph_passes
596+ ), "Only split_fsdp_collectives and split_dI_dW_graph are supported"
593597 sharded_param_dict , sharded_buffer_dict = self ._apply_placement_common (
594598 sharding_placement
595599 )
600+ num_params = len (sharded_param_dict )
601+ num_buffers = len (sharded_buffer_dict )
596602 (
597603 fw_module ,
598604 bw_module ,
605+ num_params_buffers ,
599606 num_user_outputs ,
600607 num_mutate_inputs ,
601608 num_fw_outs_saved_for_bw ,
602609 num_symints_saved_for_bw ,
603610 _indices_of_inps_to_detach ,
604611 adjusted_flat_args ,
605612 ) = partition_joint_with_descriptors (self .joint_with_descriptors )
606-
613+ assert num_params_buffers == (
614+ num_params + num_buffers
615+ ), f"num_params_buffers: { num_params_buffers } , num_params: { num_params } , num_buffers: { num_buffers } "
607616 print (
617+ f"num_params_buffers: { num_params_buffers } \n "
608618 f"num_user_outputs: { num_user_outputs } \n "
609619 f"num_mutate_inputs: { num_mutate_inputs } \n "
610620 f"num_fw_outs_saved_for_bw: { num_fw_outs_saved_for_bw } \n "
@@ -631,14 +641,71 @@ def apply_placement_pp(
631641 print_output = False , include_stride = True , include_device = True
632642 ),
633643 )
634- if generate_di_dw_split_graphs :
635- from autoparallel ._passes .split_di_dw_graph import split_di_dw_graph
644+ unshard_module : Optional [torch .fx .GraphModule ] = None
645+ reduce_grad_module : Optional [torch .fx .GraphModule ] = None
646+ if "split_fsdp_collectives" in graph_passes :
647+ assert (
648+ not self .reshard_after_forward
649+ ), "reshard_after_forward should be False to disable FSDP all_gather in the backward pass"
650+ from autoparallel ._passes .split_fsdp_collectives import (
651+ split_fsdp_prefetch ,
652+ split_fsdp_reduce_scatters_epilogue ,
653+ )
636654
637- num_weight_gradients = (
638- self .joint_with_descriptors ._aot_state .aot_config .num_params_buffers
655+ unshard_module , fw_module = split_fsdp_prefetch (fw_module , num_params )
656+ trace_structured (
657+ "artifact" ,
658+ metadata_fn = lambda : {
659+ "name" : "autoparallel_pp_unshard_graph" ,
660+ "encoding" : "string" ,
661+ },
662+ payload_fn = lambda : unshard_module .print_readable (
663+ print_output = False , include_stride = True , include_device = True
664+ ),
639665 )
666+ trace_structured (
667+ "artifact" ,
668+ metadata_fn = lambda : {
669+ "name" : "autoparallel_pp_fwd_no_fsdp_graph" ,
670+ "encoding" : "string" ,
671+ },
672+ payload_fn = lambda : fw_module .print_readable (
673+ print_output = False , include_stride = True , include_device = True
674+ ),
675+ )
676+ bw_module , reduce_grad_module = split_fsdp_reduce_scatters_epilogue (
677+ bw_module , num_params
678+ )
679+ trace_structured (
680+ "artifact" ,
681+ metadata_fn = lambda : {
682+ "name" : "autoparallel_pp_bwd_no_fsdp_graph" ,
683+ "encoding" : "string" ,
684+ },
685+ payload_fn = lambda : bw_module .print_readable (
686+ print_output = False , include_stride = True , include_device = True
687+ ),
688+ )
689+ trace_structured (
690+ "artifact" ,
691+ metadata_fn = lambda : {
692+ "name" : "autoparallel_pp_reduce_grad_graph" ,
693+ "encoding" : "string" ,
694+ },
695+ payload_fn = lambda : reduce_grad_module .print_readable (
696+ print_output = False , include_stride = True , include_device = True
697+ ),
698+ )
699+
700+ bw_dI_module : Optional [torch .fx .GraphModule ] = None
701+ bw_dW_module : Optional [torch .fx .GraphModule ] = None
702+ num_input_grads = 0
703+ if "split_dI_dW" in graph_passes :
704+ from autoparallel ._passes .split_di_dw_graph import split_di_dw_graph
705+
640706 bw_dI_module , bw_dW_module , num_input_grads = split_di_dw_graph (
641- bw_module , num_weight_gradients = num_weight_gradients
707+ bw_module ,
708+ num_weight_gradients = num_params_buffers ,
642709 )
643710 trace_structured (
644711 "artifact" ,
@@ -669,24 +736,23 @@ def apply_placement_pp(
669736 raise RuntimeError (
670737 "attempted to run split dI/dW pass on a graph that has no input gradients"
671738 )
672- else :
673- bw_dI_module , bw_dW_module , num_input_grads = None , None , - 1
674739
675740 graph_meta : dict [str , int ] = {
676741 "num_mutate_inputs" : num_mutate_inputs ,
677742 "num_user_outputs" : num_user_outputs ,
678743 "num_symints_saved_for_bw" : num_symints_saved_for_bw ,
679- "num_weight_buffer_grads " : len ( sharded_param_dict )
680- + len ( sharded_buffer_dict ) ,
744+ "num_params " : num_params ,
745+ "num_buffers" : num_buffers ,
681746 "num_input_grads" : num_input_grads ,
682747 }
748+
683749 graph_modules : dict [str , Optional [torch .fx .GraphModule ]] = {
684750 "fw" : fw_module ,
685751 "full_bw" : bw_module ,
686752 "bw_dI" : bw_dI_module ,
687753 "bw_dW" : bw_dW_module ,
688- "unshard" : None ,
689- "reduce_grad" : None ,
754+ "unshard" : unshard_module ,
755+ "reduce_grad" : reduce_grad_module ,
690756 }
691757 self .parallel_model = AutoParallelPPModule (
692758 sharded_param_dict ,
0 commit comments