Skip to content

Commit 4377685

Browse files
committed
model fragments for diloco
Summary: - add a configuration option for users to provide how they want to partition the model - if this is provided, the model needs to implement `FaultTolerantTrainingSpec` that defines the framentation function to split the model based on the configuration - determine the model fragments in training script to pass to ft manager Test Plan: Running llama3 8b parameters with 2 fragments, 1 step delay, each fragment gets synced every 20 steps <img width="944" height="545" alt="image" src="https://github.com/user-attachments/assets/6d16f486-7260-49d6-8ba3-3e98cd331e58" />
1 parent cf30b29 commit 4377685

File tree

9 files changed

+308
-7
lines changed

9 files changed

+308
-7
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Import to register quantization modules.
8+
from torchtitan.components.ft.diloco import fragment_llm
9+
from torchtitan.components.ft.manager import (
10+
FTManager,
11+
has_torchft,
12+
maybe_semi_sync_training,
13+
)
14+
15+
16+
__all__ = [
17+
"FTManager",
18+
"has_torchft",
19+
"maybe_semi_sync_training",
20+
"fragment_llm",
21+
]

torchtitan/components/ft/diloco.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch.nn as nn
8+
from torchtitan.config.job_config import FaultTolerance as FTConfig
9+
from torchtitan.distributed.pipeline import generate_llm_fqn_per_model_part
10+
11+
12+
def module_split(
13+
model: nn.Module,
14+
module_fqns_per_model_fragment: list[list[str]],
15+
) -> list[nn.Module]:
16+
"""
17+
This API creates fragments based on specified module names for each fragment.
18+
This method updates the model in place.
19+
20+
Args:
21+
model: The complete model to be split
22+
module_fqns_per_model_fragment: List of lists, where each inner list contains the module names
23+
that should be included in that fragment. Module names should be
24+
dot-separated paths. Examples:
25+
- "tok_embeddings" for token embeddings
26+
- "layers.0", "layers.1" for specific transformer layers
27+
- "norm" for the final normalization layer
28+
- "output" for the output projection layer
29+
30+
Returns:
31+
List of model fragments
32+
33+
Example usage:
34+
module_fqns_per_model_fragment = [
35+
["tok_embeddings", "layers.0"], # fragment 0: embeddings + first layer
36+
["layers.1", "layers.2"], # fragment 1: middle layers
37+
["norm", "output"] # fragment 2: final norm + output
38+
]
39+
"""
40+
41+
def _build_fragment_from_modules(
42+
fragment_idx: int, module_names: list[str]
43+
) -> nn.Module:
44+
fragment_model = nn.Module()
45+
# Create a set of modules to keep for faster lookup
46+
modules_to_keep = set(module_names)
47+
print(f"fragment {fragment_idx}: Modules to keep: {modules_to_keep}")
48+
for module_name, module_value in model.named_children():
49+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
50+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
51+
layers_to_keep = {
52+
name.split(".", 1)[1]
53+
for name in modules_to_keep
54+
if name.startswith(f"{module_name}.")
55+
}
56+
57+
if not layers_to_keep:
58+
continue
59+
60+
# Keep only specified layers
61+
if isinstance(module_value, nn.ModuleDict):
62+
for layer_name in list(module_value.keys()):
63+
if layer_name in layers_to_keep:
64+
setattr(
65+
fragment_model,
66+
f"{module_name}.{layer_name}",
67+
module_value[layer_name],
68+
)
69+
else:
70+
indices_to_keep = {
71+
int(idx) for idx in layers_to_keep if idx.isdigit()
72+
}
73+
new_layers = nn.ModuleList(
74+
[
75+
layer
76+
for i, layer in enumerate(module_value)
77+
if i in indices_to_keep
78+
]
79+
)
80+
setattr(fragment_model, module_name, new_layers)
81+
82+
continue
83+
84+
# Handle simple module attributes (e.g., "linear", "norm")
85+
if module_name not in modules_to_keep:
86+
continue
87+
88+
setattr(fragment_model, module_name, module_value)
89+
90+
return fragment_model
91+
92+
num_fragments = len(module_fqns_per_model_fragment)
93+
model_fragments = []
94+
95+
for fragment_idx in range(num_fragments):
96+
module_names = module_fqns_per_model_fragment[fragment_idx]
97+
model_fragment = _build_fragment_from_modules(
98+
fragment_idx,
99+
module_names,
100+
)
101+
print(f"building fragment_idx {fragment_idx} " f"with modules {module_names}")
102+
model_fragments.append(model_fragment)
103+
104+
return model_fragments
105+
106+
107+
def fragment_llm(
108+
model: nn.Module,
109+
ft_config: FTConfig,
110+
n_layers: int,
111+
) -> list[nn.Module]:
112+
assert ft_config.num_fragments > 0
113+
114+
module_fqns_per_model_fragment = ft_config.module_fqns_per_model_fragment
115+
116+
input_weight = 1 # Weight for tok_embeddings
117+
output_weight = 1 # Weight for norm + output layers
118+
119+
if module_fqns_per_model_fragment == []:
120+
if ft_config.num_fragments == 1:
121+
return [model]
122+
123+
module_fqns_per_model_fragment = generate_llm_fqn_per_model_part(
124+
ft_config.num_fragments, n_layers, input_weight, output_weight
125+
)
126+
127+
model_fragments = module_split(model, module_fqns_per_model_fragment)
128+
print(f"Created {len(model_fragments)} model fragments")
129+
130+
return model_fragments

