Skip to content

Commit 3065a2a

Browse files
authored
model fragments for diloco (#1446)
Summary: - add a configuration option for users to provide how they want to partition the model - if this is provided, the model needs to implement `FaultTolerantTrainingSpec` that defines the framentation function to split the model based on the configuration - determine the model fragments in training script to pass to ft manager Test Plan: Running llama3 8b parameters with 2 fragments, 1 step delay, each fragment gets synced every 20 steps <img width="944" height="545" alt="image" src="https://github.com/user-attachments/assets/6d16f486-7260-49d6-8ba3-3e98cd331e58" /> --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1446). * #1516 * __->__ #1446
1 parent a204e31 commit 3065a2a

File tree

9 files changed

+272
-7
lines changed

9 files changed

+272
-7
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchtitan.components.ft.manager import (
8+
FTManager,
9+
has_torchft,
10+
maybe_semi_sync_training,
11+
)
12+
13+
14+
__all__ = [
15+
"FTManager",
16+
"has_torchft",
17+
"maybe_semi_sync_training",
18+
]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchtitan.components.ft.diloco.protocol import FaultTolerantTrainSpec
8+
from torchtitan.components.ft.diloco.utils import fragment_llm
9+
10+
__all__ = [
11+
"FaultTolerantTrainSpec",
12+
"fragment_llm",
13+
]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Callable, TypeAlias
9+
10+
import torch.nn as nn
11+
from torchtitan.protocols.train_spec import TrainSpec
12+
13+
14+
FragmentFunction: TypeAlias = Callable[..., list[nn.Module]]
15+
16+
17+
@dataclass
18+
class FaultTolerantTrainSpec(TrainSpec):
19+
fragment_fn: FragmentFunction | None = None
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch.nn as nn
8+
from torchtitan.config.job_config import FaultTolerance as FTConfig
9+
from torchtitan.distributed.pipeline import generate_llm_fqn_per_model_part
10+
11+
12+
def module_split(
13+
model: nn.Module,
14+
module_fqns_per_model_fragment: list[list[str]],
15+
) -> list[nn.Module]:
16+
"""
17+
This API creates fragments based on specified module names for each fragment.
18+
This method updates the model in place.
19+
20+
Args:
21+
model: The complete model to be split
22+
module_fqns_per_model_fragment: List of lists, where each inner list contains the module names
23+
that should be included in that fragment. Module names should be
24+
dot-separated paths. Examples:
25+
- "tok_embeddings" for token embeddings
26+
- "layers.0", "layers.1" for specific transformer layers
27+
- "norm" for the final normalization layer
28+
- "output" for the output projection layer
29+
30+
Returns:
31+
List of model fragments
32+
33+
Example usage:
34+
module_fqns_per_model_fragment = [
35+
["tok_embeddings", "layers.0"], # fragment 0: embeddings + first layer
36+
["layers.1", "layers.2"], # fragment 1: middle layers
37+
["norm", "output"] # fragment 2: final norm + output
38+
]
39+
"""
40+
41+
def _build_fragment_from_modules(
42+
fragment_idx: int, module_names: list[str]
43+
) -> nn.Module:
44+
fragment_model = nn.Module()
45+
# Create a set of modules to keep for faster lookup
46+
modules_to_keep = set(module_names)
47+
print(f"fragment {fragment_idx}: Modules to keep: {modules_to_keep}")
48+
for module_name, module_value in model.named_children():
49+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
50+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
51+
layers_to_keep = {
52+
name.split(".", 1)[1]
53+
for name in modules_to_keep
54+
if name.startswith(f"{module_name}.")
55+
}
56+
57+
if not layers_to_keep:
58+
continue
59+
60+
# Keep only specified layers
61+
if isinstance(module_value, nn.ModuleDict):
62+
for layer_name in list(module_value.keys()):
63+
if layer_name in layers_to_keep:
64+
setattr(
65+
fragment_model,
66+
f"{module_name}.{layer_name}",
67+
module_value[layer_name],
68+
)
69+
else:
70+
indices_to_keep = {
71+
int(idx) for idx in layers_to_keep if idx.isdigit()
72+
}
73+
new_layers = nn.ModuleList(
74+
[
75+
layer
76+
for i, layer in enumerate(module_value)
77+
if i in indices_to_keep
78+
]
79+
)
80+
setattr(fragment_model, module_name, new_layers)
81+
82+
continue
83+
84+
# Handle simple module attributes (e.g., "linear", "norm")
85+
if module_name not in modules_to_keep:
86+
continue
87+
88+
setattr(fragment_model, module_name, module_value)
89+
90+
return fragment_model
91+
92+
num_fragments = len(module_fqns_per_model_fragment)
93+
model_fragments = []
94+
95+
for fragment_idx in range(num_fragments):
96+
module_names = module_fqns_per_model_fragment[fragment_idx]
97+
model_fragment = _build_fragment_from_modules(
98+
fragment_idx,
99+
module_names,
100+
)
101+
print(f"building fragment_idx {fragment_idx} " f"with modules {module_names}")
102+
model_fragments.append(model_fragment)
103+
104+
return model_fragments
105+
106+
107+
def fragment_llm(
108+
model: nn.Module,
109+
ft_config: FTConfig,
110+
n_layers: int,
111+
) -> list[nn.Module]:
112+
assert ft_config.num_fragments > 0
113+
114+
module_fqns_per_model_fragment = ft_config.module_fqns_per_model_fragment
115+
116+
input_weight = 1 # Weight for tok_embeddings
117+
output_weight = 1 # Weight for norm + output layers
118+
119+
if module_fqns_per_model_fragment == []:
120+
if ft_config.num_fragments == 1:
121+
return [model]
122+
123+
module_fqns_per_model_fragment = generate_llm_fqn_per_model_part(
124+
ft_config.num_fragments, n_layers, input_weight, output_weight
125+
)
126+
127+
model_fragments = module_split(model, module_fqns_per_model_fragment)
128+
print(f"Created {len(model_fragments)} model fragments")
129+
130+
return model_fragments

