Skip to content

Commit 3701919

Browse files
sanketpurandarexmfan
authored andcommitted
Fixing backward not being called, gradient scaling and enabling backward with torch.no_grad() (#237)
stack-info: PR: #239, branch: xmfan/stack/17
1 parent 36535ef commit 3701919

File tree

5 files changed

+205
-241
lines changed

5 files changed

+205
-241
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10621062

10631063
# HOPs don't support buffer mutations, keep this outside
10641064
with torch.no_grad():
1065-
self.tokens_per_expert.add_(num_tokens_per_expert)
1065+
self.tokens_per_expert.add_(num_tokens_per_expert) # type: ignore[operator]
10661066
return out
10671067

10681068
def init_weights(
@@ -1076,14 +1076,10 @@ def init_weights(
10761076
self.shared_experts.init_weights(init_std)
10771077

10781078
with torch.device(buffer_device):
1079-
self.tokens_per_expert = torch.zeros(
1080-
self.experts.num_experts, dtype=torch.float32
1081-
)
1079+
self.tokens_per_expert.zero_() # type: ignore[operator]
10821080
if self.load_balance_coeff is not None:
10831081
assert isinstance(self.expert_bias, torch.Tensor)
1084-
self.expert_bias = torch.zeros(
1085-
self.experts.num_experts, dtype=torch.float32
1086-
)
1082+
self.expert_bias.zero_() # type: ignore[operator]
10871083

10881084

10891085
def has_cuda_capability(major: int, minor: int) -> bool:

autoparallel/graph_pp_runner.py

Lines changed: 94 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
67
from dataclasses import dataclass
78
from typing import Any, Callable, Optional, Union, cast
89

@@ -20,6 +21,11 @@
2021
)
2122
from torch.distributed.tensor import DTensor
2223

24+
from autoparallel.utils import DebugInterpreter
25+
26+
logger = logging.getLogger(__name__)
27+
logger.setLevel(logging.DEBUG)
28+
2329

2430
@dataclass
2531
class GraphCallables:
@@ -75,14 +81,39 @@ def __init__(
7581
"unsharded_grads": [],
7682
}
7783

84+
def scale_grads(self, grad_scale_factor: int) -> None:
85+
"""Scale stage's gradients by `grad_scale_factor`, which should be specified in coordination with the
86+
loss function used with pipelining. For loss functions which perform 'mean' loss reduction, `grad_scale_factor`
87+
should be set to num_microbatches. For loss functions that use `sum` reduction, `grad_scale_factor` should
88+
be set to 1.
89+
90+
Should only be called once per pipeline schedule step, after all backwards passes have completed.
91+
"""
92+
93+
# PP scales only for its own contribution (microbatches), but relies on DP to scale further
94+
# for DP degree.
95+
if grad_scale_factor != 1:
96+
for grad in self.state["unsharded_grads"]:
97+
if grad is not None:
98+
grad.div_(grad_scale_factor)
99+
78100

79101
def _run_fw_module(
80-
fw_module: fx.GraphModule, graph_meta: GraphMeta, fw_args: list[Any]
102+
fw_module: fx.GraphModule,
103+
graph_meta: GraphMeta,
104+
fw_args: list[Any],
105+
numerics_logs: Optional[list[str]] = None,
81106
) -> tuple[Any, tuple[list[Any], list[Any]]]:
82107
assert len([n for n in fw_module.graph.nodes if n.op == "placeholder"]) == len(
83108
fw_args
84109
), f"Mismatched number of inputs to fwd, {len([n for n in fw_module.graph.nodes if n.op == 'placeholder'])}, {len(fw_args)}"
85-
fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args)
110+
if numerics_logs is not None:
111+
debug_interpreter = DebugInterpreter(fw_module)
112+
fw_outputs = debug_interpreter.boxed_run(fw_args)
113+
numerics_logs += debug_interpreter.get_logs()
114+
else:
115+
fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args)
116+
86117
num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs
87118
saved_intermediates = fw_outputs[num_inner_fwd_outputs:]
88119
num_tensors_for_backward = (
@@ -153,14 +184,16 @@ def _run_reduce_grad_module(
153184
return sharded_grads
154185

155186

156-
def _run_forward_microbatch(stage: GraphPipelineStage, *args) -> tuple[Any, Any]:
187+
def _run_forward_microbatch(
188+
stage: GraphPipelineStage, *args, numerics_logs: Optional[list[str]] = None
189+
) -> tuple[Any, Any]:
157190
fw_args = [
158191
*stage.state["unsharded_params"],
159192
*stage.state["buffers"],
160193
*args,
161194
]
162195
user_outputs, saved_intermediates = _run_fw_module(
163-
stage.graph_callables.fw, stage.graph_meta, fw_args
196+
stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs
164197
)
165198
return (user_outputs, saved_intermediates)
166199

@@ -200,6 +233,7 @@ def _run_backward_microbatch(
200233
def stage_forward(
201234
action: _Action,
202235
ctx: _PipelineContext,
236+
numerics_logs: Optional[list[str]] = None,
203237
) -> None:
204238
schedule = ctx.schedule_ref
205239
assert isinstance(schedule, _PipelineScheduleRuntime)
@@ -243,8 +277,13 @@ def stage_forward(
243277
composite_args = stage._retrieve_recv_activations(mb_index)
244278

245279
# stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args?
246-
247-
output, saved_intermediates = _run_forward_microbatch(stage, *composite_args)
280+
logger.debug(
281+
"GraphPPRunner running action %s",
282+
action,
283+
)
284+
output, saved_intermediates = _run_forward_microbatch(
285+
stage, *composite_args, numerics_logs=numerics_logs
286+
)
248287

249288
# See [Note: pipeline model output type]
250289
output_tuple = _normalize_model_output_as_tuple(output)
@@ -306,6 +345,7 @@ def stage_full_backward(
306345
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1
307346

308347
if not backward_stage.has_backward:
348+
logger.debug("Returning early for backward stage")
309349
return
310350
(
311351
stage_output,
@@ -320,7 +360,7 @@ def stage_full_backward(
320360
# HACK till we have loss function, we populate the tangents here manually
321361
bwd_kwargs = {
322362
"stage_output": loss,
323-
"tangents": [torch.randn_like(stage_output)],
363+
"tangents": [torch.randn_like(stage_output[0])],
324364
"saved_intermediates": saved_intermediates,
325365
}
326366
else:
@@ -334,10 +374,14 @@ def stage_full_backward(
334374
"tangents": output_grads,
335375
"saved_intermediates": saved_intermediates,
336376
}
337-
377+
logger.debug(
378+
"GraphPPRunner running action %s",
379+
action,
380+
)
338381
input_grads = _run_backward_microbatch(backward_stage, bwd_kwargs)
339-
340-
backward_stage.bwd_cache[backward_mb_index] = input_grads
382+
backward_stage.bwd_cache[backward_mb_index] = (
383+
tuple(input_grads) if not isinstance(input_grads, tuple) else input_grads
384+
)
341385

342386
# skipping detach logic
343387

@@ -362,9 +406,20 @@ def stage_unshard(
362406
stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages
363407
}
364408
stage = stage_index_to_stage[action.stage_index]
409+
logger.debug(
410+
"GraphPPRunner running action %s",
411+
action,
412+
)
365413
if stage.graph_callables.unshard is None:
366414
stage.state["unsharded_params"] = stage.state["sharded_params"]
367-
# TODO (sanketpurandare): Add the fw_fsdp_all_gather graph call here
415+
else:
416+
sharded_params = list(stage.state["sharded_params"])
417+
unsharded_params = _run_unshard_module(
418+
stage.graph_callables.unshard,
419+
stage.graph_meta,
420+
sharded_params,
421+
)
422+
stage.state["unsharded_params"] = unsharded_params
368423

369424

370425
def stage_reshard(
@@ -377,6 +432,10 @@ def stage_reshard(
377432
stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages
378433
}
379434
stage = stage_index_to_stage[action.stage_index]
435+
logger.debug(
436+
"GraphPPRunner running action %s",
437+
action,
438+
)
380439
stage.state["unsharded_params"].clear()
381440

382441

@@ -390,8 +449,19 @@ def stage_reduce_grad(
390449
stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages
391450
}
392451
stage = stage_index_to_stage[action.stage_index]
452+
logger.debug(
453+
"GraphPPRunner running action %s",
454+
action,
455+
)
393456
if stage.graph_callables.reduce_grad is None:
394457
stage.state["sharded_grads"] = stage.state["unsharded_grads"]
458+
else:
459+
sharded_grads = _run_reduce_grad_module(
460+
stage.graph_callables.reduce_grad,
461+
stage.graph_meta,
462+
stage.state["unsharded_grads"],
463+
)
464+
stage.state["sharded_grads"] = sharded_grads
395465

396466

397467
class GraphPPRunner:
@@ -400,6 +470,19 @@ def __init__(
400470
schedule: _PipelineScheduleRuntime,
401471
):
402472
self.schedule = schedule
473+
if not schedule._backward_requires_autograd:
474+
assert all(
475+
isinstance(stage, GraphPipelineStage)
476+
and (
477+
stage.graph_callables.full_bw is not None
478+
or (
479+
stage.graph_callables.bw_dI is not None
480+
and stage.graph_callables.bw_dW is not None
481+
)
482+
)
483+
for stage in schedule._stages
484+
)
485+
self.schedule._has_backward = True
403486

404487
def _populate_stage_states(self, stage: GraphPipelineStage) -> None:
405488
sharded_params = [
@@ -415,21 +498,10 @@ def _populate_stage_states(self, stage: GraphPipelineStage) -> None:
415498
stage.state["sharded_params"] = sharded_params
416499
stage.state["buffers"] = buffers
417500
stage.state["unsharded_grads"] = [None] * len(sharded_params)
418-
# TODO (sanketpurandare)
419-
# pipeline schedule runtime does not allow us to register a custom function
420-
# for UNSHARD/RESHARD/REDUCE_GRAD action types yet
421-
# HACK remove this once we support this
422-
if stage.graph_callables.unshard is None:
423-
stage.state["unsharded_params"] = stage.state["sharded_params"]
424501

425502
def _accumulate_stage_grads_and_clear_states(
426503
self, stage: GraphPipelineStage
427504
) -> None:
428-
# TODO (sanketpurandare)
429-
# We don't have a REDUCE_GRAD action yet in the ScheduleIR yet
430-
# HACK remove this once Ivan's PR lands
431-
if stage.graph_callables.reduce_grad is None:
432-
stage.state["sharded_grads"] = stage.state["unsharded_grads"]
433505
grads = stage.state["sharded_grads"]
434506
params = list(stage.submod.parameters())
435507
for param, grad in zip(params, grads):

autoparallel/utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import Any, Iterable
7+
68
import torch
9+
import torch.utils._pytree as pytree
710
from torch.distributed._tensor.placement_types import Placement, TensorMeta
811
from torch.distributed.device_mesh import _get_device_handle
912
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@@ -310,3 +313,67 @@ def _get_device_from_mesh(mesh):
310313
return torch.device("cpu")
311314
device_handle = _get_device_handle(mesh.device_type)
312315
return torch.device(mesh.device_type, device_handle.current_device())
316+
317+
318+
# An FX graph interpreter that logs inputs and outputs of each node
319+
# with a few exceptions for c10d ops
320+
class DebugInterpreter(torch.fx.Interpreter):
321+
def __init__(self, *args, **kwargs):
322+
super().__init__(*args, **kwargs)
323+
self._logs = []
324+
325+
def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
326+
leaves, _ = pytree.tree_flatten(args)
327+
for i, arg in enumerate(leaves):
328+
if not isinstance(arg, torch.Tensor):
329+
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}]={arg}")
330+
continue
331+
332+
if arg.numel() == 0:
333+
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}].numel()=0")
334+
continue
335+
336+
if arg.is_complex():
337+
real = torch.hash_tensor(arg.real)
338+
imag = torch.hash_tensor(arg.imag)
339+
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}], {real=} {imag=}")
340+
continue
341+
342+
self._logs.append(
343+
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
344+
)
345+
346+
def run_node(self, n: torch.fx.Node) -> Any:
347+
args, kwargs = self.fetch_args_kwargs_from_env(n)
348+
349+
# reading wait_tensor inputs is undefined behavior
350+
if "wait_tensor" not in n.name:
351+
args, _ = self.fetch_args_kwargs_from_env(n)
352+
self.log(n.name, args, "args")
353+
354+
out = super().run_node(n)
355+
356+
# reading functional collectives outputs before wait_tensor is undefined behavior
357+
if "c10d" not in str(n.target):
358+
outs = out
359+
if isinstance(outs, torch.Tensor):
360+
outs = [outs]
361+
self.log(n.name, outs, "outs")
362+
363+
return out
364+
365+
def get_logs(self):
366+
return self._logs
367+
368+
369+
# Always prints from rank 0 to rank N
370+
def print_rank_by_rank(msg: Any):
371+
rank = torch.distributed.get_rank()
372+
world_size = torch.distributed.get_world_size()
373+
torch.distributed.barrier()
374+
for i in range(world_size):
375+
if rank == i:
376+
print(f"{rank=} start")
377+
print(msg)
378+
print(f"{rank=} done")
379+
torch.distributed.barrier()

0 commit comments

Comments
 (0)