Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE][5/n] simplify pp vs. non-pp set up #510

Merged
merged 3 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 12 additions & 21 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,33 +122,25 @@ def loss_fn(pred, labels):
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)
model = model_cls.from_model_args(model_config)

# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(whole_model)
float8_handler.convert_to_float8_training(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
model_parts = [
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
for m in model_parts
]

init_device = "cuda"
for model in model_parts:
model.to_empty(device=init_device)
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc this set of changes is just to clean up estimation to be non-pp compatible, and estimation was copy-pasted from train.py which is why it had model_chunks in the first place. This makes sense to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. In general, estimation.py only supports eager FSDP and doesn't support TP/PP/compile. So let's keep this file simple.


model.to_empty(device="cuda")
if not active_fake_mode():
whole_model.init_weights()
model.init_weights()
model.train()

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
optimizers = build_optimizers([model], job_config)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)

for model in model_parts:
model.train()
logger.info(f"Vocab size: {model_config.vocab_size}")
# Create a dummy batch instead of loading from a dataset
batch = (
Expand All @@ -165,24 +157,23 @@ def loss_fn(pred, labels):
device="cuda",
),
)
fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0])
fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0])
fsdp_memtracker.track_inputs(batch)

with fsdp_memtracker:
for iter_idx in range(2):
input_ids, labels = batch
# train step
with train_context():
pred = whole_model(input_ids)
pred = model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()

# clip gradients
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
# optimizer step
Expand Down
2 changes: 0 additions & 2 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
from torchtitan.parallelisms.pipeline_llama import pipeline_llama
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule


__all__ = [
"build_pipeline_schedule",
"models_parallelize_fns",
"models_pipelining_fns",
"ParallelDims",
Expand Down
22 changes: 8 additions & 14 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def parallelize_llama(
and not job_config.training.compile
):
raise RuntimeError("Async TP requires --training.compile")
model = apply_tp(
apply_tp(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
Expand All @@ -60,7 +60,7 @@ def parallelize_llama(
)

if job_config.activation_checkpoint.mode != "none":
model = apply_ac(model, job_config.activation_checkpoint)
apply_ac(model, job_config.activation_checkpoint)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if job_config.training.compile:
Expand All @@ -69,14 +69,14 @@ def parallelize_llama(
"fused_rmsnorm is not compatible with torch.compile yet. "
"Please use rmsnorm or layernorm."
)
model = apply_compile(model)
apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_type == "fsdp":
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names

model = apply_fsdp(
apply_fsdp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this change break anything? IIRC one feature required us to keep returning the model. Maybe it was AC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one feature required us to keep returning the model.

Great point. Yes torch.compile and AC require reassigning the model. But since we are doing per-block compile and AC, we achieve that in-place for the whole model by

transformer_block = compile/AC(transformer_block)
model.layers.register_module(layer_id, transformer_block)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any motivation of removing the assignment? I thought an explicit assignment does not look bad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol
I actually don't know why these functions have redundant semantics, i.e. modifying a module in-place, but at the same time return it explicitly. I'm modifying it because:

  1. Before this PR, the PP branch explicitly reassign the returned module, but the SPMD branch doesn't. I think we should use the minimum viable code to reduce confusion.
  2. IIRC @awgu had a PR which removes the reassigning for FSDP2 fully_shard, so in some sense I'm mimicking that PR.

model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
Expand All @@ -88,15 +88,13 @@ def parallelize_llama(
else:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
model = apply_ddp(
apply_ddp(
model,
world_mesh,
enable_compile=job_config.training.compile,
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
)

return model


def apply_tp(
model: nn.Module,
Expand All @@ -110,7 +108,7 @@ def apply_tp(
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Parallelize the final linear output layer
model = parallelize_module(
parallelize_module(
model,
tp_mesh,
{
Expand Down Expand Up @@ -192,7 +190,6 @@ def apply_tp(
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
"Tensor Parallelism to the model"
)
return model


# for selective op activation checkpointing
Expand Down Expand Up @@ -273,7 +270,6 @@ def apply_ac(model: nn.Module, ac_config):
model.layers.register_module(layer_id, transformer_block)

logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
return model


def apply_compile(model: nn.Module):
Expand All @@ -286,7 +282,6 @@ def apply_compile(model: nn.Module):
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")
return model


def apply_fsdp(
Expand Down Expand Up @@ -329,8 +324,8 @@ def apply_fsdp(
module._load_state_dict_pre_hooks.clear()
assert len(module._state_dict_pre_hooks) <= 1
module._state_dict_pre_hooks.clear()

logger.info("Applied FSDP to the model")
return model


def apply_ddp(
Expand All @@ -347,7 +342,6 @@ def apply_ddp(
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"

model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)

logger.info("Applied DDP to the model")
return model
18 changes: 13 additions & 5 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# This file applies the PT-D pipeline parallelism to the Llama model.

import copy
from typing import Union
from typing import Callable, Union

import torch
import torch.nn as nn
Expand All @@ -18,7 +18,10 @@
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank
from torchtitan.parallelisms.pipelining_utils import (
build_pipeline_schedule,
stage_ids_this_rank,
)


DeviceType = Union[int, str, torch.device]
Expand All @@ -31,6 +34,7 @@ def pipeline_llama(
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
loss_fn: Callable[..., torch.Tensor],
):
split_mode = job_config.experimental.pipeline_parallel_split_mode
valid_split_modes = ("manual", "tracer")
Expand All @@ -39,14 +43,18 @@ def pipeline_llama(
f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}"
)
if split_mode == "manual":
return pipeline_llama_manual(
stages, models = pipeline_llama_manual(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
elif split_mode == "tracer":
return pipeline_llama_tracer(
stages, models = pipeline_llama_tracer(
model, pp_mesh, parallel_dims, job_config, device, model_config
)

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

return pp_schedule, models


def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
"""Get meta tensors with the right input shapes used for tracing"""
Expand Down Expand Up @@ -218,4 +226,4 @@ def pipeline_llama_tracer(
group=pp_mesh.get_group(),
)
)
return (stages, models)
return stages, models
2 changes: 1 addition & 1 deletion torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchtitan.logging import logger


def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
def build_pipeline_schedule(job_config, stages, loss_fn):
looped_schedule = False

if job_config.experimental.pipeline_parallel_schedule == "1f1b":
Expand Down
Loading
Loading