Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/unit_tests/test_train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchtitan.config import Optimizer as OptimizerConfig
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.models.llama3 import parallelize_llama, pipeline_llama
from torchtitan.models.llama3 import parallelize_llama
from torchtitan.protocols import BaseModelArgs, ModelProtocol
from torchtitan.protocols.train_spec import (
get_train_spec,
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_register_train_spec(self):
model_cls=FakeModel,
model_args=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
pipelining_fn=None,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
Expand All @@ -100,7 +100,7 @@ def test_optim_hook(self):
model_cls=FakeModel,
model_args=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
pipelining_fn=None,
build_optimizers_fn=fake_build_optimizers_with_hook,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
Expand Down
12 changes: 0 additions & 12 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,18 +316,6 @@ class Parallelism:
of stages. Stages per rank are inferred from split points degree, and schedule.
"""

pipeline_parallel_split_points: list[str] = field(default_factory=list)
"""
DEPRECATED: Use module_fqns_per_model_part instead.
Specify comma-separated names of modules to use as the beginning of a split point.
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
the first containing all the layers up to layers.0,
the second containing layers.0 and up to layers.2,
the third containing layers.2 and all the remaining layers.
Note: fully-automated splitting may be enabled in the future,
but currently the split points must be specified manually.
"""

module_fqns_per_model_part: list[list[str]] | None = None
"""
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk.
Expand Down
170 changes: 143 additions & 27 deletions torchtitan/distributed/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy

import math
import os
from typing import Callable

Expand All @@ -22,19 +24,135 @@
ScheduleZBVZeroBubble,
)

from torchtitan.components.loss import rescale_accumulated_loss
from torchtitan.components.loss import LossFunction, rescale_accumulated_loss
from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction
from torchtitan.tools.logging import logger


__all__ = [
"pipeline_llm",
"build_pipeline_schedule",
"stage_ids_this_rank",
"generate_llm_fqn_per_model_part",
"pipeline_module_split",
]