torchtitan/components/ft.py renamed to torchtitan/components/ft/manager.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import importlib
88
from contextlib import nullcontext
99
from datetime import timedelta
10-
from typing import ContextManager, Optional, TYPE_CHECKING, Union
10+
from typing import Callable, ContextManager, Optional, TYPE_CHECKING, Union
1111

1212
import torch
1313
import torch.distributed as dist
14+
15+
import torch.nn as nn
1416
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
1517
from torch.distributed.distributed_c10d import ReduceOp
1618
from torchtitan.config.job_config import FaultTolerance as FTConfig
@@ -108,8 +110,10 @@ def loss_sync_pg(
108110
def maybe_semi_sync_training(
109111
ft_config: FTConfig,
110112
ft_manager: FTManager,
111-
model_parts: list[torch.nn.Module],
113+
model: torch.nn.Module,
114+
n_layers: int,
112115
optimizer: torch.optim.Optimizer,
116+
fragment_fn: Optional[Callable[..., list[nn.Module]]] = None,
113117
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
114118
"""
115119
If TorchFT is enabled and the config is set, use semi_sync_method
@@ -122,6 +126,11 @@ def maybe_semi_sync_training(
122126
ft_manager._manager is not None
123127
), "FTManager must be enabled to use semi-sync training."
124128
if semi_sync_method.lower() == "diloco":
129+
if fragment_fn:
130+
model_parts = fragment_fn(model, ft_config, n_layers)
131+
else:
132+
model_parts = [model]
133+
125134
# Create the outer optimizer based on the inner optimizer parameters.
126135
outer_optimizers = []
127136
for model in model_parts:
@@ -142,10 +151,9 @@ def maybe_semi_sync_training(
142151
fragment_update_alpha=ft_config.fragment_update_alpha,
143152
)
144153
elif semi_sync_method.lower() == "local_sgd":
145-
assert len(model_parts) == 1
146154
return local_sgd.LocalSGD(
147155
manager=ft_manager._manager,
148-
model=model_parts[0],
156+
model=model,
149157
optimizer=optimizer,
150158
sync_every=ft_config.sync_steps,
151159
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Import to register quantization modules.
8+
from torchtitan.components.ft.protocol.model import (
9+
FaultTolerantModelArgs,
10+
FaultTolerantTrainSpec,
11+
)
12+
13+
14+
__all__ = [
15+
"FaultTolerantModelArgs",
16+
"FaultTolerantTrainSpec",
17+
]
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Callable, TypeAlias
9+
10+
import torch.nn as nn
11+
from torchtitan.protocols.train_spec import BaseModelArgs, TrainSpec
12+
13+
14+
FragmentFunction: TypeAlias = Callable[..., list[nn.Module]]
15+
16+
17+
@dataclass
18+
class FaultTolerantModelArgs(BaseModelArgs):
19+
n_layers: int = 0
20+
21+
22+
@dataclass
23+
class FaultTolerantTrainSpec(TrainSpec):
24+
fragment_fn: FragmentFunction | None = None

torchtitan/config/job_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,22 @@ class FaultTolerance:
686686
This is only used when "semi_sync_method" is set.
687687
"""
688688

689+
module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list)
690+
"""
691+
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment.
692+
Each inner list represents one model fragment and contains the module names that belong to that fragment.
693+
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
694+
will create 3 chunks: the first containing tok_embeddings and layers.0,
695+
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
696+
"""
697+
698+
num_fragments: int = 1
699+
"""
700+
Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco".
701+
This is used to automatically split the model into fragments provided that the model
702+
implements FaultTolerantTrainSpec
703+
"""
704+
689705

690706
@dataclass
691707
class Experimental:
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast
8+
9+
from torchtitan.components.ft import fragment_llm
10+
from torchtitan.components.ft.protocol import (
11+
FaultTolerantModelArgs,
12+
FaultTolerantTrainSpec,
13+
)
14+
from torchtitan.components.loss import build_cross_entropy_loss
15+
from torchtitan.components.lr_scheduler import build_lr_schedulers
16+
from torchtitan.components.optimizer import build_optimizers
17+
from torchtitan.components.tokenizer import build_hf_tokenizer
18+
from torchtitan.components.validate import build_validator
19+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
20+
from torchtitan.protocols.model import BaseModelArgs
21+
from torchtitan.protocols.train_spec import register_train_spec
22+
from ..llama3 import (
23+
llama3_configs as base_llama3_configs,
24+
Llama3StateDictAdapter,
25+
parallelize_llama,
26+
pipeline_llama,
27+
Transformer,
28+
TransformerModelArgs,
29+
)
30+
31+
__all__ = [
32+
"parallelize_llama",
33+
"pipeline_llama",
34+
"TransformerModelArgs",
35+
"Transformer",
36+
"llama3_configs",
37+
]
38+
39+
40+
class FaultTolerantTransformerModelArgs(
41+
TransformerModelArgs, FaultTolerantModelArgs, BaseModelArgs
42+
):
43+
def __init__(self, *args, **kwargs):
44+
super().__init__(*args, **kwargs)
45+
46+
47+
llama3_configs = {
48+
key: cast(FaultTolerantTransformerModelArgs, value)
49+
for key, value in base_llama3_configs.items()
50+
}
51+
52+
53+
register_train_spec(
54+
FaultTolerantTrainSpec(
55+
name="llama3",
56+
model_cls=Transformer,
57+
model_args=llama3_configs,
58+
parallelize_fn=parallelize_llama,
59+
pipelining_fn=pipeline_llama,
60+
fragment_fn=fragment_llm,
61+
build_optimizers_fn=build_optimizers,
62+
build_lr_schedulers_fn=build_lr_schedulers,
63+
build_dataloader_fn=build_hf_dataloader,
64+
build_tokenizer_fn=build_hf_tokenizer,
65+
build_loss_fn=build_cross_entropy_loss,
66+
build_validator_fn=build_validator,
67+
state_dict_adapter=Llama3StateDictAdapter,
68+
)
69+
)

torchtitan/protocols/train_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from collections.abc import Callable
88
from dataclasses import dataclass
9-
from typing import TypeAlias
9+
from typing import Mapping, TypeAlias
1010

1111
import torch.nn as nn
1212
from torch.distributed.pipelining.schedules import _PipelineSchedule
@@ -43,7 +43,7 @@
4343
class TrainSpec:
4444
name: str
4545
model_cls: type[ModelProtocol]
46-
model_args: dict[str, BaseModelArgs]
46+
model_args: Mapping[str, BaseModelArgs]
4747
parallelize_fn: ParallelizeFunction
4848
pipelining_fn: PipeliningFunction | None
4949
build_optimizers_fn: OptimizersBuilder

0 commit comments

Comments
 (0)