Skip to content

Commit 321a888

Browse files
committed
model fragments for diloco
Summary: - add a configuration option for users to provide how they want to partition the model - if this is provided, the model needs to implement `FaultTolerantTrainingSpec` that defines the framentation function to split the model based on the configuration - determine the model fragments in training script to pass to ft manager Test Plan: Running llama3 8b parameters with 2 fragments, 1 step delay, each fragment gets synced every 20 steps <img width="944" height="545" alt="image" src="https://github.com/user-attachments/assets/6d16f486-7260-49d6-8ba3-3e98cd331e58" />
1 parent f3e2a75 commit 321a888

File tree

6 files changed

+287
-6
lines changed

6 files changed

+287
-6
lines changed

torchtitan/config/job_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,18 @@ class FaultTolerance:
661661
This is only used when "semi_sync_method" is set.
662662
"""
663663

664+
module_names_per_model_chunk: list[list[str]] = field(default_factory=list)
665+
"""
666+
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk.
667+
Each inner list represents one model chunk and contains the module names that belong to that chunk.
668+
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
669+
will create 3 chunks: the first containing tok_embeddings and layers.0,
670+
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
671+
This provides more explicit control over which modules belong to each chunk compared to split points.
672+
"""
673+
674+
num_fragments: int = 1
675+
664676

665677
@dataclass
666678
class Experimental:

torchtitan/distributed/pipeline.py

Lines changed: 200 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import os
77
from typing import Callable
88

9+
from torch import nn
10+
911
from torch.distributed.pipelining.schedules import (
1012
_PipelineSchedule,
1113
_PipelineScheduleRuntime,
@@ -19,7 +21,13 @@
1921
from torchtitan.tools.logging import logger
2022

2123

22-
__all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"]
24+
__all__ = [
25+
"build_pipeline_schedule",
26+
"generate_split_points",
27+
"stage_ids_this_rank",
28+
"generate_module_names_per_stage",
29+
"module_split",
30+
]
2331

2432

2533
# TODO: It's unclear if this API is general enough to be used by other models.
@@ -206,6 +214,196 @@ def stage_ids_this_rank(
206214
stages_per_rank == 2
207215
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
208216
stage_v_pairs = list(
209-
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
217+
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1), strict=True)
210218
)
211219
return stage_v_pairs[pp_rank]
220+
221+
222+
def generate_module_names_per_stage(
223+
num_stages: int,
224+
num_layers: int,
225+
input_weight: int = 1,
226+
output_weight: int = 1,
227+
) -> list[list[str]]:
228+
"""
229+
Programmatically generates module names per stage for pipeline parallelism with weighting.
230+
231+
Args:
232+
num_stages: Number of pipeline stages
233+
num_layers: Total number of transformer layers in the model
234+
input_weight: Weight for input modules (tok_embeddings) in layer calculation
235+
output_weight: Weight for output modules (norm + output) in layer calculation
236+
237+
Returns:
238+
List of lists containing module names for each stage
239+
240+
Example:
241+
generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2)
242+
treats embeddings as 2 layers and norm+output as 2 layers for distribution
243+
"""
244+
if num_stages < 1:
245+
raise ValueError("Number of stages must be at least 1")
246+
247+
if num_stages == 1:
248+
# Single stage gets everything
249+
layer_names = [f"layers.{i}" for i in range(num_layers)]
250+
return [["tok_embeddings"] + layer_names + ["norm", "output"]]
251+
252+
# Calculate effective layers including weights
253+
num_effective_layers = num_layers + input_weight + output_weight
254+
255+
if num_stages > num_effective_layers:
256+
raise ValueError(
257+
f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})"
258+
)
259+
260+
# Calculate layers per stage (distribute evenly)
261+
layers_per_stage = num_effective_layers // num_stages
262+
extra_layers = num_effective_layers % num_stages
263+
264+
# Ensure each stage gets at least the weight of input/output modules
265+
if layers_per_stage < max(input_weight, output_weight):
266+
raise ValueError(
267+
f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})"
268+
)
269+
270+
module_names_per_stage = []
271+
current_layer = 0
272+
273+
for stage_idx in range(num_stages):
274+
stage_modules = []
275+
276+
# Calculate effective layers for this stage
277+
effective_layers_for_stage = layers_per_stage
278+
if stage_idx < extra_layers:
279+
effective_layers_for_stage += 1
280+
281+
# First stage: handle input modules with weighting
282+
if stage_idx == 0:
283+
stage_modules.append("tok_embeddings")
284+
# Account for input weight in layer distribution
285+
remaining_layers_for_stage = effective_layers_for_stage - input_weight
286+
287+
# Add transformer layers
288+
for _ in range(remaining_layers_for_stage):
289+
if current_layer < num_layers:
290+
stage_modules.append(f"layers.{current_layer}")
291+
current_layer += 1
292+
293+
# Last stage: handle output modules with weighting
294+
elif stage_idx == num_stages - 1:
295+
# Account for output weight in layer distribution
296+
remaining_layers_for_stage = effective_layers_for_stage - output_weight
297+
298+
# Add transformer layers
299+
for _ in range(remaining_layers_for_stage):
300+
if current_layer < num_layers:
301+
stage_modules.append(f"layers.{current_layer}")
302+
current_layer += 1
303+
304+
# Add output modules
305+
stage_modules.extend(["norm", "output"])
306+
307+
# Middle stages: only transformer layers
308+
else:
309+
for _ in range(effective_layers_for_stage):
310+
if current_layer < num_layers:
311+
stage_modules.append(f"layers.{current_layer}")
312+
current_layer += 1
313+
314+
module_names_per_stage.append(stage_modules)
315+
316+
return module_names_per_stage
317+
318+
319+
def module_split(
320+
model: nn.Module,
321+
module_names_per_stage: list[list[str]],
322+
) -> list[nn.Module]:
323+
"""
324+
This API creates pipeline stages based on specified module names for each stage.
325+
This method updates the model in place.
326+
327+
Args:
328+
model: The complete model to be split
329+
module_names_per_stage: List of lists, where each inner list contains the module names
330+
that should be included in that stage. Module names should be
331+
dot-separated paths. Examples:
332+
- "tok_embeddings" for token embeddings
333+
- "layers.0", "layers.1" for specific transformer layers
334+
- "norm" for the final normalization layer
335+
- "output" for the output projection layer
336+
337+
Returns:
338+
List of model chunks
339+
340+
Example usage:
341+
module_names_per_stage = [
342+
["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
343+
["layers.1", "layers.2"], # Stage 1: middle layers
344+
["norm", "output"] # Stage 2: final norm + output
345+
]
346+
"""
347+
348+
def _build_stage_from_modules(stage_idx: int, module_names: list[str]) -> nn.Module:
349+
stage_model = nn.Module()
350+
# Create a set of modules to keep for faster lookup
351+
modules_to_keep = set(module_names)
352+
print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}")
353+
for module_name, module_value in model.named_children():
354+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
355+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
356+
layers_to_keep = {
357+
name.split(".", 1)[1]
358+
for name in modules_to_keep
359+
if name.startswith(f"{module_name}.")
360+
}
361+
362+
if not layers_to_keep:
363+
continue
364+
365+
# Keep only specified layers
366+
if isinstance(module_value, nn.ModuleDict):
367+
for layer_name in list(module_value.keys()):
368+
if layer_name in layers_to_keep:
369+
setattr(
370+
stage_model,
371+
f"{module_name}.{layer_name}",
372+
module_value[layer_name],
373+
)
374+
else:
375+
indices_to_keep = {
376+
int(idx) for idx in layers_to_keep if idx.isdigit()
377+
}
378+
new_layers = nn.ModuleList(
379+
[
380+
layer
381+
for i, layer in enumerate(module_value)
382+
if i in indices_to_keep
383+
]
384+
)
385+
setattr(stage_model, module_name, new_layers)
386+
387+
continue
388+
389+
# Handle simple module attributes (e.g., "linear", "norm")
390+
if module_name not in modules_to_keep:
391+
continue
392+
393+
setattr(stage_model, module_name, module_value)
394+
395+
return stage_model
396+
397+
num_stages = len(module_names_per_stage)
398+
models = []
399+
400+
for stage_idx in range(num_stages):
401+
module_names = module_names_per_stage[stage_idx]
402+
model_chunk = _build_stage_from_modules(
403+
stage_idx,
404+
module_names,
405+
)
406+
logger.info(f"building stage_idx {stage_idx} " f"with modules {module_names}")
407+
models.append(model_chunk)
408+
409+
return models

torchtitan/models/llama3/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from torchtitan.components.tokenizer import build_hf_tokenizer
1111
from torchtitan.components.validate import build_validator
1212
from torchtitan.datasets.hf_datasets import build_hf_dataloader
13-
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
13+
from torchtitan.protocols.train_spec import FaultTolerantTrainSpec, register_train_spec
14+
from .infra.fault_tolerance import fragment_llama
1415

1516
from .infra.parallelize import parallelize_llama
1617
from .infra.pipeline import pipeline_llama
@@ -71,12 +72,13 @@
7172

7273

7374
register_train_spec(
74-
TrainSpec(
75+
FaultTolerantTrainSpec(
7576
name="llama3",
7677
model_cls=Transformer,
7778
model_args=llama3_configs,
7879
parallelize_fn=parallelize_llama,
7980
pipelining_fn=pipeline_llama,
81+
fragment_fn=fragment_llama,
8082
build_optimizers_fn=build_optimizers,
8183
build_lr_schedulers_fn=build_lr_schedulers,
8284
build_dataloader_fn=build_hf_dataloader,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file is used to setup the model for fault tolerance
8+
9+
import torch.nn as nn
10+
11+
from torchtitan.config import JobConfig
12+
from torchtitan.distributed.pipeline import (
13+
generate_module_names_per_stage,
14+
module_split,
15+
)
16+
17+
from ..model.args import TransformerModelArgs
18+
19+
20+
def fragment_llama(
21+
model: nn.Module,
22+
job_config: JobConfig,
23+
model_config: TransformerModelArgs,
24+
) -> list[nn.Module]:
25+
ft = job_config.fault_tolerance
26+
27+
assert ft.num_fragments > 0
28+
29+
module_names_per_stage = ft.module_names_per_model_chunk
30+
31+
input_weight = 1 # Weight for tok_embeddings
32+
output_weight = 1 # Weight for norm + output layers
33+
34+
if module_names_per_stage == []:
35+
if ft.num_fragments == 1:
36+
return [model]
37+
38+
module_names_per_stage = generate_module_names_per_stage(
39+
ft.num_fragments, model_config.n_layers, input_weight, output_weight
40+
)
41+
42+
model_fragments = module_split(model, module_names_per_stage)
43+
print(f"Created {len(model_fragments)} model fragments")
44+
45+
return model_fragments

torchtitan/protocols/train_spec.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ class TrainSpec:
5656
state_dict_adapter: type[StateDictAdapter] | None = None
5757

5858

59+
FragmentFunction: TypeAlias = Callable[..., list[nn.Module]]
60+
61+
62+
@dataclass
63+
class FaultTolerantTrainSpec(TrainSpec):
64+
fragment_fn: FragmentFunction | None = None
65+
66+
5967
_train_specs = {}
6068

6169

torchtitan/train.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import time
1010
from datetime import timedelta
11-
from typing import Any, Generator, Iterable, Optional
11+
from typing import Any, cast, Generator, Iterable, Optional
1212

1313
import torch
1414
from torch.distributed.elastic.multiprocessing.errors import record
@@ -43,6 +43,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
4343
tokenizer: train_spec_module.BaseTokenizer | None
4444
dataloader: train_spec_module.BaseDataLoader
4545
model_parts: list[torch.nn.Module]
46+
ft_model_parts: list[torch.nn.Module]
4647
loss_fn: train_spec_module.LossFunction
4748
optimizers: train_spec_module.OptimizersContainer
4849
lr_schedulers: train_spec_module.LRSchedulersContainer
@@ -215,6 +216,8 @@ def __init__(self, job_config: JobConfig):
215216
self.loss_fn, self.gradient_accumulation_steps
216217
)
217218

219+
self.ft_model_parts = []
220+
218221
# apply parallelisms and initialization
219222
if parallel_dims.pp_enabled:
220223
if not self.train_spec.pipelining_fn:
@@ -261,6 +264,19 @@ def __init__(self, job_config: JobConfig):
261264

262265
self.model_parts = [model]
263266

267+
ft = job_config.fault_tolerance
268+
269+
if ft.enable:
270+
train_spec = cast(
271+
train_spec_module.FaultTolerantTrainSpec, self.train_spec
272+
)
273+
if train_spec.fragment_fn:
274+
self.ft_model_parts = train_spec.fragment_fn(
275+
model, job_config, model_args
276+
)
277+
else:
278+
self.ft_model_parts = [model]
279+
264280
self.ft_manager.maybe_set_all_reduce_hook(self.model_parts)
265281

266282
# initialize device memory monitor and get peak flops for MFU calculation
@@ -524,7 +540,7 @@ def train(self):
524540
maybe_semi_sync_training(
525541
job_config.fault_tolerance,
526542
ft_manager=self.ft_manager,
527-
model_parts=self.model_parts,
543+
model_parts=self.ft_model_parts,
528544
optimizer=self.optimizers,
529545
),
530546
):

0 commit comments

Comments
 (0)