Skip to content

Commit

Permalink
Add support of DDP and experimental CompiledAutograd
Browse files Browse the repository at this point in the history
Summary:
Address the comments in pytorch#319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: 81dc85d42df13df4ed727bebd825681879af936b
Pull Request resolved: pytorch#432
  • Loading branch information
fegin committed Jul 18, 2024
1 parent 69fe8de commit 2f989b9
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 8 deletions.
1 change: 1 addition & 0 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def estimate_memory(job_config: JobConfig):
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
Expand Down
9 changes: 9 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,15 @@ def build_test_list():
],
"FSDP2 with float8 all-gather and precomputed dynamic scales",
"fsdp2_float8_all_gather_precompute_dynamic_scales",
),
OverrideDefinitions(
[
[
"--training.data_parallel_type ddp",
]
],
"DDP",
"ddp",
ngpu=4,
),
]
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,17 @@ def __init__(self):
The default value will be the number of pipeline stages, if unspecified.
""",
)
self.parser.add_argument(
"--training.data_parallel_type",
type=str,
default="fsdp",
help="Data parallelism type. TorchTitan currently supports FSDP and DDP.",
)
self.parser.add_argument(
"--experimental.enable_compiled_autograd",
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ class ParallelDims:
pp: int
world_size: int
enable_loss_parallel: bool
dp_type: str

def __post_init__(self):
self.dp_type = self.dp_type.lower()
self._validate()

def _validate(self):
Expand All @@ -42,6 +44,7 @@ def _validate(self):
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert self.dp_type in ("fsdp", "ddp")

def build_mesh(self, device_type):
dims = []
Expand Down
36 changes: 33 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from torch.distributed import DeviceMesh

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
Expand Down Expand Up @@ -453,13 +455,15 @@ def apply_compile(model: nn.Module, job_config: JobConfig):
return model


def apply_dp(
def apply_fsdp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
):
"""Apply data parallelism (FSDP2) to the model."""
"""
Apply data parallelism to the model. FSDP2 is used here.
"""

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
Expand Down Expand Up @@ -492,6 +496,29 @@ def apply_dp(
return model


def apply_ddp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
):
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism.")

if job_config.training.compile:
if job_config.experimental.enable_compiled_autograd:
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
)
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"

model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100)

logger.info("Applied DDP to the model")
return model


def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
Expand All @@ -516,6 +543,9 @@ def parallelize_llama(
model = apply_compile(model, job_config)

if parallel_dims.dp_enabled:
model = apply_dp(model, world_mesh, parallel_dims, job_config)
if parallel_dims.dp_type == "fsdp":
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)
else:
model = apply_ddp(model, world_mesh, parallel_dims, job_config)

return model
27 changes: 22 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ def zero_grad(self):
return OptimizersContainer([_build_optimizer(model) for model in model_parts])


def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
@contextlib.contextmanager
def context():
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(loss_parallel())
if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

yield

return context


# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def main(job_config: JobConfig):
Expand All @@ -160,6 +176,7 @@ def main(job_config: JobConfig):
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
Expand Down Expand Up @@ -194,9 +211,9 @@ def main(job_config: JobConfig):
dp_rank,
)

# loss_parallel enables dispatching to efficient loss operators
loss_parallel_ctx = (
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
)

# loss fn can be shared by pipeline-parallel or non-pp execution
Expand Down Expand Up @@ -364,7 +381,7 @@ def loss_fn(pred, labels):
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with loss_parallel_ctx():
with train_context():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
Expand All @@ -381,7 +398,7 @@ def loss_fn(pred, labels):
)
else:
# Non-PP forward / backward
with loss_parallel_ctx():
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
Expand Down

0 comments on commit 2f989b9

Please sign in to comment.