Skip to content

Commit b042f26

Browse files
Fixing backward not being called, gradient scaling and enabling backward with torch.no_grad() (#237)
1 parent 36535ef commit b042f26

File tree

3 files changed

+91
-225
lines changed

3 files changed

+91
-225
lines changed

autoparallel/graph_pp_runner.py

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
67
from dataclasses import dataclass
7-
from typing import Any, Callable, Optional, Union, cast
8+
from typing import Any, Callable, cast, Optional, Union
89

910
import torch
1011
import torch.fx as fx
@@ -15,11 +16,14 @@
1516
_wait_batch_p2p,
1617
)
1718
from torch.distributed.pipelining.stage import (
18-
PipelineStage,
1919
_normalize_model_output_as_tuple,
20+
PipelineStage,
2021
)
2122
from torch.distributed.tensor import DTensor
2223

24+
logger = logging.getLogger(__name__)
25+
logger.setLevel(logging.DEBUG)
26+
2327

2428
@dataclass
2529
class GraphCallables:
@@ -75,6 +79,22 @@ def __init__(
7579
"unsharded_grads": [],
7680
}
7781

82+
def scale_grads(self, grad_scale_factor: int) -> None:
83+
"""Scale stage's gradients by `grad_scale_factor`, which should be specified in coordination with the
84+
loss function used with pipelining. For loss functions which perform 'mean' loss reduction, `grad_scale_factor`
85+
should be set to num_microbatches. For loss functions that use `sum` reduction, `grad_scale_factor` should
86+
be set to 1.
87+
88+
Should only be called once per pipeline schedule step, after all backwards passes have completed.
89+
"""
90+
91+
# PP scales only for its own contribution (microbatches), but relies on DP to scale further
92+
# for DP degree.
93+
if grad_scale_factor != 1:
94+
for grad in self.state["unsharded_grads"]:
95+
if grad is not None:
96+
grad.div_(grad_scale_factor)
97+
7898

