diff --git a/torchtitan/components/ft/__init__.py b/torchtitan/components/ft/__init__.py new file mode 100644 index 0000000000..308025d39d --- /dev/null +++ b/torchtitan/components/ft/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.ft.manager import ( + FTManager, + has_torchft, + maybe_semi_sync_training, +) + + +__all__ = [ + "FTManager", + "has_torchft", + "maybe_semi_sync_training", +] diff --git a/torchtitan/components/ft/diloco/__init__.py b/torchtitan/components/ft/diloco/__init__.py new file mode 100644 index 0000000000..d99772a274 --- /dev/null +++ b/torchtitan/components/ft/diloco/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.ft.diloco.protocol import FaultTolerantTrainSpec +from torchtitan.components.ft.diloco.utils import fragment_llm + +__all__ = [ + "FaultTolerantTrainSpec", + "fragment_llm", +] diff --git a/torchtitan/components/ft/diloco/protocol.py b/torchtitan/components/ft/diloco/protocol.py new file mode 100644 index 0000000000..15c218ffe2 --- /dev/null +++ b/torchtitan/components/ft/diloco/protocol.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Callable, TypeAlias + +import torch.nn as nn +from torchtitan.protocols.train_spec import TrainSpec + + +FragmentFunction: TypeAlias = Callable[..., list[nn.Module]] + + +@dataclass +class FaultTolerantTrainSpec(TrainSpec): + fragment_fn: FragmentFunction | None = None diff --git a/torchtitan/components/ft/diloco/utils.py b/torchtitan/components/ft/diloco/utils.py new file mode 100644 index 0000000000..f83759cff6 --- /dev/null +++ b/torchtitan/components/ft/diloco/utils.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +from torchtitan.config.job_config import FaultTolerance as FTConfig +from torchtitan.distributed.pipeline import generate_llm_fqn_per_model_part + + +def module_split( + model: nn.Module, + module_fqns_per_model_fragment: list[list[str]], +) -> list[nn.Module]: + """ + This API creates fragments based on specified module names for each fragment. + This method updates the model in place. + + Args: + model: The complete model to be split + module_fqns_per_model_fragment: List of lists, where each inner list contains the module names + that should be included in that fragment. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + List of model fragments + + Example usage: + module_fqns_per_model_fragment = [ + ["tok_embeddings", "layers.0"], # fragment 0: embeddings + first layer + ["layers.1", "layers.2"], # fragment 1: middle layers + ["norm", "output"] # fragment 2: final norm + output + ] + """ + + def _build_fragment_from_modules( + fragment_idx: int, module_names: list[str] + ) -> nn.Module: + fragment_model = nn.Module() + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + print(f"fragment {fragment_idx}: Modules to keep: {modules_to_keep}") + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + + if not layers_to_keep: + continue + + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name in layers_to_keep: + setattr( + fragment_model, + f"{module_name}.{layer_name}", + module_value[layer_name], + ) + else: + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(fragment_model, module_name, new_layers) + + continue + + # Handle simple module attributes (e.g., "linear", "norm") + if module_name not in modules_to_keep: + continue + + setattr(fragment_model, module_name, module_value) + + return fragment_model + + num_fragments = len(module_fqns_per_model_fragment) + model_fragments = [] + + for fragment_idx in range(num_fragments): + module_names = module_fqns_per_model_fragment[fragment_idx] + model_fragment = _build_fragment_from_modules( + fragment_idx, + module_names, + ) + print(f"building fragment_idx {fragment_idx} " f"with modules {module_names}") + model_fragments.append(model_fragment) + + return model_fragments + + +def fragment_llm( + model: nn.Module, + ft_config: FTConfig, + n_layers: int, +) -> list[nn.Module]: + assert ft_config.num_fragments > 0 + + module_fqns_per_model_fragment = ft_config.module_fqns_per_model_fragment + + input_weight = 1 # Weight for tok_embeddings + output_weight = 1 # Weight for norm + output layers + + if module_fqns_per_model_fragment == []: + if ft_config.num_fragments == 1: + return [model] + + module_fqns_per_model_fragment = generate_llm_fqn_per_model_part( + ft_config.num_fragments, n_layers, input_weight, output_weight + ) + + model_fragments = module_split(model, module_fqns_per_model_fragment) + print(f"Created {len(model_fragments)} model fragments") + + return model_fragments diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft/manager.py similarity index 93% rename from torchtitan/components/ft.py rename to torchtitan/components/ft/manager.py index 76f2da3ae5..1a33222c1e 100644 --- a/torchtitan/components/ft.py +++ b/torchtitan/components/ft/manager.py @@ -7,10 +7,12 @@ import importlib from contextlib import nullcontext from datetime import timedelta -from typing import ContextManager, Optional, TYPE_CHECKING, Union +from typing import Callable, ContextManager, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist + +import torch.nn as nn from torch.distributed._composable.fsdp.fully_shard import FSDPModule from torch.distributed.distributed_c10d import ReduceOp from torchtitan.config.job_config import FaultTolerance as FTConfig @@ -108,8 +110,10 @@ def loss_sync_pg( def maybe_semi_sync_training( ft_config: FTConfig, ft_manager: FTManager, - model_parts: list[torch.nn.Module], + model: torch.nn.Module, + n_layers: int, optimizer: torch.optim.Optimizer, + fragment_fn: Optional[Callable[..., list[nn.Module]]] = None, ) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]: """ If TorchFT is enabled and the config is set, use semi_sync_method @@ -122,6 +126,11 @@ def maybe_semi_sync_training( ft_manager._manager is not None ), "FTManager must be enabled to use semi-sync training." if semi_sync_method.lower() == "diloco": + if fragment_fn: + model_parts = fragment_fn(model, ft_config, n_layers) + else: + model_parts = [model] + # Create the outer optimizer based on the inner optimizer parameters. outer_optimizers = [] for model in model_parts: @@ -142,10 +151,9 @@ def maybe_semi_sync_training( fragment_update_alpha=ft_config.fragment_update_alpha, ) elif semi_sync_method.lower() == "local_sgd": - assert len(model_parts) == 1 return local_sgd.LocalSGD( manager=ft_manager._manager, - model=model_parts[0], + model=model, optimizer=optimizer, sync_every=ft_config.sync_steps, ) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 5255de3da4..881470348d 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -686,6 +686,22 @@ class FaultTolerance: This is only used when "semi_sync_method" is set. """ + module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list) + """ + Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment. + Each inner list represents one model fragment and contains the module names that belong to that fragment. + e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']] + will create 3 chunks: the first containing tok_embeddings and layers.0, + the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4. + """ + + num_fragments: int = 1 + """ + Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco". + This is used to automatically split the model into fragments provided that the model + implements FaultTolerantTrainSpec + """ + @dataclass class Experimental: diff --git a/torchtitan/models/llama3_ft/__init__.py b/torchtitan/models/llama3_ft/__init__.py new file mode 100644 index 0000000000..1dc277051b --- /dev/null +++ b/torchtitan/models/llama3_ft/__init__.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.ft.diloco import FaultTolerantTrainSpec, fragment_llm +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.protocols.train_spec import register_train_spec +from ..llama3 import ( + llama3_configs, + Llama3StateDictAdapter, + parallelize_llama, + pipeline_llama, + Transformer, + TransformerModelArgs, +) + +__all__ = [ + "parallelize_llama", + "pipeline_llama", + "TransformerModelArgs", + "Transformer", + "llama3_configs", +] + + +register_train_spec( + FaultTolerantTrainSpec( + name="llama3_ft", + model_cls=Transformer, + model_args=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + fragment_fn=fragment_llm, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + state_dict_adapter=Llama3StateDictAdapter, + ) +) diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 8a782f8b42..fc1ed1b279 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -6,7 +6,7 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import TypeAlias +from typing import Mapping, TypeAlias import torch.nn as nn from torch.distributed.pipelining.schedules import _PipelineSchedule @@ -43,7 +43,7 @@ class TrainSpec: name: str model_cls: type[ModelProtocol] - model_args: dict[str, BaseModelArgs] + model_args: Mapping[str, BaseModelArgs] parallelize_fn: ParallelizeFunction pipelining_fn: PipeliningFunction | None build_optimizers_fn: OptimizersBuilder diff --git a/torchtitan/train.py b/torchtitan/train.py index 369c409a81..96f52749fe 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -48,6 +48,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): lr_schedulers: train_spec_module.LRSchedulersContainer validator: train_spec_module.BaseValidator metrics_processor: train_spec_module.MetricsProcessor + model_args: train_spec_module.BaseModelArgs # non-swappable training components checkpointer: CheckpointManager @@ -146,6 +147,7 @@ def __init__(self, job_config: JobConfig): model_args = self.train_spec.model_args[job_config.model.flavor] # set the model args from training job configs model_args.update_from_config(job_config) + self.model_args = model_args logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" @@ -545,8 +547,18 @@ def train(self): maybe_semi_sync_training( job_config.fault_tolerance, ft_manager=self.ft_manager, - model_parts=self.model_parts, + model=self.model_parts[0], + n_layers=( + self.model_args.n_layers + if hasattr(self.model_args, "n_layers") + else 0 + ), optimizer=self.optimizers, + fragment_fn=( + self.train_spec.fragment_fn + if hasattr(self.train_spec, "fragment_fn") + else None + ), ), ): data_iterator = self.batch_generator(self.dataloader)