Skip to content

Commit ef69776

Browse files
committed
model fragments for diloco
Summary: - add a configuration option for users to provide how they want to partition the model - if this is provided, the model needs to implement `FaultTolerantTrainingSpec` that defines the framentation function to split the model based on the configuration - determine the model fragments in training script to pass to ft manager Test Plan: Running llama3 8b parameters with 2 fragments, 1 step delay, each fragment gets synced every 20 steps <img width="944" height="545" alt="image" src="https://github.com/user-attachments/assets/6d16f486-7260-49d6-8ba3-3e98cd331e58" />
1 parent f423b62 commit ef69776

File tree

7 files changed

+184
-8
lines changed

7 files changed

+184
-8
lines changed

torchtitan/components/ft.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@
77
import importlib
88
from contextlib import nullcontext
99
from datetime import timedelta
10-
from typing import ContextManager, Optional, TYPE_CHECKING, Union
10+
from typing import cast, ContextManager, Optional, TYPE_CHECKING, Union
1111

1212
import torch
1313
import torch.distributed as dist
1414
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
1515
from torch.distributed.distributed_c10d import ReduceOp
1616
from torchtitan.config.job_config import FaultTolerance as FTConfig
17+
from torchtitan.protocols.train_spec import (
18+
BaseModelArgs,
19+
FaultTolerantTrainSpec,
20+
TrainSpec,
21+
)
1722

1823
if importlib.util.find_spec("torchft") is not None:
1924
import torchft as ft
@@ -108,8 +113,10 @@ def loss_sync_pg(
108113
def maybe_semi_sync_training(
109114
ft_config: FTConfig,
110115
ft_manager: FTManager,
111-
model_parts: list[torch.nn.Module],
116+
model: torch.nn.Module,
117+
model_args: BaseModelArgs,
112118
optimizer: torch.optim.Optimizer,
119+
train_spec: TrainSpec,
113120
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
114121
"""
115122
If TorchFT is enabled and the config is set, use semi_sync_method
@@ -122,6 +129,12 @@ def maybe_semi_sync_training(
122129
ft_manager._manager is not None
123130
), "FTManager must be enabled to use semi-sync training."
124131
if semi_sync_method.lower() == "diloco":
132+
train_spec = cast(FaultTolerantTrainSpec, train_spec)
133+
if train_spec.fragment_fn:
134+
model_parts = train_spec.fragment_fn(model, ft_config, model_args)
135+
else:
136+
model_parts = [model]
137+
125138
# Create the outer optimizer based on the inner optimizer parameters.
126139
outer_optimizers = []
127140
for model in model_parts:
@@ -142,10 +155,10 @@ def maybe_semi_sync_training(
142155
fragment_update_alpha=ft_config.fragment_update_alpha,
143156
)
144157
elif semi_sync_method.lower() == "local_sgd":
145-
assert len(model_parts) == 1
158+
assert len(model) == 1
146159
return local_sgd.LocalSGD(
147160
manager=ft_manager._manager,
148-
model=model_parts[0],
161+
model=model,
149162
optimizer=optimizer,
150163
sync_every=ft_config.sync_steps,
151164
)

torchtitan/config/job_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,18 @@ class FaultTolerance:
686686
This is only used when "semi_sync_method" is set.
687687
"""
688688

689+
module_names_per_model_chunk: list[list[str]] = field(default_factory=list)
690+
"""
691+
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk.
692+
Each inner list represents one model chunk and contains the module names that belong to that chunk.
693+
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
694+
will create 3 chunks: the first containing tok_embeddings and layers.0,
695+
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
696+
This provides more explicit control over which modules belong to each chunk compared to split points.
697+
"""
698+
699+
num_fragments: int = 1
700+
689701

690702
@dataclass
691703
class Experimental:

torchtitan/distributed/pipeline.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"stage_ids_this_rank",
3131
"generate_llm_fqn_per_model_part",
3232
"pipeline_module_split",
33+
"module_split",
3334
]
3435

3536

@@ -118,7 +119,7 @@ def stage_ids_this_rank(
118119
stages_per_rank == 2
119120
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
120121
stage_v_pairs = list(
121-
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
122+
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1), strict=True)
122123
)
123124
return stage_v_pairs[pp_rank]
124125

