Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions autoparallel/_testing/models/dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# HOPs don't support buffer mutations, keep this outside
with torch.no_grad():
self.tokens_per_expert.add_(num_tokens_per_expert)
self.tokens_per_expert.add_(num_tokens_per_expert) # type: ignore[operator]
return out

def init_weights(
Expand All @@ -1076,14 +1076,10 @@ def init_weights(
self.shared_experts.init_weights(init_std)

with torch.device(buffer_device):
self.tokens_per_expert = torch.zeros(
self.experts.num_experts, dtype=torch.float32
)
self.tokens_per_expert.zero_() # type: ignore[operator]
if self.load_balance_coeff is not None:
assert isinstance(self.expert_bias, torch.Tensor)
self.expert_bias = torch.zeros(
self.experts.num_experts, dtype=torch.float32
)
self.expert_bias.zero_() # type: ignore[operator]


def has_cuda_capability(major: int, minor: int) -> bool:
Expand Down
116 changes: 94 additions & 22 deletions autoparallel/graph_pp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import logging
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union, cast

Expand All @@ -20,6 +21,11 @@
)
from torch.distributed.tensor import DTensor

from autoparallel.utils import DebugInterpreter

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


@dataclass
class GraphCallables:
Expand Down Expand Up @@ -75,14 +81,39 @@ def __init__(
"unsharded_grads": [],
}

def scale_grads(self, grad_scale_factor: int) -> None:
"""Scale stage's gradients by `grad_scale_factor`, which should be specified in coordination with the
loss function used with pipelining. For loss functions which perform 'mean' loss reduction, `grad_scale_factor`
should be set to num_microbatches. For loss functions that use `sum` reduction, `grad_scale_factor` should
be set to 1.

Should only be called once per pipeline schedule step, after all backwards passes have completed.
"""

# PP scales only for its own contribution (microbatches), but relies on DP to scale further
# for DP degree.
if grad_scale_factor != 1:
for grad in self.state["unsharded_grads"]:
if grad is not None:
grad.div_(grad_scale_factor)


def _run_fw_module(
fw_module: fx.GraphModule, graph_meta: GraphMeta, fw_args: list[Any]
fw_module: fx.GraphModule,
graph_meta: GraphMeta,
fw_args: list[Any],
numerics_logs: Optional[list[str]] = None,
) -> tuple[Any, tuple[list[Any], list[Any]]]:
assert len([n for n in fw_module.graph.nodes if n.op == "placeholder"]) == len(
fw_args
), f"Mismatched number of inputs to fwd, {len([n for n in fw_module.graph.nodes if n.op == 'placeholder'])}, {len(fw_args)}"
fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args)
if numerics_logs is not None:
debug_interpreter = DebugInterpreter(fw_module)
fw_outputs = debug_interpreter.boxed_run(fw_args)
numerics_logs += debug_interpreter.get_logs()
else:
fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args)