def pipeline_llm(
model: nn.Module,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: torch.device,
model_args: BaseModelArgs,
parallelize_fn: ParallelizeFunction,
loss_fn: LossFunction,
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
pp_mesh = parallel_dims.world_mesh["pp"]

# Determine the number of virtual stages based on schedule type
schedule_class = get_schedule_class(
job_config.parallelism.pipeline_parallel_schedule
)
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
layers_per_stage = job_config.parallelism.pipeline_parallel_layers_per_stage
if hasattr(model_args, "n_layers"):
num_layers = model_args.n_layers
else:
raise ValueError("Model does not have n_layers attribute.")

# You can adjust these weights based on the computational cost of embeddings and output layers
# Higher weights mean these modules are treated as "heavier" in the distribution
input_weight = job_config.parallelism.pipeline_parallel_first_stage_less_layers
output_weight = job_config.parallelism.pipeline_parallel_last_stage_less_layers

# Calculate number of virtual stages
if layers_per_stage is not None:

# Calculate number of virtual stages needed (using ceiling division)
# This allows for unequal distribution where stages can differ by at most 1 layer
num_virtual_stages = math.ceil(
(num_layers + input_weight + output_weight) / layers_per_stage
)

# Validation: check stages per rank based on schedule type
model_config_info = f"Model has {num_layers} layers with pipeline_parallel_layers_per_stage={layers_per_stage}"
stage_distribution_info = (
f"resulting in {num_virtual_stages=} across {parallel_dims.pp} PP ranks"
)

if num_virtual_stages % parallel_dims.pp != 0:
raise ValueError(
f"Number of virtual stages ({num_virtual_stages}) must be divisible by "
f"pipeline parallel size ({parallel_dims.pp}). "
f"{model_config_info}. "
f"Please adjust pipeline_parallel_layers_per_stage to a value that results in a number of stages "
f"divisible by {parallel_dims.pp}."
)

stages_per_rank = num_virtual_stages // parallel_dims.pp

if is_single_stage_schedule and stages_per_rank != 1:
raise ValueError(
f"Single stage schedule requires exactly 1 stage per rank, but got {stages_per_rank} stages per rank. "
f"{model_config_info}, {stage_distribution_info}. "
f"Please increase pipeline_parallel_layers_per_stage to {num_layers // parallel_dims.pp} or higher "
f"to achieve 1 stage per rank."
)

if not is_single_stage_schedule and stages_per_rank < 2:
raise ValueError(
f"Multi-stage schedule requires at least 2 stages per rank, but got {stages_per_rank} stages per rank. "
f"{model_config_info}, {stage_distribution_info}. "
f"Please decrease pipeline_parallel_layers_per_stage to achieve at least 2 stages per rank."
)
else:
# Fallback to default behavior when layers_per_stage is not provided
# For multi-stage schedules, default is 2 virtual stages per rank
# For single-stage schedules, default is 1 virtual stage per rank
stages_per_rank = 1 if is_single_stage_schedule else 2
num_virtual_stages = parallel_dims.pp * stages_per_rank

module_names_per_stage = job_config.parallelism.module_fqns_per_model_part
if module_names_per_stage is None:
module_names_per_stage = generate_llm_fqn_per_model_part(
num_virtual_stages, num_layers, input_weight, output_weight
)
for i, stage_ms in enumerate(module_names_per_stage):
logger.debug(f"Stage {i}: {stage_ms}")

stages, model_parts = pipeline_module_split(
model,
pp_mesh,
job_config.parallelism.pipeline_parallel_schedule,
device,
module_names_per_stage,
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for i, m in enumerate(model_parts):
# apply SPMD-style PT-D techniques
m = parallelize_fn(m, parallel_dims, job_config)
model_parts[i] = m
# NOTE: this is to update the model in the stage
# in case the model is modified e.g. by torch.compile
stages[i].submod = m

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, model_parts, has_first_stage, has_last_stage


def build_pipeline_schedule(
job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable
) -> _PipelineSchedule:
Expand Down Expand Up @@ -105,27 +223,6 @@ def build_pipeline_schedule(
return schedule


# TODO(whc) should this be a utility inside torch.pipelining?
def stage_ids_this_rank(
pp_rank: int, pp_size: int, num_stages: int, style: str = "loop"
) -> tuple[int]:
"""Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule"""
assert (
num_stages % pp_size == 0
), f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}"
stages_per_rank = num_stages // pp_size
if style == "loop":
return tuple(pp_rank + s * pp_size for s in range(stages_per_rank))
elif style == "v":
assert (
stages_per_rank == 2
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
stage_v_pairs = list(
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
)
return stage_v_pairs[pp_rank]


def generate_llm_fqn_per_model_part(
num_stages: int,
num_layers: int,
Expand Down Expand Up @@ -277,7 +374,7 @@ def pipeline_module_split(
]
"""
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
pp_degree = pp_mesh.size()

def _build_stage_from_modules(
stage_idx: int, module_names: list[str], num_stages: int
Expand All @@ -286,7 +383,6 @@ def _build_stage_from_modules(

# Create a set of modules to keep for faster lookup
modules_to_keep = set(module_names)
logger.info(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}")
for module_name, module_value in model.named_children():
# Handle layer-like structures (e.g., "layers.0", "layers.1")
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
Expand Down Expand Up @@ -342,7 +438,27 @@ def _build_stage_from_modules(
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
)

for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
def _get_stage_indices() -> tuple[int]:
"""
Compute the stage ids for the stages that will run on this pp rank
for either a looped or V style schedule
"""
assert (
num_stages % pp_degree == 0
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
stages_per_rank = num_stages // pp_degree
if style == "loop":
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
elif style == "v":
assert (
stages_per_rank == 2
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
stage_v_pairs = list(
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
)
return stage_v_pairs[pp_rank]

for stage_idx in _get_stage_indices():
module_names = module_names_per_stage[stage_idx]
stage, model_chunk = _build_stage_from_modules(
stage_idx,
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.distributed.pipeline_parallel import pipeline_llm
from torchtitan.models.deepseek_v3 import deepseekv3_configs
from torchtitan.models.llama3 import pipeline_llama
from torchtitan.protocols.train_spec import TrainSpec

from .model import SimpleFSDPDeepSeekV3Model
Expand All @@ -24,7 +24,7 @@ def get_train_spec() -> TrainSpec:
model_cls=SimpleFSDPDeepSeekV3Model,
model_args=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=pipeline_llama,
pipelining_fn=pipeline_llm,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
Expand Down
9 changes: 6 additions & 3 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp
from torchtitan.models.deepseek_v3.infra.parallelize import apply_non_moe_tp
from torchtitan.models.llama3.infra.parallelize import apply_ac
from torchtitan.models.deepseek_v3.infra.parallelize import (
apply_ac,
apply_moe_ep_tp,
apply_non_moe_tp,
)
from torchtitan.tools.logging import logger

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy


# Adapted from llama4/infra/parallelize.py
def parallelize_deepseekv3(
model: nn.Module,
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/experiments/simple_fsdp/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.llama3 import llama3_configs, pipeline_llama
from torchtitan.distributed.pipeline_parallel import pipeline_llm
from torchtitan.models.llama3 import llama3_configs
from torchtitan.protocols.train_spec import TrainSpec

from .model import SimpleFSDPTransformer
Expand All @@ -23,7 +24,7 @@ def get_train_spec() -> TrainSpec:
model_cls=SimpleFSDPTransformer,
model_args=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
pipelining_fn=pipeline_llm,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.llama3.infra.pipeline import pipeline_llama
from torchtitan.distributed.pipeline_parallel import pipeline_llm
from torchtitan.models.moe import MoEArgs
from torchtitan.protocols.train_spec import TrainSpec

Expand Down Expand Up @@ -164,7 +164,7 @@ def get_train_spec() -> TrainSpec:
model_cls=DeepSeekV3Model,
model_args=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=pipeline_llama,
pipelining_fn=pipeline_llm,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
Expand Down
5 changes: 2 additions & 3 deletions torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.components.validate import build_validator
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.distributed.pipeline_parallel import pipeline_llm
from torchtitan.protocols.train_spec import TrainSpec

from .infra.parallelize import parallelize_llama
from .infra.pipeline import pipeline_llama
from .model.args import TransformerModelArgs
from .model.model import Transformer
from .model.state_dict_adapter import Llama3StateDictAdapter

__all__ = [
"parallelize_llama",
"pipeline_llama",
"TransformerModelArgs",
"Transformer",
"llama3_configs",
Expand Down Expand Up @@ -75,7 +74,7 @@ def get_train_spec() -> TrainSpec:
model_cls=Transformer,
model_args=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
pipelining_fn=pipeline_llm,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
Expand Down
Loading