7999
def _run_fw_module(
80100
fw_module: fx.GraphModule, graph_meta: GraphMeta, fw_args: list[Any]
@@ -243,7 +263,10 @@ def stage_forward(
243263
composite_args = stage._retrieve_recv_activations(mb_index)
244264

245265
# stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args?
246-
266+
logger.debug(
267+
"GraphPPRunner running action %s",
268+
action,
269+
)
247270
output, saved_intermediates = _run_forward_microbatch(stage, *composite_args)
248271

249272
# See [Note: pipeline model output type]
@@ -306,6 +329,7 @@ def stage_full_backward(
306329
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1
307330

308331
if not backward_stage.has_backward:
332+
logger.debug("Returning early for backward stage")
309333
return
310334
(
311335
stage_output,
@@ -320,7 +344,7 @@ def stage_full_backward(
320344
# HACK till we have loss function, we populate the tangents here manually
321345
bwd_kwargs = {
322346
"stage_output": loss,
323-
"tangents": [torch.randn_like(stage_output)],
347+
"tangents": [torch.randn_like(stage_output[0])],
324348
"saved_intermediates": saved_intermediates,
325349
}
326350
else:
@@ -334,10 +358,14 @@ def stage_full_backward(
334358
"tangents": output_grads,
335359
"saved_intermediates": saved_intermediates,
336360
}
337-
361+
logger.debug(
362+
"GraphPPRunner running action %s",
363+
action,
364+
)
338365
input_grads = _run_backward_microbatch(backward_stage, bwd_kwargs)
339-
340-
backward_stage.bwd_cache[backward_mb_index] = input_grads
366+
backward_stage.bwd_cache[backward_mb_index] = (
367+
tuple(input_grads) if not isinstance(input_grads, tuple) else input_grads
368+
)
341369

342370
# skipping detach logic
343371

@@ -362,9 +390,20 @@ def stage_unshard(
362390
stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages
363391
}
364392
stage = stage_index_to_stage[action.stage_index]
393+
logger.debug(
394+
"GraphPPRunner running action %s",
395+
action,
396+
)
365397
if stage.graph_callables.unshard is None:
366398
stage.state["unsharded_params"] = stage.state["sharded_params"]
367-
# TODO (sanketpurandare): Add the fw_fsdp_all_gather graph call here
399+
else:
400+
sharded_params = list(stage.state["sharded_params"])
401+
unsharded_params = _run_unshard_module(
402+
stage.graph_callables.unshard,
403+
stage.graph_meta,
404+
sharded_params,
405+
)
406+
stage.state["unsharded_params"] = unsharded_params
368407

369408

370409
def stage_reshard(
@@ -377,6 +416,10 @@ def stage_reshard(
377416
stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages
378417
}
379418
stage = stage_index_to_stage[action.stage_index]
419+
logger.debug(
420+
"GraphPPRunner running action %s",
421+
action,
422+
)
380423
stage.state["unsharded_params"].clear()
381424

382425

@@ -390,8 +433,19 @@ def stage_reduce_grad(
390433
stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages
391434
}
392435
stage = stage_index_to_stage[action.stage_index]
436+
logger.debug(
437+
"GraphPPRunner running action %s",
438+
action,
439+
)
393440
if stage.graph_callables.reduce_grad is None:
394441
stage.state["sharded_grads"] = stage.state["unsharded_grads"]
442+
else:
443+
sharded_grads = _run_reduce_grad_module(
444+
stage.graph_callables.reduce_grad,
445+
stage.graph_meta,
446+
stage.state["unsharded_grads"],
447+
)
448+
stage.state["sharded_grads"] = sharded_grads
395449

396450

397451
class GraphPPRunner:
@@ -400,6 +454,19 @@ def __init__(
400454
schedule: _PipelineScheduleRuntime,
401455
):
402456
self.schedule = schedule
457+
if not schedule._backward_requires_autograd:
458+
assert all(
459+
isinstance(stage, GraphPipelineStage)
460+
and (
461+
stage.graph_callables.full_bw is not None
462+
or (
463+
stage.graph_callables.bw_dI is not None
464+
and stage.graph_callables.bw_dW is not None
465+
)
466+
)
467+
for stage in schedule._stages
468+
)
469+
self.schedule._has_backward = True
403470

404471
def _populate_stage_states(self, stage: GraphPipelineStage) -> None:
405472
sharded_params = [
@@ -415,21 +482,10 @@ def _populate_stage_states(self, stage: GraphPipelineStage) -> None:
415482
stage.state["sharded_params"] = sharded_params
416483
stage.state["buffers"] = buffers
417484
stage.state["unsharded_grads"] = [None] * len(sharded_params)
418-
# TODO (sanketpurandare)
419-
# pipeline schedule runtime does not allow us to register a custom function
420-
# for UNSHARD/RESHARD/REDUCE_GRAD action types yet
421-
# HACK remove this once we support this
422-
if stage.graph_callables.unshard is None:
423-
stage.state["unsharded_params"] = stage.state["sharded_params"]
424485

425486
def _accumulate_stage_grads_and_clear_states(
426487
self, stage: GraphPipelineStage
427488
) -> None:
428-
# TODO (sanketpurandare)
429-
# We don't have a REDUCE_GRAD action yet in the ScheduleIR yet
430-
# HACK remove this once Ivan's PR lands
431-
if stage.graph_callables.reduce_grad is None:
432-
stage.state["sharded_grads"] = stage.state["unsharded_grads"]
433489
grads = stage.state["sharded_grads"]
434490
params = list(stage.submod.parameters())
435491
for param, grad in zip(params, grads):

examples/example_ds3_pp.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from torch.distributed.pipelining.schedules import (
1717
FORWARD,
1818
FULL_BACKWARD,
19+
REDUCE_GRAD,
20+
RESHARD,
21+
UNSHARD,
1922
PipelineScheduleMulti,
2023
_PipelineSchedule,
2124
_PipelineScheduleRuntime,
@@ -42,8 +45,15 @@
4245
GraphPPRunner,
4346
stage_forward,
4447
stage_full_backward,
48+
stage_reduce_grad,
49+
stage_reshard,
50+
stage_unshard,
4551
)
4652

53+
# Configure logging to show DEBUG messages
54+
logging.basicConfig(
55+
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
56+
)
4757
logger = logging.getLogger(__name__)
4858

4959

@@ -54,6 +64,7 @@ def build_pipeline_schedule(
5464
microbatch_size: int,
5565
local_batch_size: int,
5666
pipeline_parallel_degree: int,
67+
backward_requires_autograd: bool = False,
5768
) -> _PipelineSchedule:
5869
"""Builds a pipeline schedule for the given configuration and stages."""
5970
schedule_class = get_schedule_class(pipeline_parallel_schedule)
@@ -78,6 +89,7 @@ def build_pipeline_schedule(
7889
stages if looped_schedule else stages[0],
7990
n_microbatches=n_microbatches,
8091
loss_fn=loss_fn,
92+
backward_requires_autograd=backward_requires_autograd,
8193
)
8294
logger.info(
8395
f"Using pipeline schedule {pipeline_parallel_schedule} "
@@ -427,11 +439,15 @@ def shape_inference_output_fn_last_stage():
427439
microbatch_size=microbatch_size,
428440
local_batch_size=local_batch_size,
429441
pipeline_parallel_degree=pp_degree,
442+
backward_requires_autograd=False,
430443
)
431444
assert isinstance(schedule, _PipelineScheduleRuntime)
432445
# Step 6. Override the pipeline runner's action implementations
433446
schedule.register_custom_function(FORWARD, stage_forward)
434447
schedule.register_custom_function(FULL_BACKWARD, stage_full_backward)
448+
schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad)
449+
schedule.register_custom_function(RESHARD, stage_reshard)
450+
schedule.register_custom_function(UNSHARD, stage_unshard)
435451

436452
# Step 7. Register the schedule with the graph runner
437453

0 commit comments

Comments
 (0)