Skip to content

Commit d3ab3b7

Browse files
Hybrid Sharding in Full Distributed FT (#2415)
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com> Co-authored-by: joecummings <jrcummings27@gmail.com>
1 parent 8dadbaa commit d3ab3b7

File tree

4 files changed

+127
-42
lines changed

4 files changed

+127
-42
lines changed

recipes/configs/llama3_3/70B_full_multinode.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ checkpointer:
3939
model_type: LLAMA3
4040
resume_from_checkpoint: False
4141

42+
# Parallelism tweaks
43+
tensor_parallel_dim: 8 # 8-way TP
44+
tensor_parallel_plan:
45+
_component_: torchtune.models.llama3.base_llama_tp_plan
46+
data_parallel_shard_dim: -1 # -1 means to infer based on other parallel dims & world size
47+
data_parallel_replicate_dim: 1
48+
4249
# Fine-tuning arguments
4350
batch_size: 4
4451
epochs: 1
@@ -53,7 +60,6 @@ loss:
5360
max_steps_per_epoch: null
5461
gradient_accumulation_steps: 1 # Use to increase effective batch size
5562

56-
5763
# Training env
5864
device: cuda
5965

recipes/full_finetune_distributed.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
from omegaconf import DictConfig, ListConfig
1616

1717
from torch import nn
18-
from torch.distributed import (
19-
destroy_process_group,
20-
init_device_mesh,
21-
init_process_group,
22-
)
18+
from torch.distributed import destroy_process_group, init_process_group
2319
from torch.distributed._tensor import DTensor
2420
from torch.distributed.tensor.parallel import parallelize_module
2521
from torch.optim import Optimizer
@@ -146,20 +142,31 @@ def __init__(self, cfg: DictConfig) -> None:
146142
# Initialize distributed variables
147143
self.world_size, self.rank = utils.get_world_size_and_rank()
148144
self._is_rank_zero = self.rank == 0
149-
self.tensor_parallel_plan = config.instantiate(
150-
cfg.get("tensor_parallel_plan", None)
151-
)
152-
self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1)
153-
if self.tensor_parallel_dim > 1 and self.tensor_parallel_plan is None:
145+
self.tp_plan = config.instantiate(cfg.get("tensor_parallel_plan", None))
146+
self.tp_degree = cfg.get("tensor_parallel_dim", 1)
147+
if self.tp_degree > 1 and self.tp_plan is None:
154148
raise ValueError(
155149
"Tensor Parallel plan needs to be provided when tensor parallel is enabled."
156150
)
157-
if self.world_size % self.tensor_parallel_dim != 0:
158-
raise ValueError(
159-
f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
151+
data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer
152+
data_replicate = cfg.get("data_parallel_replicate_dim", 1)
153+
154+
# Set up n-d device mesh
155+
self.parallel_dims = training.ParallelDims(
156+
dp_replicate=data_replicate,
157+
dp_shard=data_shard,
158+
tp=self.tp_degree,
159+
world_size=self.world_size,
160+
)
161+
self.world_mesh = self.parallel_dims.build_mesh(device_type=device_type)
162+
if self.parallel_dims.dp_enabled:
163+
dp_mesh = self.world_mesh["dp"]
164+
self.dp_degree, self.dp_rank = (
165+
dp_mesh.size(),
166+
dp_mesh.get_local_rank(),
160167
)
161-
162-
self.data_parallel_dim = self.world_size // self.tensor_parallel_dim
168+
else:
169+
self.dp_degree, self.dp_rank = 1, 0
163170

164171
# Logging attributes
165172
self._output_dir = cfg.output_dir
@@ -538,26 +545,18 @@ def _setup_model(
538545
if self._compile:
539546
training.compile_model(model, verbose=self._is_rank_zero)
540547

541-
device_mesh = init_device_mesh(
542-
self._device.type,
543-
mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim),
544-
mesh_dim_names=("dp", "tp"),
545-
)
546-
self.dp_size = device_mesh["dp"].size()
547-
self.dp_rank = device_mesh["dp"].get_local_rank()
548-
549548
# Apply tensor parallelism to the model
550-
if self.tensor_parallel_dim > 1:
551-
if self.data_parallel_dim == 1 and self.fsdp_cpu_offload:
549+
if self.parallel_dims.tp_enabled:
550+
if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload:
552551
raise ValueError(
553552
"Tensor parallelism is not supported with FSDP CPU offloading when data parallelism is disabled."
554553
)
555554
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
556-
model = training.prepare_mha_for_tp(model, device_mesh["tp"])
555+
model = training.prepare_mha_for_tp(model, self.world_mesh["tp"])
557556
parallelize_module(
558557
model,
559-
device_mesh["tp"],
560-
parallelize_plan=self.tensor_parallel_plan,
558+
self.world_mesh["tp"],
559+
parallelize_plan=self.tp_plan,
561560
)
562561

563562
# We currently have two versions of activation checkpointing in this recipe
@@ -580,19 +579,25 @@ def _setup_model(
580579
)
581580

582581
# Apply Fully Sharded Data Parallelism to the model
583-
if self.data_parallel_dim > 1:
582+
if self.parallel_dims.dp_shard_enabled:
584583
fsdp_shard_conditions = [
585584
partial(
586585
training.get_shard_conditions,
587586
names_to_match=custom_sharded_layers,
588587
)
589588
]
589+
590+
if self.parallel_dims.dp_replicate_enabled:
591+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
592+
else:
593+
dp_mesh_dim_names = ("dp_shard",)
594+
590595
training.shard_model(
591596
model=model,
592597
shard_conditions=fsdp_shard_conditions,
593598
cpu_offload=fsdp_cpu_offload,
594599
reshard_after_forward=reshard_after_forward,
595-
dp_mesh=device_mesh["dp"],
600+
dp_mesh=self.world_mesh[dp_mesh_dim_names],
596601
)
597602

598603
with training.set_default_dtype(self._dtype), self._device:
@@ -629,7 +634,7 @@ def _setup_model(
629634
training.log_memory_stats(memory_stats)
630635

631636
# synchronize before training begins
632-
torch.distributed.barrier()
637+
torch.distributed.barrier(device_ids=[self._device.index])
633638

634639
return model
635640

@@ -716,7 +721,7 @@ def _setup_data(
716721
collate_fn = _get_component_from_path(collate_fn)
717722

718723
sampler = StatefulDistributedSampler(
719-
ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle
724+
ds, num_replicas=self.dp_degree, rank=self.dp_rank, shuffle=shuffle, seed=0
720725
)
721726
dataloader = StatefulDataLoader(
722727
dataset=ds,
@@ -727,7 +732,7 @@ def _setup_data(
727732
collate_fn,
728733
padding_idx=self._tokenizer.pad_id,
729734
ignore_idx=self._loss_fn.ignore_index,
730-
pad_to_multiple_of=self.tensor_parallel_dim,
735+
pad_to_multiple_of=self.tp_degree,
731736
)
732737
if not packed
733738
else padded_collate_packed
@@ -811,22 +816,18 @@ def train(self) -> None:
811816
if self._optimizer_in_bwd:
812817
torch.distributed.all_reduce(num_tokens)
813818
torch.distributed.all_reduce(running_loss)
814-
815-
# We multiply by world_size to undo FSDP2 gradient normalization.
816-
current_loss = current_loss * (self.dp_size / num_tokens)
819+
current_loss = current_loss * (self.dp_degree / num_tokens)
817820

818821
current_loss.backward()
819-
820-
# Step with optimizer
822+
# Optimizer step (if not fused in backward call)
821823
if (idx + 1) % self._gradient_accumulation_steps == 0:
822824
if not self._optimizer_in_bwd:
823825
# Get total number of tokens across all ranks to normalize gradients
824826
torch.distributed.all_reduce(num_tokens)
825827
# This will ensure that the logged loss matches what we're optimizing
826828
torch.distributed.all_reduce(running_loss)
827829
# Manually scale the gradients from unnormalized loss by total # of tokens
828-
# We multiply by world_size to undo FSDP2 gradient normalization.
829-
training.scale_grads(self._model, self.dp_size / num_tokens)
830+
training.scale_grads(self._model, self.dp_degree / num_tokens)
830831
if self._clip_grad_norm is not None:
831832
grad_norm = torch.nn.utils.clip_grad_norm_(
832833
self._model.parameters(),

torchtune/training/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
is_distributed,
2020
load_from_full_model_state_dict,
2121
load_from_full_optimizer_state_dict,
22+
ParallelDims,
2223
prepare_mha_for_tp,
2324
set_torch_num_threads,
2425
shard_model,
@@ -91,6 +92,7 @@
9192
"Checkpointer",
9293
"update_state_dict_for_classifier",
9394
"ADAPTER_CONFIG",
95+
"ParallelDims",
9496
"ADAPTER_KEY",
9597
"EPOCHS_KEY",
9698
"MAX_STEPS_KEY",

torchtune/training/_distributed.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import logging
99
import os
10+
from dataclasses import dataclass
1011
from itertools import chain
1112
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
1213

@@ -24,7 +25,7 @@
2425
set_optimizer_state_dict,
2526
StateDictOptions,
2627
)
27-
from torch.distributed.device_mesh import DeviceMesh
28+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
2829
from torch.distributed.fsdp import ShardingStrategy
2930
from torch.nn.modules.module import _IncompatibleKeys
3031
from torch.optim import Optimizer
@@ -47,6 +48,81 @@
4748
_DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = False
4849

4950

51+
@dataclass
52+
class ParallelDims:
53+
dp_replicate: int
54+
dp_shard: int
55+
tp: int
56+
world_size: int
57+
58+
def __post_init__(self):
59+
self._validate()
60+
61+
def _validate(self):
62+
dp_replicate, dp_shard, tp = (
63+
self.dp_replicate,
64+
self.dp_shard,
65+
self.tp,
66+
)
67+
for d in (dp_replicate, tp):
68+
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
69+
70+
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
71+
if dp_shard < 0:
72+
self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp)
73+
assert dp_shard >= 1
74+
75+
assert dp_replicate * dp_shard * tp == self.world_size, (
76+
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
77+
f"tp({tp}) != WORLD_SIZE({self.world_size})"
78+
)
79+
80+
def build_mesh(self, device_type):
81+
dims = []
82+
names = []
83+
for d, name in zip(
84+
[self.dp_replicate, self.dp_shard, self.tp],
85+
["dp_replicate", "dp_shard", "tp"],
86+
):
87+
if d > 1:
88+
dims.append(d)
89+
names.append(name)
90+
91+
names = tuple(names)
92+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
93+
94+
# Create all the submesh here to ensure all required process groups are
95+
# initialized:
96+
# Mesh for data loading (no communication on this mesh)
97+
dp_mesh_dim_names = []
98+
99+
if self.dp_replicate_enabled:
100+
dp_mesh_dim_names.append("dp_replicate")
101+
if self.dp_shard_enabled:
102+
dp_mesh_dim_names.append("dp_shard")
103+
104+
if dp_mesh_dim_names != []:
105+
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
106+
107+
return mesh
108+
109+
@property
110+
def dp_enabled(self):
111+
return self.dp_replicate > 1 or self.dp_shard > 1
112+
113+
@property
114+
def dp_replicate_enabled(self):
115+
return self.dp_replicate > 1
116+
117+
@property
118+
def dp_shard_enabled(self):
119+
return self.dp_shard > 1
120+
121+
@property
122+
def tp_enabled(self):
123+
return self.tp > 1
124+
125+
50126
def _get_sharding_strategy(strategy: str) -> ShardingStrategy:
51127
"""Helper function to convert sharding strategy strings to ShardingStrategy enum."""
52128
return getattr(ShardingStrategy, strategy)

0 commit comments

Comments
 (0)