Skip to content

Commit be87993

Browse files
committed
model fragments for diloco
Test Plan: ```bash $ LOGLEVEL=INFO NCCL_DEBUG_SUBSYS=ALL NCCL_DEBUG=WARN TORCH_CPP_LOG_LEVEL=INFO CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_train.sh --fault_tolerance.enable --fault_tolerance.group_size=1 --fault_tolerance.replica_id=0 --training.local_batch_size=2 --fault_tolerance.sync_steps=10 --fault_tolerance.semi_sync_method=diloco --parallelism.data_parallel_shard_degree=2 --fault_tolerance.num_fragments=2 ```
1 parent 9d1f27d commit be87993

File tree

6 files changed

+179
-4
lines changed

6 files changed

+179
-4
lines changed

torchtitan/config_manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,18 @@ class FaultTolerance:
691691
This is only used when "semi_sync_method" is set.
692692
"""
693693

694+
module_names_per_model_chunk: list[list[str]] = field(default_factory=list)
695+
"""
696+
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk.
697+
Each inner list represents one model chunk and contains the module names that belong to that chunk.
698+
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
699+
will create 3 chunks: the first containing tok_embeddings and layers.0,
700+
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
701+
This provides more explicit control over which modules belong to each chunk compared to split points.
702+
"""
703+
704+
num_fragments: int = 1
705+
694706

695707
@dataclass
696708
class Experimental:

torchtitan/distributed/pipeline.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"stage_ids_this_rank",
3131
"generate_module_names_per_stage",
3232
"pipeline_module_split",
33+
"module_split",
3334
]
3435

3536

@@ -333,3 +334,95 @@ def _build_stage_from_modules(
333334
models.append(model_chunk)
334335

335336
return stages, models
337+
338+
def module_split(
339+
model: nn.Module,
340+
module_names_per_stage: list[list[str]],
341+
) -> list[nn.Module]:
342+
"""
343+
This API creates pipeline stages based on specified module names for each stage.
344+
This method updates the model in place.
345+
346+
Args:
347+
model: The complete model to be split
348+
module_names_per_stage: List of lists, where each inner list contains the module names
349+
that should be included in that stage. Module names should be
350+
dot-separated paths. Examples:
351+
- "tok_embeddings" for token embeddings
352+
- "layers.0", "layers.1" for specific transformer layers
353+
- "norm" for the final normalization layer
354+
- "output" for the output projection layer
355+
356+
Returns:
357+
List of model chunks
358+
359+
Example usage:
360+
module_names_per_stage = [
361+
["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
362+
["layers.1", "layers.2"], # Stage 1: middle layers
363+
["norm", "output"] # Stage 2: final norm + output
364+
]
365+
"""
366+
def _build_stage_from_modules(
367+
stage_idx: int, module_names: list[str]
368+
) -> nn.Module:
369+
stage_model = nn.Module()
370+
# Create a set of modules to keep for faster lookup
371+
modules_to_keep = set(module_names)
372+
print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}")
373+
for module_name, module_value in model.named_children():
374+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
375+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
376+
layers_to_keep = {
377+
name.split(".", 1)[1]
378+
for name in modules_to_keep
379+
if name.startswith(f"{module_name}.")
380+
}
381+
382+
if not layers_to_keep:
383+
continue
384+
385+
# Keep only specified layers
386+
if isinstance(module_value, nn.ModuleDict):
387+
for layer_name in list(module_value.keys()):
388+
if layer_name in layers_to_keep:
389+
setattr(stage_model, f"{module_name}.{layer_name}", module_value)
390+
else:
391+
indices_to_keep = {
392+
int(idx) for idx in layers_to_keep if idx.isdigit()
393+
}
394+
new_layers = nn.ModuleList(
395+
[
396+
layer
397+
for i, layer in enumerate(module_value)
398+
if i in indices_to_keep
399+
]
400+
)
401+
setattr(stage_model, module_name, new_layers)
402+
403+
continue
404+
405+
# Handle simple module attributes (e.g., "linear", "norm")
406+
if module_name not in modules_to_keep:
407+
continue
408+
409+
setattr(stage_model, module_name, module_value)
410+
411+
return model
412+
413+
num_stages = len(module_names_per_stage)
414+
models = []
415+
416+
for stage_idx in range(num_stages):
417+
module_names = module_names_per_stage[stage_idx]
418+
model_chunk = _build_stage_from_modules(
419+
stage_idx,
420+
module_names,
421+
)
422+
logger.info(
423+
f"building stage_idx {stage_idx} "
424+
f"with modules {module_names}"
425+
)
426+
models.append(model_chunk)
427+
428+
return models

torchtitan/models/llama3/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from torchtitan.components.tokenizer import build_hf_tokenizer
1111
from torchtitan.components.validate import build_validator
1212
from torchtitan.datasets.hf_datasets import build_hf_dataloader
13-
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
13+
from torchtitan.protocols.train_spec import FaultTolerantTrainSpec, register_train_spec
1414

1515
from .infra.parallelize import parallelize_llama
1616
from .infra.pipeline import pipeline_llama
17+
from .infra.fault_tolerance import fragment_llama
1718
from .model.args import TransformerModelArgs
1819
from .model.model import Transformer
1920
from .model.state_dict_adapter import Llama3StateDictAdapter
@@ -70,12 +71,13 @@
7071

7172

7273
register_train_spec(
73-
TrainSpec(
74+
FaultTolerantTrainSpec(
7475
name="llama3",
7576
model_cls=Transformer,
7677
model_args=llama3_configs,
7778
parallelize_fn=parallelize_llama,
7879
pipelining_fn=pipeline_llama,
80+
fragment_fn=fragment_llama,
7981
build_optimizers_fn=build_optimizers,
8082
build_lr_schedulers_fn=build_lr_schedulers,
8183
build_dataloader_fn=build_hf_dataloader,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file is used to setup the model for fault tolerance
8+
9+
import torch.nn as nn
10+
11+
12+
from torchtitan.config_manager import JobConfig
13+
from torchtitan.distributed.pipeline import (
14+
generate_module_names_per_stage,
15+
module_split,
16+
)
17+
from torchtitan.tools.logging import logger
18+
19+
from ..model.args import TransformerModelArgs
20+
21+
def fragment_llama(
22+
model: nn.Module,
23+
job_config: JobConfig,
24+
model_config: TransformerModelArgs,
25+
) -> list[nn.Module]:
26+
ft = job_config.fault_tolerance
27+
28+
assert ft.num_fragments > 0
29+
30+
module_names_per_stage = ft.module_names_per_model_chunk
31+
32+
input_weight = 1 # Weight for tok_embeddings
33+
output_weight = 1 # Weight for norm + output layers
34+
35+
if module_names_per_stage == []:
36+
if ft.num_fragments == 1:
37+
return [model]
38+
39+
module_names_per_stage = generate_module_names_per_stage(
40+
ft.num_fragments, model_config.n_layers, input_weight, output_weight
41+
)
42+
43+
44+
model_fragments = module_split(model, module_names_per_stage)
45+
print(f"Created {len(model_fragments)} model fragments")
46+
47+
return model_fragments

torchtitan/protocols/train_spec.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@ class TrainSpec:
104104
state_dict_adapter: type[StateDictAdapter] | None = None
105105

106106

107+
FragmentFunction: TypeAlias = Callable[
108+
..., list[nn.Module]
109+
]
110+
111+
112+
@dataclass
113+
class FaultTolerantTrainSpec(TrainSpec):
114+
fragment_fn: FragmentFunction | None = None
115+
116+
107117
_train_specs = {}
108118

109119

torchtitan/train.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import time
1010
from datetime import timedelta
11-
from typing import Any, Generator, Iterable, Optional
11+
from typing import Any, Generator, Iterable, Optional, cast
1212

1313
import torch
1414
from torch.distributed.elastic.multiprocessing.errors import record
@@ -42,6 +42,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
4242
# swappable training components in TrainSpec
4343
dataloader: train_spec_module.BaseDataLoader
4444
model_parts: list[torch.nn.Module]
45+
ft_model_parts: list[torch.nn.Module]
4546
loss_fn: train_spec_module.LossFunction
4647
optimizers: train_spec_module.OptimizersContainer
4748
lr_schedulers: train_spec_module.LRSchedulersContainer
@@ -256,6 +257,16 @@ def __init__(self, job_config: JobConfig):
256257

257258
self.model_parts = [model]
258259

260+
ft = job_config.fault_tolerance
261+
262+
if ft.enable:
263+
train_spec = cast(train_spec_module.FaultTolerantTrainSpec, self.train_spec)
264+
if train_spec.fragment_fn:
265+
self.ft_model_parts = train_spec.fragment_fn(model, job_config, model_args)
266+
else:
267+
self.ft_model_parts = [model]
268+
269+
259270
self.ft_manager.maybe_set_all_reduce_hook(self.model_parts)
260271

261272
# initialize device memory monitor and get peak flops for MFU calculation
@@ -501,7 +512,7 @@ def train(self):
501512
maybe_semi_sync_training(
502513
job_config.fault_tolerance,
503514
ft_manager=self.ft_manager,
504-
model_parts=self.model_parts,
515+
model_parts=self.ft_model_parts,
505516
optimizer=self.optimizers,
506517
),
507518
):

0 commit comments

Comments
 (0)