Skip to content

Commit

Permalink
Add support of DDP and experimental CompiledAutograd
Browse files Browse the repository at this point in the history
Summary:
Address the comments in #319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: a3131b7b0a835576992dd86d010f53866da4aa9d
Pull Request resolved: #432
  • Loading branch information
fegin committed Jul 9, 2024
1 parent 3fca883 commit 958cac9
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 8 deletions.
1 change: 1 addition & 0 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def estimate_memory(job_config: JobConfig):
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.experimental.data_parallel_type,
)

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
Expand Down
13 changes: 13 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,19 @@ def build_test_list():
"fsdp2_mem_tracker",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 1",
"--training.data_parallel_degree 8",
"--experimental.data_parallel_type ddp",
"--experimental.enable_compiled_autograd",
]
],
"CompiledDDP",
"compiled_ddp",
ngpu=8,
),
]
return integration_tests_flavors

Expand Down
11 changes: 11 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,17 @@ def __init__(self):
The default value will be the number of pipeline stages, if unspecified.
""",
)
self.parser.add_argument(
"--experimental.data_parallel_type",
type=str,
default="fsdp",
help="Data parallelism type. TorchTitan currently supports FSDP and DDP.",
)
self.parser.add_argument(
"--experimental.enable_compiled_autograd",
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ class ParallelDims:
pp: int
world_size: int
enable_loss_parallel: bool
dp_type: str

def __post_init__(self):
self.dp_type = self.dp_type.lower()
self._validate()

def _validate(self):
Expand All @@ -42,6 +44,7 @@ def _validate(self):
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert self.dp_type in ("fsdp", "ddp")

def build_mesh(self, device_type):
dims = []
Expand Down
28 changes: 25 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from typing import Dict, Tuple

import torch

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
Expand Down Expand Up @@ -452,7 +453,7 @@ def apply_compile(model, job_config: JobConfig):
return model


def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
def apply_fsdp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply data parallelism to the model. FSDP2 is used here.
"""
Expand Down Expand Up @@ -489,6 +490,24 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
return model


def apply_ddp(model, world_mesh, parallel_dims, job_config: JobConfig):
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism.")

if job_config.training.compile:
if job_config.experimental.compiled_autograd:
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
)
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"

model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100)

logger.info("Applied DDP to the model")
return model


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply tensor parallelism, activation checkpointing, torch.compile, and data
Expand All @@ -508,6 +527,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
model = apply_compile(model, job_config)

if parallel_dims.dp_enabled:
model = apply_dp(model, world_mesh, parallel_dims, job_config)
if parallel_dims.dp_type == "fsdp":
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)
else:
model = apply_ddp(model, world_mesh, parallel_dims, job_config)

return model
27 changes: 22 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ def zero_grad(self):
return OptimizersContainer([_build_optimizer(model) for model in model_parts])


def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
@contextlib.contextmanager
def context():
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(loss_parallel())
if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

yield

return context


# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def main(job_config: JobConfig):
Expand All @@ -157,6 +173,7 @@ def main(job_config: JobConfig):
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.experimental.data_parallel_type,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
Expand Down Expand Up @@ -191,9 +208,9 @@ def main(job_config: JobConfig):
dp_rank,
)

# loss_parallel enables dispatching to efficient loss operators
loss_parallel_ctx = (
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
)

# loss fn can be shared by pipeline-parallel or non-pp execution
Expand Down Expand Up @@ -362,7 +379,7 @@ def loss_fn(pred, labels):
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with loss_parallel_ctx():
with train_context():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
Expand All @@ -379,7 +396,7 @@ def loss_fn(pred, labels):
)
else:
# Non-PP forward / backward
with loss_parallel_ctx():
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
Expand Down

0 comments on commit 958cac9

Please sign in to comment.