@@ -352,3 +353,96 @@ def _build_stage_from_modules(
352353
models.append(model_chunk)
353354

354355
return stages, models
356+
357+
358+
def module_split(
359+
model: nn.Module,
360+
module_names_per_stage: list[list[str]],
361+
) -> list[nn.Module]:
362+
"""
363+
This API creates pipeline stages based on specified module names for each stage.
364+
This method updates the model in place.
365+
366+
Args:
367+
model: The complete model to be split
368+
module_names_per_stage: List of lists, where each inner list contains the module names
369+
that should be included in that stage. Module names should be
370+
dot-separated paths. Examples:
371+
- "tok_embeddings" for token embeddings
372+
- "layers.0", "layers.1" for specific transformer layers
373+
- "norm" for the final normalization layer
374+
- "output" for the output projection layer
375+
376+
Returns:
377+
List of model chunks
378+
379+
Example usage:
380+
module_names_per_stage = [
381+
["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
382+
["layers.1", "layers.2"], # Stage 1: middle layers
383+
["norm", "output"] # Stage 2: final norm + output
384+
]
385+
"""
386+
387+
def _build_stage_from_modules(stage_idx: int, module_names: list[str]) -> nn.Module:
388+
stage_model = nn.Module()
389+
# Create a set of modules to keep for faster lookup
390+
modules_to_keep = set(module_names)
391+
print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}")
392+
for module_name, module_value in model.named_children():
393+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
394+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
395+
layers_to_keep = {
396+
name.split(".", 1)[1]
397+
for name in modules_to_keep
398+
if name.startswith(f"{module_name}.")
399+
}
400+
401+
if not layers_to_keep:
402+
continue
403+
404+
# Keep only specified layers
405+
if isinstance(module_value, nn.ModuleDict):
406+
for layer_name in list(module_value.keys()):
407+
if layer_name in layers_to_keep:
408+
setattr(
409+
stage_model,
410+
f"{module_name}.{layer_name}",
411+
module_value[layer_name],
412+
)
413+
else:
414+
indices_to_keep = {
415+
int(idx) for idx in layers_to_keep if idx.isdigit()
416+
}
417+
new_layers = nn.ModuleList(
418+
[
419+
layer
420+
for i, layer in enumerate(module_value)
421+
if i in indices_to_keep
422+
]
423+
)
424+
setattr(stage_model, module_name, new_layers)
425+
426+
continue
427+
428+
# Handle simple module attributes (e.g., "linear", "norm")
429+
if module_name not in modules_to_keep:
430+
continue
431+
432+
setattr(stage_model, module_name, module_value)
433+
434+
return stage_model
435+
436+
num_stages = len(module_names_per_stage)
437+
models = []
438+
439+
for stage_idx in range(num_stages):
440+
module_names = module_names_per_stage[stage_idx]
441+
model_chunk = _build_stage_from_modules(
442+
stage_idx,
443+
module_names,
444+
)
445+
logger.info(f"building stage_idx {stage_idx} " f"with modules {module_names}")
446+
models.append(model_chunk)
447+
448+
return models

torchtitan/models/llama3/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from torchtitan.components.tokenizer import build_hf_tokenizer
1111
from torchtitan.components.validate import build_validator
1212
from torchtitan.datasets.hf_datasets import build_hf_dataloader
13-
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
13+
from torchtitan.protocols.train_spec import FaultTolerantTrainSpec, register_train_spec
14+
from .infra.fault_tolerance import fragment_llama
1415

1516
from .infra.parallelize import parallelize_llama
1617
from .infra.pipeline import pipeline_llama
@@ -71,12 +72,13 @@
7172

7273

7374
register_train_spec(
74-
TrainSpec(
75+
FaultTolerantTrainSpec(
7576
name="llama3",
7677
model_cls=Transformer,
7778
model_args=llama3_configs,
7879
parallelize_fn=parallelize_llama,
7980
pipelining_fn=pipeline_llama,
81+
fragment_fn=fragment_llama,
8082
build_optimizers_fn=build_optimizers,
8183
build_lr_schedulers_fn=build_lr_schedulers,
8284
build_dataloader_fn=build_hf_dataloader,
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file is used to setup the model for fault tolerance
8+
9+
import torch.nn as nn
10+
11+
from torchtitan.config.job_config import FaultTolerance as FTConfig
12+
from torchtitan.distributed.pipeline import (
13+
generate_llm_fqn_per_model_part,
14+
module_split,
15+
)
16+
17+
from ..model.args import TransformerModelArgs
18+
19+
20+
def fragment_llama(
21+
model: nn.Module,
22+
ft_config: FTConfig,
23+
model_config: TransformerModelArgs,
24+
) -> list[nn.Module]:
25+
assert ft_config.num_fragments > 0
26+
27+
module_names_per_stage = ft_config.module_names_per_model_chunk
28+
29+
input_weight = 1 # Weight for tok_embeddings
30+
output_weight = 1 # Weight for norm + output layers
31+
32+
if module_names_per_stage == []:
33+
if ft_config.num_fragments == 1:
34+
return [model]
35+
36+
module_names_per_stage = generate_llm_fqn_per_model_part(
37+
ft_config.num_fragments, model_config.n_layers, input_weight, output_weight
38+
)
39+
40+
model_fragments = module_split(model, module_names_per_stage)
41+
print(f"Created {len(model_fragments)} model fragments")
42+
43+
return model_fragments

torchtitan/protocols/train_spec.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ class TrainSpec:
5656
state_dict_adapter: type[StateDictAdapter] | None = None
5757

5858

59+
FragmentFunction: TypeAlias = Callable[..., list[nn.Module]]
60+
61+
62+
@dataclass
63+
class FaultTolerantTrainSpec(TrainSpec):
64+
fragment_fn: FragmentFunction | None = None
65+
66+
5967
_train_specs = {}
6068

6169

torchtitan/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
3838
job_config: JobConfig
3939
parallel_dims: ParallelDims
4040
train_spec: train_spec_module.TrainSpec
41+
model_args: train_spec_module.BaseModelArgs
4142

4243
# swappable training components in TrainSpec
4344
tokenizer: train_spec_module.BaseTokenizer | None
@@ -146,6 +147,7 @@ def __init__(self, job_config: JobConfig):
146147
model_args = self.train_spec.model_args[job_config.model.flavor]
147148
# set the model args from training job configs
148149
model_args.update_from_config(job_config)
150+
self.model_args = model_args
149151

150152
logger.info(
151153
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
@@ -529,8 +531,10 @@ def train(self):
529531
maybe_semi_sync_training(
530532
job_config.fault_tolerance,
531533
ft_manager=self.ft_manager,
532-
model_parts=self.model_parts,
534+
model=self.model_parts[0],
535+
model_args=self.model_args,
533536
optimizer=self.optimizers,
537+
train_spec=self.train_spec,
534538
),
535539
):
536540
data_iterator = self.batch_generator(self.dataloader)

0 commit comments

Comments
 (0)