Skip to content

Commit ec935f5

Browse files
committed
model fragments for diloco
Summary: - add a configuration option for users to provide how they want to partition the model - if this is provided, the model needs to implement `FaultTolerantTrainingSpec` that defines the framentation function to split the model based on the configuration - determine the model fragments in training script to pass to ft manager Test Plan: Running llama3 8b parameters with 2 fragments, 1 step delay, each fragment gets synced every 20 steps <img width="944" height="545" alt="image" src="https://github.com/user-attachments/assets/6d16f486-7260-49d6-8ba3-3e98cd331e58" />
1 parent cf30b29 commit ec935f5

File tree

9 files changed

+276
-7
lines changed

9 files changed

+276
-7
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Import to register quantization modules.
8+
from torchtitan.components.ft.diloco import fragment_llm
9+
from torchtitan.components.ft.manager import (
10+
FTManager,
11+
has_torchft,
12+
maybe_semi_sync_training,
13+
)
14+
15+
16+
__all__ = [
17+
"FTManager",
18+
"has_torchft",
19+
"maybe_semi_sync_training",
20+
"fragment_llm",
21+
]

torchtitan/components/ft/diloco.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch.nn as nn
8+
from torchtitan.config.job_config import FaultTolerance as FTConfig
9+
from torchtitan.distributed.pipeline import generate_llm_fqn_per_model_part
10+
11+
12+
def module_split(
13+
model: nn.Module,
14+
module_fqns_per_model_fragment: list[list[str]],
15+
) -> list[nn.Module]:
16+
"""
17+
This API creates fragments based on specified module names for each fragment.
18+
This method updates the model in place.
19+
20+
Args:
21+
model: The complete model to be split
22+
module_fqns_per_model_fragment: List of lists, where each inner list contains the module names
23+
that should be included in that fragment. Module names should be
24+
dot-separated paths. Examples:
25+
- "tok_embeddings" for token embeddings
26+
- "layers.0", "layers.1" for specific transformer layers
27+
- "norm" for the final normalization layer
28+
- "output" for the output projection layer
29+
30+
Returns:
31+
List of model fragments
32+
33+
Example usage:
34+
module_fqns_per_model_fragment = [
35+
["tok_embeddings", "layers.0"], # fragment 0: embeddings + first layer
36+
["layers.1", "layers.2"], # fragment 1: middle layers
37+
["norm", "output"] # fragment 2: final norm + output
38+
]
39+
"""
40+
41+
def _build_fragment_from_modules(
42+
fragment_idx: int, module_names: list[str]
43+
) -> nn.Module:
44+
fragment_model = nn.Module()
45+
# Create a set of modules to keep for faster lookup
46+
modules_to_keep = set(module_names)
47+
print(f"fragment {fragment_idx}: Modules to keep: {modules_to_keep}")
48+
for module_name, module_value in model.named_children():
49+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
50+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
51+
layers_to_keep = {
52+
name.split(".", 1)[1]
53+
for name in modules_to_keep
54+
if name.startswith(f"{module_name}.")
55+
}
56+
57+
if not layers_to_keep:
58+
continue
59+
60+
# Keep only specified layers
61+
if isinstance(module_value, nn.ModuleDict):
62+
for layer_name in list(module_value.keys()):
63+
if layer_name in layers_to_keep:
64+
setattr(
65+
fragment_model,
66+
f"{module_name}.{layer_name}",
67+
module_value[layer_name],
68+
)
69+
else:
70+
indices_to_keep = {
71+
int(idx) for idx in layers_to_keep if idx.isdigit()
72+
}
73+
new_layers = nn.ModuleList(
74+
[
75+
layer
76+
for i, layer in enumerate(module_value)
77+
if i in indices_to_keep
78+
]
79+
)
80+
setattr(fragment_model, module_name, new_layers)
81+
82+
continue
83+
84+
# Handle simple module attributes (e.g., "linear", "norm")
85+
if module_name not in modules_to_keep:
86+
continue
87+
88+
setattr(fragment_model, module_name, module_value)
89+
90+
return fragment_model
91+
92+
num_fragments = len(module_fqns_per_model_fragment)
93+
model_fragments = []
94+
95+
for fragment_idx in range(num_fragments):
96+
module_names = module_fqns_per_model_fragment[fragment_idx]
97+
model_fragment = _build_fragment_from_modules(
98+
fragment_idx,
99+
module_names,
100+
)
101+
print(f"building fragment_idx {fragment_idx} " f"with modules {module_names}")
102+
model_fragments.append(model_fragment)
103+
104+
return model_fragments
105+
106+
107+
def fragment_llm(
108+
model: nn.Module,
109+
ft_config: FTConfig,
110+
n_layers: int,
111+
) -> list[nn.Module]:
112+
assert ft_config.num_fragments > 0
113+
114+
module_fqns_per_model_fragment = ft_config.module_fqns_per_model_fragment
115+
116+
input_weight = 1 # Weight for tok_embeddings
117+
output_weight = 1 # Weight for norm + output layers
118+
119+
if module_fqns_per_model_fragment == []:
120+
if ft_config.num_fragments == 1:
121+
return [model]
122+
123+
module_fqns_per_model_fragment = generate_llm_fqn_per_model_part(
124+
ft_config.num_fragments, n_layers, input_weight, output_weight
125+
)
126+
127+
model_fragments = module_split(model, module_fqns_per_model_fragment)
128+
print(f"Created {len(model_fragments)} model fragments")
129+
130+
return model_fragments