torchtitan/components/ft.py renamed to torchtitan/components/ft/manager.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import importlib
88
from contextlib import nullcontext
99
from datetime import timedelta
10-
from typing import ContextManager, Optional, TYPE_CHECKING, Union
10+
from typing import Callable, ContextManager, Optional, TYPE_CHECKING, Union
1111

1212
import torch
1313
import torch.distributed as dist
14+
15+
import torch.nn as nn
1416
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
1517
from torch.distributed.distributed_c10d import ReduceOp
1618
from torchtitan.config.job_config import FaultTolerance as FTConfig
@@ -108,8 +110,10 @@ def loss_sync_pg(
108110
def maybe_semi_sync_training(
109111
ft_config: FTConfig,
110112
ft_manager: FTManager,
111-
model_parts: list[torch.nn.Module],
113+
model: torch.nn.Module,
114+
n_layers: int,
112115
optimizer: torch.optim.Optimizer,
116+
fragment_fn: Optional[Callable[..., list[nn.Module]]] = None,
113117
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
114118
"""
115119
If TorchFT is enabled and the config is set, use semi_sync_method
@@ -122,6 +126,11 @@ def maybe_semi_sync_training(
122126
ft_manager._manager is not None
123127
), "FTManager must be enabled to use semi-sync training."
124128
if semi_sync_method.lower() == "diloco":
129+
if fragment_fn:
130+
model_parts = fragment_fn(model, ft_config, n_layers)
131+
else:
132+
model_parts = [model]
133+
125134
# Create the outer optimizer based on the inner optimizer parameters.
126135
outer_optimizers = []
127136
for model in model_parts:
@@ -142,10 +151,9 @@ def maybe_semi_sync_training(
142151
fragment_update_alpha=ft_config.fragment_update_alpha,
143152
)
144153
elif semi_sync_method.lower() == "local_sgd":
145-
assert len(model_parts) == 1
146154
return local_sgd.LocalSGD(
147155
manager=ft_manager._manager,
148-
model=model_parts[0],
156+
model=model,
149157
optimizer=optimizer,
150158
sync_every=ft_config.sync_steps,
151159
)

torchtitan/config/job_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,22 @@ class FaultTolerance:
696696
This is only used when "semi_sync_method" is set.
697697
"""
698698

699+
module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list)
700+
"""
701+
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment.
702+
Each inner list represents one model fragment and contains the module names that belong to that fragment.
703+
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
704+
will create 3 chunks: the first containing tok_embeddings and layers.0,
705+
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
706+
"""
707+
708+
num_fragments: int = 1
709+
"""
710+
Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco".
711+
This is used to automatically split the model into fragments provided that the model
712+
implements FaultTolerantTrainSpec
713+
"""
714+
699715

700716
@dataclass
701717
class Experimental:
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchtitan.components.ft.diloco import FaultTolerantTrainSpec, fragment_llm
8+
from torchtitan.components.loss import build_cross_entropy_loss
9+
from torchtitan.components.lr_scheduler import build_lr_schedulers
10+
from torchtitan.components.optimizer import build_optimizers
11+
from torchtitan.components.tokenizer import build_hf_tokenizer
12+
from torchtitan.components.validate import build_validator
13+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
14+
from torchtitan.protocols.train_spec import register_train_spec
15+
from ..llama3 import (
16+
llama3_configs,
17+
Llama3StateDictAdapter,
18+
parallelize_llama,
19+
pipeline_llama,
20+
Transformer,
21+
TransformerModelArgs,
22+
)
23+
24+
__all__ = [
25+
"parallelize_llama",
26+
"pipeline_llama",
27+
"TransformerModelArgs",
28+
"Transformer",
29+
"llama3_configs",
30+
]
31+
32+
33+
register_train_spec(
34+
FaultTolerantTrainSpec(
35+
name="llama3_ft",
36+
model_cls=Transformer,
37+
model_args=llama3_configs,
38+
parallelize_fn=parallelize_llama,
39+
pipelining_fn=pipeline_llama,
40+
fragment_fn=fragment_llm,
41+
build_optimizers_fn=build_optimizers,
42+
build_lr_schedulers_fn=build_lr_schedulers,
43+
build_dataloader_fn=build_hf_dataloader,
44+
build_tokenizer_fn=build_hf_tokenizer,
45+
build_loss_fn=build_cross_entropy_loss,
46+
build_validator_fn=build_validator,
47+
state_dict_adapter=Llama3StateDictAdapter,
48+
)
49+
)

torchtitan/protocols/train_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from collections.abc import Callable
88
from dataclasses import dataclass
9-
from typing import TypeAlias
9+
from typing import Mapping, TypeAlias
1010

1111
import torch.nn as nn
1212
from torch.distributed.pipelining.schedules import _PipelineSchedule
@@ -43,7 +43,7 @@
4343
class TrainSpec:
4444
name: str
4545
model_cls: type[ModelProtocol]
46-
model_args: dict[str, BaseModelArgs]
46+
model_args: Mapping[str, BaseModelArgs]
4747
parallelize_fn: ParallelizeFunction
4848
pipelining_fn: PipeliningFunction | None
4949
build_optimizers_fn: OptimizersBuilder

torchtitan/train.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
4848
lr_schedulers: train_spec_module.LRSchedulersContainer
4949
validator: train_spec_module.BaseValidator
5050
metrics_processor: train_spec_module.MetricsProcessor
51+
model_args: train_spec_module.BaseModelArgs
5152

5253
# non-swappable training components
5354
checkpointer: CheckpointManager
@@ -146,6 +147,7 @@ def __init__(self, job_config: JobConfig):
146147
model_args = self.train_spec.model_args[job_config.model.flavor]
147148
# set the model args from training job configs
148149
model_args.update_from_config(job_config)
150+
self.model_args = model_args
149151

150152
logger.info(
151153
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
@@ -553,8 +555,18 @@ def train(self):
553555
maybe_semi_sync_training(
554556
job_config.fault_tolerance,
555557
ft_manager=self.ft_manager,
556-
model_parts=self.model_parts,
558+
model=self.model_parts[0],
559+
n_layers=(
560+
self.model_args.n_layers
561+
if hasattr(self.model_args, "n_layers")
562+
else 0
563+
),
557564
optimizer=self.optimizers,
565+
fragment_fn=(
566+
self.train_spec.fragment_fn
567+
if hasattr(self.train_spec, "fragment_fn")
568+
else None
569+
),
558570
),
559571
):
560572
data_iterator = self.batch_generator(self.dataloader)

0 commit comments

Comments
 (0)