num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs
saved_intermediates = fw_outputs[num_inner_fwd_outputs:]
num_tensors_for_backward = (
Expand Down Expand Up @@ -153,14 +184,16 @@ def _run_reduce_grad_module(
return sharded_grads


def _run_forward_microbatch(stage: GraphPipelineStage, *args) -> tuple[Any, Any]:
def _run_forward_microbatch(
stage: GraphPipelineStage, *args, numerics_logs: Optional[list[str]] = None
) -> tuple[Any, Any]:
fw_args = [
*stage.state["unsharded_params"],
*stage.state["buffers"],
*args,
]
user_outputs, saved_intermediates = _run_fw_module(
stage.graph_callables.fw, stage.graph_meta, fw_args
stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs
)
return (user_outputs, saved_intermediates)

Expand Down Expand Up @@ -200,6 +233,7 @@ def _run_backward_microbatch(
def stage_forward(
action: _Action,
ctx: _PipelineContext,
numerics_logs: Optional[list[str]] = None,
) -> None:
schedule = ctx.schedule_ref
assert isinstance(schedule, _PipelineScheduleRuntime)
Expand Down Expand Up @@ -243,8 +277,13 @@ def stage_forward(
composite_args = stage._retrieve_recv_activations(mb_index)

# stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args?

output, saved_intermediates = _run_forward_microbatch(stage, *composite_args)
logger.debug(
"GraphPPRunner running action %s",
action,
)
output, saved_intermediates = _run_forward_microbatch(
stage, *composite_args, numerics_logs=numerics_logs
)

# See [Note: pipeline model output type]
output_tuple = _normalize_model_output_as_tuple(output)
Expand Down Expand Up @@ -306,6 +345,7 @@ def stage_full_backward(
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1

if not backward_stage.has_backward:
logger.debug("Returning early for backward stage")
return
(
stage_output,
Expand All @@ -320,7 +360,7 @@ def stage_full_backward(
# HACK till we have loss function, we populate the tangents here manually
bwd_kwargs = {
"stage_output": loss,
"tangents": [torch.randn_like(stage_output)],
"tangents": [torch.randn_like(stage_output[0])],
"saved_intermediates": saved_intermediates,
}
else:
Expand All @@ -334,10 +374,14 @@ def stage_full_backward(
"tangents": output_grads,
"saved_intermediates": saved_intermediates,
}

logger.debug(
"GraphPPRunner running action %s",
action,
)
input_grads = _run_backward_microbatch(backward_stage, bwd_kwargs)

backward_stage.bwd_cache[backward_mb_index] = input_grads
backward_stage.bwd_cache[backward_mb_index] = (
tuple(input_grads) if not isinstance(input_grads, tuple) else input_grads
)

# skipping detach logic

Expand All @@ -362,9 +406,20 @@ def stage_unshard(
stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages
}
stage = stage_index_to_stage[action.stage_index]
logger.debug(
"GraphPPRunner running action %s",
action,
)
if stage.graph_callables.unshard is None:
stage.state["unsharded_params"] = stage.state["sharded_params"]
# TODO (sanketpurandare): Add the fw_fsdp_all_gather graph call here
else:
sharded_params = list(stage.state["sharded_params"])
unsharded_params = _run_unshard_module(
stage.graph_callables.unshard,
stage.graph_meta,
sharded_params,
)
stage.state["unsharded_params"] = unsharded_params


def stage_reshard(
Expand All @@ -377,6 +432,10 @@ def stage_reshard(
stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages
}
stage = stage_index_to_stage[action.stage_index]
logger.debug(
"GraphPPRunner running action %s",
action,
)
stage.state["unsharded_params"].clear()


Expand All @@ -390,8 +449,19 @@ def stage_reduce_grad(
stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages
}
stage = stage_index_to_stage[action.stage_index]
logger.debug(
"GraphPPRunner running action %s",
action,
)
if stage.graph_callables.reduce_grad is None:
stage.state["sharded_grads"] = stage.state["unsharded_grads"]
else:
sharded_grads = _run_reduce_grad_module(
stage.graph_callables.reduce_grad,
stage.graph_meta,
stage.state["unsharded_grads"],
)
stage.state["sharded_grads"] = sharded_grads


class GraphPPRunner:
Expand All @@ -400,6 +470,19 @@ def __init__(
schedule: _PipelineScheduleRuntime,
):
self.schedule = schedule
if not schedule._backward_requires_autograd:
assert all(
isinstance(stage, GraphPipelineStage)
and (
stage.graph_callables.full_bw is not None
or (
stage.graph_callables.bw_dI is not None
and stage.graph_callables.bw_dW is not None
)
)
for stage in schedule._stages
)
self.schedule._has_backward = True

def _populate_stage_states(self, stage: GraphPipelineStage) -> None:
sharded_params = [
Expand All @@ -415,21 +498,10 @@ def _populate_stage_states(self, stage: GraphPipelineStage) -> None:
stage.state["sharded_params"] = sharded_params
stage.state["buffers"] = buffers
stage.state["unsharded_grads"] = [None] * len(sharded_params)
# TODO (sanketpurandare)
# pipeline schedule runtime does not allow us to register a custom function
# for UNSHARD/RESHARD/REDUCE_GRAD action types yet
# HACK remove this once we support this
if stage.graph_callables.unshard is None:
stage.state["unsharded_params"] = stage.state["sharded_params"]

def _accumulate_stage_grads_and_clear_states(
self, stage: GraphPipelineStage
) -> None:
# TODO (sanketpurandare)
# We don't have a REDUCE_GRAD action yet in the ScheduleIR yet
# HACK remove this once Ivan's PR lands
if stage.graph_callables.reduce_grad is None:
stage.state["sharded_grads"] = stage.state["unsharded_grads"]
grads = stage.state["sharded_grads"]
params = list(stage.submod.parameters())
for param, grad in zip(params, grads):
Expand Down
67 changes: 67 additions & 0 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Iterable

import torch
import torch.utils._pytree as pytree
from torch.distributed._tensor.placement_types import Placement, TensorMeta
from torch.distributed.device_mesh import _get_device_handle
from torch.distributed.tensor._dtensor_spec import DTensorSpec
Expand Down Expand Up @@ -310,3 +313,67 @@ def _get_device_from_mesh(mesh):
return torch.device("cpu")
device_handle = _get_device_handle(mesh.device_type)
return torch.device(mesh.device_type, device_handle.current_device())


# An FX graph interpreter that logs inputs and outputs of each node
# with a few exceptions for c10d ops
class DebugInterpreter(torch.fx.Interpreter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._logs = []

def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
leaves, _ = pytree.tree_flatten(args)
for i, arg in enumerate(leaves):
if not isinstance(arg, torch.Tensor):
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}]={arg}")
continue

if arg.numel() == 0:
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}].numel()=0")
continue

if arg.is_complex():
real = torch.hash_tensor(arg.real)
imag = torch.hash_tensor(arg.imag)
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}], {real=} {imag=}")
continue

self._logs.append(
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
)

def run_node(self, n: torch.fx.Node) -> Any:
args, kwargs = self.fetch_args_kwargs_from_env(n)

# reading wait_tensor inputs is undefined behavior
if "wait_tensor" not in n.name:
args, _ = self.fetch_args_kwargs_from_env(n)
self.log(n.name, args, "args")

out = super().run_node(n)

# reading functional collectives outputs before wait_tensor is undefined behavior
if "c10d" not in str(n.target):
outs = out
if isinstance(outs, torch.Tensor):
outs = [outs]
self.log(n.name, outs, "outs")

return out

def get_logs(self):
return self._logs


# Always prints from rank 0 to rank N
def print_rank_by_rank(msg: Any):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.distributed.barrier()
for i in range(world_size):
if rank == i:
print(f"{rank=} start")
print(msg)
print(f"{rank=} done")
torch.distributed.barrier()
Loading