Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torchtitan/components/ft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.ft.manager import (
FTManager,
has_torchft,
maybe_semi_sync_training,
)


__all__ = [
"FTManager",
"has_torchft",
"maybe_semi_sync_training",
]
13 changes: 13 additions & 0 deletions torchtitan/components/ft/diloco/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.ft.diloco.protocol import FaultTolerantTrainSpec
from torchtitan.components.ft.diloco.utils import fragment_llm

__all__ = [
"FaultTolerantTrainSpec",
"fragment_llm",
]
19 changes: 19 additions & 0 deletions torchtitan/components/ft/diloco/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Callable, TypeAlias

import torch.nn as nn
from torchtitan.protocols.train_spec import TrainSpec


FragmentFunction: TypeAlias = Callable[..., list[nn.Module]]


@dataclass
class FaultTolerantTrainSpec(TrainSpec):
fragment_fn: FragmentFunction | None = None
130 changes: 130 additions & 0 deletions torchtitan/components/ft/diloco/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
from torchtitan.config.job_config import FaultTolerance as FTConfig
from torchtitan.distributed.pipeline import generate_llm_fqn_per_model_part


def module_split(
model: nn.Module,
module_fqns_per_model_fragment: list[list[str]],
) -> list[nn.Module]:
"""
This API creates fragments based on specified module names for each fragment.
This method updates the model in place.

Args:
model: The complete model to be split
module_fqns_per_model_fragment: List of lists, where each inner list contains the module names
that should be included in that fragment. Module names should be
dot-separated paths. Examples:
- "tok_embeddings" for token embeddings
- "layers.0", "layers.1" for specific transformer layers
- "norm" for the final normalization layer
- "output" for the output projection layer

Returns:
List of model fragments

Example usage:
module_fqns_per_model_fragment = [
["tok_embeddings", "layers.0"], # fragment 0: embeddings + first layer
["layers.1", "layers.2"], # fragment 1: middle layers
["norm", "output"] # fragment 2: final norm + output
]
"""

def _build_fragment_from_modules(
fragment_idx: int, module_names: list[str]
) -> nn.Module:
fragment_model = nn.Module()
# Create a set of modules to keep for faster lookup
modules_to_keep = set(module_names)
print(f"fragment {fragment_idx}: Modules to keep: {modules_to_keep}")
for module_name, module_value in model.named_children():
# Handle layer-like structures (e.g., "layers.0", "layers.1")
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
layers_to_keep = {
name.split(".", 1)[1]
for name in modules_to_keep
if name.startswith(f"{module_name}.")
}

if not layers_to_keep:
continue

# Keep only specified layers
if isinstance(module_value, nn.ModuleDict):
for layer_name in list(module_value.keys()):
if layer_name in layers_to_keep:
setattr(
fragment_model,
f"{module_name}.{layer_name}",
module_value[layer_name],
)
else:
indices_to_keep = {
int(idx) for idx in layers_to_keep if idx.isdigit()
}
new_layers = nn.ModuleList(
[
layer
for i, layer in enumerate(module_value)
if i in indices_to_keep
]
)
setattr(fragment_model, module_name, new_layers)

continue

# Handle simple module attributes (e.g., "linear", "norm")
if module_name not in modules_to_keep:
continue

setattr(fragment_model, module_name, module_value)

return fragment_model

num_fragments = len(module_fqns_per_model_fragment)
model_fragments = []

for fragment_idx in range(num_fragments):
module_names = module_fqns_per_model_fragment[fragment_idx]
model_fragment = _build_fragment_from_modules(
fragment_idx,
module_names,
)
print(f"building fragment_idx {fragment_idx} " f"with modules {module_names}")
model_fragments.append(model_fragment)

return model_fragments


def fragment_llm(
model: nn.Module,
ft_config: FTConfig,
n_layers: int,
) -> list[nn.Module]:
assert ft_config.num_fragments > 0

module_fqns_per_model_fragment = ft_config.module_fqns_per_model_fragment

input_weight = 1 # Weight for tok_embeddings
output_weight = 1 # Weight for norm + output layers

if module_fqns_per_model_fragment == []:
if ft_config.num_fragments == 1:
return [model]

module_fqns_per_model_fragment = generate_llm_fqn_per_model_part(
ft_config.num_fragments, n_layers, input_weight, output_weight
)

model_fragments = module_split(model, module_fqns_per_model_fragment)
print(f"Created {len(model_fragments)} model fragments")

return model_fragments
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import importlib
from contextlib import nullcontext
from datetime import timedelta
from typing import ContextManager, Optional, TYPE_CHECKING, Union
from typing import Callable, ContextManager, Optional, TYPE_CHECKING, Union

