Skip to content

Commit

Permalink
Update on "[BE][5/n] simplify pp vs. non-pp set up"
Browse files Browse the repository at this point in the history
This PR restructures the PP vs. non-PP setup in `train.py`:
- Now we only have two main if-else for PP vs. non-PP, one in setup phase, the other in training phase.
- I think it's already clear to read or copy-paste, and it's not necessary to create separate sub-functions to hold the code.

This PR also removes unnecessary module returns in `parallelize_llama`, as we are modifying module in-place. Note that torch.compile and AC require returning and reassigning the module. But since we are doing per-block compile and AC, we achieve that in-place for the whole model by
```
transformer_block = compile/AC(transformer_block)
model.layers.register_module(layer_id, transformer_block)
``` 

[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Aug 7, 2024
1 parent 338f183 commit f58ca70
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 35 deletions.
33 changes: 12 additions & 21 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,33 +122,25 @@ def loss_fn(pred, labels):
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)
model = model_cls.from_model_args(model_config)

# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(whole_model)
float8_handler.convert_to_float8_training(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
model_parts = [
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
for m in model_parts
]

init_device = "cuda"
for model in model_parts:
model.to_empty(device=init_device)
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

model.to_empty(device="cuda")
if not active_fake_mode():
whole_model.init_weights()
model.init_weights()
model.train()

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
optimizers = build_optimizers([model], job_config)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)

for model in model_parts:
model.train()
logger.info(f"Vocab size: {model_config.vocab_size}")
# Create a dummy batch instead of loading from a dataset
batch = (
Expand All @@ -165,24 +157,23 @@ def loss_fn(pred, labels):
device="cuda",
),
)
fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0])
fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0])
fsdp_memtracker.track_inputs(batch)

with fsdp_memtracker:
for iter_idx in range(2):
input_ids, labels = batch
# train step
with train_context():
pred = whole_model(input_ids)
pred = model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()

# clip gradients
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
# optimizer step
Expand Down
2 changes: 0 additions & 2 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
from torchtitan.parallelisms.pipeline_llama import pipeline_llama
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule


__all__ = [
"build_pipeline_schedule",
"models_parallelize_fns",
"models_pipelining_fns",
"ParallelDims",
Expand Down
18 changes: 13 additions & 5 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# This file applies the PT-D pipeline parallelism to the Llama model.

import copy
from typing import Union
from typing import Callable, Union

import torch
import torch.nn as nn
Expand All @@ -18,7 +18,10 @@
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank
from torchtitan.parallelisms.pipelining_utils import (
build_pipeline_schedule,
stage_ids_this_rank,
)


DeviceType = Union[int, str, torch.device]
Expand All @@ -31,6 +34,7 @@ def pipeline_llama(
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
loss_fn: Callable[..., torch.Tensor],
):
split_mode = job_config.experimental.pipeline_parallel_split_mode
valid_split_modes = ("manual", "tracer")
Expand All @@ -39,14 +43,18 @@ def pipeline_llama(
f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}"
)
if split_mode == "manual":
return pipeline_llama_manual(
stages, models = pipeline_llama_manual(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
elif split_mode == "tracer":
return pipeline_llama_tracer(
stages, models = pipeline_llama_tracer(
model, pp_mesh, parallel_dims, job_config, device, model_config
)

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

return pp_schedule, models


def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
"""Get meta tensors with the right input shapes used for tracing"""
Expand Down Expand Up @@ -218,4 +226,4 @@ def pipeline_llama_tracer(
group=pp_mesh.get_group(),
)
)
return (stages, models)
return stages, models
2 changes: 1 addition & 1 deletion torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchtitan.logging import logger


def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
def build_pipeline_schedule(job_config, stages, loss_fn):
looped_schedule = False

if job_config.experimental.pipeline_parallel_schedule == "1f1b":
Expand Down
8 changes: 2 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import (
build_pipeline_schedule,
models_parallelize_fns,
models_pipelining_fns,
ParallelDims,
Expand Down Expand Up @@ -143,11 +142,8 @@ def loss_fn(pred, labels):
# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
stages, model_parts = models_pipelining_fns[model_name](
model, pp_mesh, parallel_dims, job_config, device, model_config
)
pp_schedule = build_pipeline_schedule(
job_config, parallel_dims, stages, loss_fn
pp_schedule, model_parts = models_pipelining_fns[model_name](
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
Expand Down

0 comments on commit f58ca70

Please sign in to comment.