torchtitan/components/ft.py renamed to torchtitan/components/ft/manager.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import importlib
88
from contextlib import nullcontext
99
from datetime import timedelta
10-
from typing import ContextManager, Optional, TYPE_CHECKING, Union
10+
from typing import Callable, ContextManager, Optional, TYPE_CHECKING, Union
1111

1212
import torch
1313
import torch.distributed as dist
14+
15+
import torch.nn as nn
1416
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
1517
from torch.distributed.distributed_c10d import ReduceOp
1618
from torchtitan.config.job_config import FaultTolerance as FTConfig
@@ -108,8 +110,10 @@ def loss_sync_pg(
108110
def maybe_semi_sync_training(
109111
ft_config: FTConfig,
110112
ft_manager: FTManager,
111-
model_parts: list[torch.nn.Module],
113+
model: torch.nn.Module,
114+
n_layers: int,
112115
optimizer: torch.optim.Optimizer,
116+
fragment_fn: Optional[Callable[..., list[nn.Module]]] = None,
113117
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
114118
"""
115119
If TorchFT is enabled and the config is set, use semi_sync_method
@@ -122,6 +126,11 @@ def maybe_semi_sync_training(
122126
ft_manager._manager is not None
123127
), "FTManager must be enabled to use semi-sync training."
124128
if semi_sync_method.lower() == "diloco":
129+
if fragment_fn:
130+
model_parts = fragment_fn(model, ft_config, n_layers)
131+
else:
132+
model_parts = [model]
133+
125134
# Create the outer optimizer based on the inner optimizer parameters.
126135
outer_optimizers = []
127136
for model in model_parts:
@@ -142,10 +151,9 @@ def maybe_semi_sync_training(
142151
fragment_update_alpha=ft_config.fragment_update_alpha,
143152
)
144153
elif semi_sync_method.lower() == "local_sgd":
145-
assert len(model_parts) == 1
146154
return local_sgd.LocalSGD(
147155
manager=ft_manager._manager,
148-
model=model_parts[0],
156+
model=model,
149157
optimizer=optimizer,
150158
sync_every=ft_config.sync_steps,
151159
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Import to register quantization modules.
8+
from torchtitan.components.ft.protocol.model import FaultTolerantTrainSpec
9+
10+
11+
__all__ = [
12+
"FaultTolerantTrainSpec",
13+
]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Callable, TypeAlias
9+
10+
import torch.nn as nn
11+
from torchtitan.protocols.train_spec import TrainSpec
12+
13+
14+
FragmentFunction: TypeAlias = Callable[..., list[nn.Module]]
15+
16+
17+
@dataclass
18+
class FaultTolerantTrainSpec(TrainSpec):
19+
fragment_fn: FragmentFunction | None = None

torchtitan/config/job_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,22 @@ class FaultTolerance:
686686
This is only used when "semi_sync_method" is set.
687687
"""
688688

689+
module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list)
690+
"""
691+
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment.
692+
Each inner list represents one model fragment and contains the module names that belong to that fragment.
693+
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
694+
will create 3 chunks: the first containing tok_embeddings and layers.0,
695+
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
696+
"""
697+
698+
num_fragments: int = 1
699+
"""
700+
Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco".
701+
This is used to automatically split the model into fragments provided that the model
702+
implements FaultTolerantTrainSpec
703+
"""
704+
689705

690706
@dataclass
691707
class Experimental:
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchtitan.components.ft import fragment_llm
8+
from torchtitan.components.ft.protocol import FaultTolerantTrainSpec
9+
from torchtitan.components.loss import build_cross_entropy_loss
10+
from torchtitan.components.lr_scheduler import build_lr_schedulers
11+
from torchtitan.components.optimizer import build_optimizers
12+
from torchtitan.components.tokenizer import build_hf_tokenizer
13+
from torchtitan.components.validate import build_validator
14+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
15+
from torchtitan.protocols.train_spec import register_train_spec
16+
from ..llama3 import (
17+
llama3_configs,
18+
Llama3StateDictAdapter,
19+
parallelize_llama,
20+
pipeline_llama,
21+
Transformer,
22+
TransformerModelArgs,
23+
)
24+
25+
__all__ = [
26+
"parallelize_llama",
27+
"pipeline_llama",
28+
"TransformerModelArgs",
29+
"Transformer",
30+
"llama3_configs",
31+
]
32+
33+
34+
register_train_spec(
35+
FaultTolerantTrainSpec(
36+
name="llama3_ft",
37+
model_cls=Transformer,
38+
model_args=llama3_configs,
39+
parallelize_fn=parallelize_llama,
40+
pipelining_fn=pipeline_llama,
41+
fragment_fn=fragment_llm,
42+
build_optimizers_fn=build_optimizers,
43+
build_lr_schedulers_fn=build_lr_schedulers,
44+
build_dataloader_fn=build_hf_dataloader,
45+
build_tokenizer_fn=build_hf_tokenizer,
46+
build_loss_fn=build_cross_entropy_loss,
47+
build_validator_fn=build_validator,
48+
state_dict_adapter=Llama3StateDictAdapter,
49+
)
50+
)

torchtitan/protocols/train_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from collections.abc import Callable
88
from dataclasses import dataclass
9-
from typing import TypeAlias
9+
from typing import Mapping, TypeAlias
1010

1111
import torch.nn as nn
1212
from torch.distributed.pipelining.schedules import _PipelineSchedule
@@ -43,7 +43,7 @@
4343
class TrainSpec:
4444
name: str
4545
model_cls: type[ModelProtocol]
46-
model_args: dict[str, BaseModelArgs]
46+
model_args: Mapping[str, BaseModelArgs]
4747
parallelize_fn: ParallelizeFunction
4848
pipelining_fn: PipeliningFunction | None
4949
build_optimizers_fn: OptimizersBuilder

torchtitan/train.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
4848
lr_schedulers: train_spec_module.LRSchedulersContainer
4949
validator: train_spec_module.BaseValidator
5050
metrics_processor: train_spec_module.MetricsProcessor
51+
model_args: train_spec_module.BaseModelArgs
5152

5253
# non-swappable training components
5354
checkpointer: CheckpointManager
@@ -146,6 +147,7 @@ def __init__(self, job_config: JobConfig):
146147
model_args = self.train_spec.model_args[job_config.model.flavor]
147148
# set the model args from training job configs
148149
model_args.update_from_config(job_config)
150+
self.model_args = model_args
149151

150152
logger.info(
151153
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
@@ -545,8 +547,18 @@ def train(self):
545547
maybe_semi_sync_training(
546548
job_config.fault_tolerance,
547549
ft_manager=self.ft_manager,
548-
model_parts=self.model_parts,
550+
model=self.model_parts[0],
551+
n_layers=(
552+
self.model_args.n_layers
553+
if hasattr(self.model_args, "n_layers")
554+
else 0
555+
),
549556
optimizer=self.optimizers,
557+
fragment_fn=(
558+
self.train_spec.fragment_fn
559+
if hasattr(self.train_spec, "fragment_fn")
560+
else None
561+
),
550562
),
551563
):
552564
data_iterator = self.batch_generator(self.dataloader)

0 commit comments

Comments
 (0)