import torch
import torch.distributed as dist

import torch.nn as nn
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
from torch.distributed.distributed_c10d import ReduceOp
from torchtitan.config.job_config import FaultTolerance as FTConfig
Expand Down Expand Up @@ -108,8 +110,10 @@ def loss_sync_pg(
def maybe_semi_sync_training(
ft_config: FTConfig,
ft_manager: FTManager,
model_parts: list[torch.nn.Module],
model: torch.nn.Module,
n_layers: int,
optimizer: torch.optim.Optimizer,
fragment_fn: Optional[Callable[..., list[nn.Module]]] = None,
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
"""
If TorchFT is enabled and the config is set, use semi_sync_method
Expand All @@ -122,6 +126,11 @@ def maybe_semi_sync_training(
ft_manager._manager is not None
), "FTManager must be enabled to use semi-sync training."
if semi_sync_method.lower() == "diloco":
if fragment_fn:
model_parts = fragment_fn(model, ft_config, n_layers)
else:
model_parts = [model]

# Create the outer optimizer based on the inner optimizer parameters.
outer_optimizers = []
for model in model_parts:
Expand All @@ -142,10 +151,9 @@ def maybe_semi_sync_training(
fragment_update_alpha=ft_config.fragment_update_alpha,
)
elif semi_sync_method.lower() == "local_sgd":
assert len(model_parts) == 1
return local_sgd.LocalSGD(
manager=ft_manager._manager,
model=model_parts[0],
model=model,
optimizer=optimizer,
sync_every=ft_config.sync_steps,
)
Expand Down
16 changes: 16 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,22 @@ class FaultTolerance:
This is only used when "semi_sync_method" is set.
"""

module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list)
"""
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment.
Each inner list represents one model fragment and contains the module names that belong to that fragment.
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
will create 3 chunks: the first containing tok_embeddings and layers.0,
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
"""

num_fragments: int = 1
"""
Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco".
This is used to automatically split the model into fragments provided that the model
implements FaultTolerantTrainSpec
"""


@dataclass
class Experimental:
Expand Down
49 changes: 49 additions & 0 deletions torchtitan/models/llama3_ft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.ft.diloco import FaultTolerantTrainSpec, fragment_llm
from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.components.validate import build_validator
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.protocols.train_spec import register_train_spec
from ..llama3 import (
llama3_configs,
Llama3StateDictAdapter,
parallelize_llama,
pipeline_llama,
Transformer,
TransformerModelArgs,
)

__all__ = [
"parallelize_llama",
"pipeline_llama",
"TransformerModelArgs",
"Transformer",
"llama3_configs",
]


register_train_spec(
FaultTolerantTrainSpec(
name="llama3_ft",
model_cls=Transformer,
model_args=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
fragment_fn=fragment_llm,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_validator_fn=build_validator,
state_dict_adapter=Llama3StateDictAdapter,
)
)
4 changes: 2 additions & 2 deletions torchtitan/protocols/train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeAlias
from typing import Mapping, TypeAlias

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule
Expand Down Expand Up @@ -43,7 +43,7 @@
class TrainSpec:
name: str
model_cls: type[ModelProtocol]
model_args: dict[str, BaseModelArgs]
model_args: Mapping[str, BaseModelArgs]
parallelize_fn: ParallelizeFunction
pipelining_fn: PipeliningFunction | None
build_optimizers_fn: OptimizersBuilder
Expand Down
14 changes: 13 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
lr_schedulers: train_spec_module.LRSchedulersContainer
validator: train_spec_module.BaseValidator
metrics_processor: train_spec_module.MetricsProcessor
model_args: train_spec_module.BaseModelArgs

# non-swappable training components
checkpointer: CheckpointManager
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(self, job_config: JobConfig):
model_args = self.train_spec.model_args[job_config.model.flavor]
# set the model args from training job configs
model_args.update_from_config(job_config)
self.model_args = model_args

logger.info(
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
Expand Down Expand Up @@ -545,8 +547,18 @@ def train(self):
maybe_semi_sync_training(
job_config.fault_tolerance,
ft_manager=self.ft_manager,
model_parts=self.model_parts,
model=self.model_parts[0],
n_layers=(
self.model_args.n_layers
if hasattr(self.model_args, "n_layers")
else 0
),
optimizer=self.optimizers,
fragment_fn=(
self.train_spec.fragment_fn
if hasattr(self.train_spec, "fragment_fn")
else None
),
),
):
data_iterator = self.batch_generator(self.dataloader)
Expand Down
Loading