33# This source code is licensed under the BSD license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ import logging
67from dataclasses import dataclass
78from typing import Any , Callable , Optional , Union , cast
89
2021)
2122from torch .distributed .tensor import DTensor
2223
24+ from autoparallel .utils import DebugInterpreter
25+
26+ logger = logging .getLogger (__name__ )
27+ logger .setLevel (logging .DEBUG )
28+
2329
2430@dataclass
2531class GraphCallables :
@@ -75,14 +81,39 @@ def __init__(
7581 "unsharded_grads" : [],
7682 }
7783
84+ def scale_grads (self , grad_scale_factor : int ) -> None :
85+ """Scale stage's gradients by `grad_scale_factor`, which should be specified in coordination with the
86+ loss function used with pipelining. For loss functions which perform 'mean' loss reduction, `grad_scale_factor`
87+ should be set to num_microbatches. For loss functions that use `sum` reduction, `grad_scale_factor` should
88+ be set to 1.
89+
90+ Should only be called once per pipeline schedule step, after all backwards passes have completed.
91+ """
92+
93+ # PP scales only for its own contribution (microbatches), but relies on DP to scale further
94+ # for DP degree.
95+ if grad_scale_factor != 1 :
96+ for grad in self .state ["unsharded_grads" ]:
97+ if grad is not None :
98+ grad .div_ (grad_scale_factor )
99+
78100
79101def _run_fw_module (
80- fw_module : fx .GraphModule , graph_meta : GraphMeta , fw_args : list [Any ]
102+ fw_module : fx .GraphModule ,
103+ graph_meta : GraphMeta ,
104+ fw_args : list [Any ],
105+ numerics_logs : Optional [list [str ]] = None ,
81106) -> tuple [Any , tuple [list [Any ], list [Any ]]]:
82107 assert len ([n for n in fw_module .graph .nodes if n .op == "placeholder" ]) == len (
83108 fw_args
84109 ), f"Mismatched number of inputs to fwd, { len ([n for n in fw_module .graph .nodes if n .op == 'placeholder' ])} , { len (fw_args )} "
85- fw_outputs = torch .fx .Interpreter (fw_module ).boxed_run (fw_args )
110+ if numerics_logs is not None :
111+ debug_interpreter = DebugInterpreter (fw_module )
112+ fw_outputs = debug_interpreter .boxed_run (fw_args )
113+ numerics_logs += debug_interpreter .get_logs ()
114+ else :
115+ fw_outputs = torch .fx .Interpreter (fw_module ).boxed_run (fw_args )
116+
86117 num_inner_fwd_outputs = graph_meta .num_mutate_inputs + graph_meta .num_user_outputs
87118 saved_intermediates = fw_outputs [num_inner_fwd_outputs :]
88119 num_tensors_for_backward = (
@@ -153,14 +184,16 @@ def _run_reduce_grad_module(
153184 return sharded_grads
154185
155186
156- def _run_forward_microbatch (stage : GraphPipelineStage , * args ) -> tuple [Any , Any ]:
187+ def _run_forward_microbatch (
188+ stage : GraphPipelineStage , * args , numerics_logs : Optional [list [str ]] = None
189+ ) -> tuple [Any , Any ]:
157190 fw_args = [
158191 * stage .state ["unsharded_params" ],
159192 * stage .state ["buffers" ],
160193 * args ,
161194 ]
162195 user_outputs , saved_intermediates = _run_fw_module (
163- stage .graph_callables .fw , stage .graph_meta , fw_args
196+ stage .graph_callables .fw , stage .graph_meta , fw_args , numerics_logs = numerics_logs
164197 )
165198 return (user_outputs , saved_intermediates )
166199
@@ -200,6 +233,7 @@ def _run_backward_microbatch(
200233def stage_forward (
201234 action : _Action ,
202235 ctx : _PipelineContext ,
236+ numerics_logs : Optional [list [str ]] = None ,
203237) -> None :
204238 schedule = ctx .schedule_ref
205239 assert isinstance (schedule , _PipelineScheduleRuntime )
@@ -243,8 +277,13 @@ def stage_forward(
243277 composite_args = stage ._retrieve_recv_activations (mb_index )
244278
245279 # stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args?
246-
247- output , saved_intermediates = _run_forward_microbatch (stage , * composite_args )
280+ logger .debug (
281+ "GraphPPRunner running action %s" ,
282+ action ,
283+ )
284+ output , saved_intermediates = _run_forward_microbatch (
285+ stage , * composite_args , numerics_logs = numerics_logs
286+ )
248287
249288 # See [Note: pipeline model output type]
250289 output_tuple = _normalize_model_output_as_tuple (output )
@@ -306,6 +345,7 @@ def stage_full_backward(
306345 grad_scale_factor = schedule ._n_microbatches if schedule .scale_grads else 1
307346
308347 if not backward_stage .has_backward :
348+ logger .debug ("Returning early for backward stage" )
309349 return
310350 (
311351 stage_output ,
@@ -320,7 +360,7 @@ def stage_full_backward(
320360 # HACK till we have loss function, we populate the tangents here manually
321361 bwd_kwargs = {
322362 "stage_output" : loss ,
323- "tangents" : [torch .randn_like (stage_output )],
363+ "tangents" : [torch .randn_like (stage_output [ 0 ] )],
324364 "saved_intermediates" : saved_intermediates ,
325365 }
326366 else :
@@ -334,10 +374,14 @@ def stage_full_backward(
334374 "tangents" : output_grads ,
335375 "saved_intermediates" : saved_intermediates ,
336376 }
337-
377+ logger .debug (
378+ "GraphPPRunner running action %s" ,
379+ action ,
380+ )
338381 input_grads = _run_backward_microbatch (backward_stage , bwd_kwargs )
339-
340- backward_stage .bwd_cache [backward_mb_index ] = input_grads
382+ backward_stage .bwd_cache [backward_mb_index ] = (
383+ tuple (input_grads ) if not isinstance (input_grads , tuple ) else input_grads
384+ )
341385
342386 # skipping detach logic
343387
@@ -362,9 +406,20 @@ def stage_unshard(
362406 stage .stage_index : cast (GraphPipelineStage , stage ) for stage in schedule ._stages
363407 }
364408 stage = stage_index_to_stage [action .stage_index ]
409+ logger .debug (
410+ "GraphPPRunner running action %s" ,
411+ action ,
412+ )
365413 if stage .graph_callables .unshard is None :
366414 stage .state ["unsharded_params" ] = stage .state ["sharded_params" ]
367- # TODO (sanketpurandare): Add the fw_fsdp_all_gather graph call here
415+ else :
416+ sharded_params = list (stage .state ["sharded_params" ])
417+ unsharded_params = _run_unshard_module (
418+ stage .graph_callables .unshard ,
419+ stage .graph_meta ,
420+ sharded_params ,
421+ )
422+ stage .state ["unsharded_params" ] = unsharded_params
368423
369424
370425def stage_reshard (
@@ -377,6 +432,10 @@ def stage_reshard(
377432 stage .stage_index : cast (GraphPipelineStage , stage ) for stage in schedule ._stages
378433 }
379434 stage = stage_index_to_stage [action .stage_index ]
435+ logger .debug (
436+ "GraphPPRunner running action %s" ,
437+ action ,
438+ )
380439 stage .state ["unsharded_params" ].clear ()
381440
382441
@@ -390,8 +449,19 @@ def stage_reduce_grad(
390449 stage .stage_index : cast (GraphPipelineStage , stage ) for stage in schedule ._stages
391450 }
392451 stage = stage_index_to_stage [action .stage_index ]
452+ logger .debug (
453+ "GraphPPRunner running action %s" ,
454+ action ,
455+ )
393456 if stage .graph_callables .reduce_grad is None :
394457 stage .state ["sharded_grads" ] = stage .state ["unsharded_grads" ]
458+ else :
459+ sharded_grads = _run_reduce_grad_module (
460+ stage .graph_callables .reduce_grad ,
461+ stage .graph_meta ,
462+ stage .state ["unsharded_grads" ],
463+ )
464+ stage .state ["sharded_grads" ] = sharded_grads
395465
396466
397467class GraphPPRunner :
@@ -400,6 +470,19 @@ def __init__(
400470 schedule : _PipelineScheduleRuntime ,
401471 ):
402472 self .schedule = schedule
473+ if not schedule ._backward_requires_autograd :
474+ assert all (
475+ isinstance (stage , GraphPipelineStage )
476+ and (
477+ stage .graph_callables .full_bw is not None
478+ or (
479+ stage .graph_callables .bw_dI is not None
480+ and stage .graph_callables .bw_dW is not None
481+ )
482+ )
483+ for stage in schedule ._stages
484+ )
485+ self .schedule ._has_backward = True
403486
404487 def _populate_stage_states (self , stage : GraphPipelineStage ) -> None :
405488 sharded_params = [
@@ -415,21 +498,10 @@ def _populate_stage_states(self, stage: GraphPipelineStage) -> None:
415498 stage .state ["sharded_params" ] = sharded_params
416499 stage .state ["buffers" ] = buffers
417500 stage .state ["unsharded_grads" ] = [None ] * len (sharded_params )
418- # TODO (sanketpurandare)
419- # pipeline schedule runtime does not allow us to register a custom function
420- # for UNSHARD/RESHARD/REDUCE_GRAD action types yet
421- # HACK remove this once we support this
422- if stage .graph_callables .unshard is None :
423- stage .state ["unsharded_params" ] = stage .state ["sharded_params" ]
424501
425502 def _accumulate_stage_grads_and_clear_states (
426503 self , stage : GraphPipelineStage
427504 ) -> None :
428- # TODO (sanketpurandare)
429- # We don't have a REDUCE_GRAD action yet in the ScheduleIR yet
430- # HACK remove this once Ivan's PR lands
431- if stage .graph_callables .reduce_grad is None :
432- stage .state ["sharded_grads" ] = stage .state ["unsharded_grads" ]
433505 grads = stage .state ["sharded_grads" ]
434506 params = list (stage .submod .parameters ())
435507 for param , grad in zip (params , grads ):
0 commit comments