33# This source code is licensed under the BSD license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ import logging
67from dataclasses import dataclass
78from typing import Any , Callable , Optional , Union , cast
89
2021)
2122from torch .distributed .tensor import DTensor
2223
24+ logger = logging .getLogger (__name__ )
25+ logger .setLevel (logging .DEBUG )
26+
2327
2428@dataclass
2529class GraphCallables :
@@ -75,6 +79,22 @@ def __init__(
7579 "unsharded_grads" : [],
7680 }
7781
82+ def scale_grads (self , grad_scale_factor : int ) -> None :
83+ """Scale gradients model gradients by `grad_scale_factor`, which should be specified in coordination with the
84+ loss function used with pipelining. For loss functions which perform 'mean' loss reduction, `grad_scale_factor`
85+ should be set to num_microbatches. For loss functions that use `sum` reduction, `grad_scale_factor` should
86+ be set to 1.
87+
88+ Should only be called once per pipeline schedule step, after all backwards passes have completed.
89+ """
90+
91+ # PP scales only for its own contribution (microbatches), but relies on DP to scale further
92+ # for DP degree.
93+ if grad_scale_factor != 1 :
94+ for grad in self .state ["unsharded_grads" ]:
95+ if grad is not None :
96+ grad .div_ (grad_scale_factor )
97+
7898
7999def _run_fw_module (
80100 fw_module : fx .GraphModule , graph_meta : GraphMeta , fw_args : list [Any ]
@@ -243,7 +263,10 @@ def stage_forward(
243263 composite_args = stage ._retrieve_recv_activations (mb_index )
244264
245265 # stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args?
246-
266+ logger .debug (
267+ "GraphPPRunner running action %s" ,
268+ action ,
269+ )
247270 output , saved_intermediates = _run_forward_microbatch (stage , * composite_args )
248271
249272 # See [Note: pipeline model output type]
@@ -306,6 +329,7 @@ def stage_full_backward(
306329 grad_scale_factor = schedule ._n_microbatches if schedule .scale_grads else 1
307330
308331 if not backward_stage .has_backward :
332+ logger .debug ("Returning early for backward stage" )
309333 return
310334 (
311335 stage_output ,
@@ -320,7 +344,7 @@ def stage_full_backward(
320344 # HACK till we have loss function, we populate the tangents here manually
321345 bwd_kwargs = {
322346 "stage_output" : loss ,
323- "tangents" : [torch .randn_like (stage_output )],
347+ "tangents" : [torch .randn_like (stage_output [ 0 ] )],
324348 "saved_intermediates" : saved_intermediates ,
325349 }
326350 else :
@@ -334,10 +358,14 @@ def stage_full_backward(
334358 "tangents" : output_grads ,
335359 "saved_intermediates" : saved_intermediates ,
336360 }
337-
361+ logger .debug (
362+ "GraphPPRunner running action %s" ,
363+ action ,
364+ )
338365 input_grads = _run_backward_microbatch (backward_stage , bwd_kwargs )
339-
340- backward_stage .bwd_cache [backward_mb_index ] = input_grads
366+ backward_stage .bwd_cache [backward_mb_index ] = (
367+ tuple (input_grads ) if not isinstance (input_grads , tuple ) else input_grads
368+ )
341369
342370 # skipping detach logic
343371
@@ -362,9 +390,20 @@ def stage_unshard(
362390 stage .stage_index : cast (GraphPipelineStage , stage ) for stage in schedule ._stages
363391 }
364392 stage = stage_index_to_stage [action .stage_index ]
393+ logger .debug (
394+ "GraphPPRunner running action %s" ,
395+ action ,
396+ )
365397 if stage .graph_callables .unshard is None :
366398 stage .state ["unsharded_params" ] = stage .state ["sharded_params" ]
367- # TODO (sanketpurandare): Add the fw_fsdp_all_gather graph call here
399+ else :
400+ sharded_params = list (stage .state ["sharded_params" ])
401+ unsharded_params = _run_unshard_module (
402+ stage .graph_callables .unshard ,
403+ stage .graph_meta ,
404+ sharded_params ,
405+ )
406+ stage .state ["unsharded_params" ] = unsharded_params
368407
369408
370409def stage_reshard (
@@ -377,6 +416,10 @@ def stage_reshard(
377416 stage .stage_index : cast (GraphPipelineStage , stage ) for stage in schedule ._stages
378417 }
379418 stage = stage_index_to_stage [action .stage_index ]
419+ logger .debug (
420+ "GraphPPRunner running action %s" ,
421+ action ,
422+ )
380423 stage .state ["unsharded_params" ].clear ()
381424
382425
@@ -390,8 +433,19 @@ def stage_reduce_grad(
390433 stage .stage_index : cast (GraphPipelineStage , stage ) for stage in schedule ._stages
391434 }
392435 stage = stage_index_to_stage [action .stage_index ]
436+ logger .debug (
437+ "GraphPPRunner running action %s" ,
438+ action ,
439+ )
393440 if stage .graph_callables .reduce_grad is None :
394441 stage .state ["sharded_grads" ] = stage .state ["unsharded_grads" ]
442+ else :
443+ sharded_grads = _run_reduce_grad_module (
444+ stage .graph_callables .reduce_grad ,
445+ stage .graph_meta ,
446+ stage .state ["unsharded_grads" ],
447+ )
448+ stage .state ["sharded_grads" ] = sharded_grads
395449
396450
397451class GraphPPRunner :
@@ -400,6 +454,19 @@ def __init__(
400454 schedule : _PipelineScheduleRuntime ,
401455 ):
402456 self .schedule = schedule
457+ if not schedule ._backward_requires_autograd :
458+ assert all (
459+ isinstance (stage , GraphPipelineStage )
460+ and (
461+ stage .graph_callables .full_bw is not None
462+ or (
463+ stage .graph_callables .bw_dI is not None
464+ and stage .graph_callables .bw_dW is not None
465+ )
466+ )
467+ for stage in schedule ._stages
468+ )
469+ self .schedule ._has_backward = True
403470
404471 def _populate_stage_states (self , stage : GraphPipelineStage ) -> None :
405472 sharded_params = [
@@ -415,21 +482,10 @@ def _populate_stage_states(self, stage: GraphPipelineStage) -> None:
415482 stage .state ["sharded_params" ] = sharded_params
416483 stage .state ["buffers" ] = buffers
417484 stage .state ["unsharded_grads" ] = [None ] * len (sharded_params )
418- # TODO (sanketpurandare)
419- # pipeline schedule runtime does not allow us to register a custom function
420- # for UNSHARD/RESHARD/REDUCE_GRAD action types yet
421- # HACK remove this once we support this
422- if stage .graph_callables .unshard is None :
423- stage .state ["unsharded_params" ] = stage .state ["sharded_params" ]
424485
425486 def _accumulate_stage_grads_and_clear_states (
426487 self , stage : GraphPipelineStage
427488 ) -> None :
428- # TODO (sanketpurandare)
429- # We don't have a REDUCE_GRAD action yet in the ScheduleIR yet
430- # HACK remove this once Ivan's PR lands
431- if stage .graph_callables .reduce_grad is None :
432- stage .state ["sharded_grads" ] = stage .state ["unsharded_grads" ]
433489 grads = stage .state ["sharded_grads" ]
434490 params = list (stage .submod .parameters ())
435491 for param , grad in zip (params , grads ):
0 commit comments