Skip to content

Commit 6e06a49

Browse files
authored
move PP API to model agnostic file (#1868)
We originally thought each model should have its own `pipeline.py` function. However, for most LLMs, it turns out a single function would suffice, and all models which needs PP are reusing `pipeline_llama.py` originally written for llama3. (For diffusion models, the model size doesn't justify the usage of PP.) This PR consolidates them and moves `pipeline_llm` into `torchtitan/distributed/pipeline_parallel.py`. We can refactor later if things change.
1 parent 248aca2 commit 6e06a49

File tree

12 files changed

+166
-209
lines changed

12 files changed

+166
-209
lines changed

tests/unit_tests/test_train_spec.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torchtitan.config import Optimizer as OptimizerConfig
1818
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1919
from torchtitan.distributed.parallel_dims import ParallelDims
20-
from torchtitan.models.llama3 import parallelize_llama, pipeline_llama
20+
from torchtitan.models.llama3 import parallelize_llama
2121
from torchtitan.protocols import BaseModelArgs, ModelProtocol
2222
from torchtitan.protocols.train_spec import (
2323
get_train_spec,
@@ -79,7 +79,7 @@ def test_register_train_spec(self):
7979
model_cls=FakeModel,
8080
model_args=fake_config,
8181
parallelize_fn=parallelize_llama,
82-
pipelining_fn=pipeline_llama,
82+
pipelining_fn=None,
8383
build_optimizers_fn=build_optimizers,
8484
build_lr_schedulers_fn=build_lr_schedulers,
8585
build_dataloader_fn=build_hf_dataloader,
@@ -100,7 +100,7 @@ def test_optim_hook(self):
100100
model_cls=FakeModel,
101101
model_args=fake_config,
102102
parallelize_fn=parallelize_llama,
103-
pipelining_fn=pipeline_llama,
103+
pipelining_fn=None,
104104
build_optimizers_fn=fake_build_optimizers_with_hook,
105105
build_lr_schedulers_fn=build_lr_schedulers,
106106
build_dataloader_fn=build_hf_dataloader,

torchtitan/config/job_config.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -316,18 +316,6 @@ class Parallelism:
316316
of stages. Stages per rank are inferred from split points degree, and schedule.
317317
"""
318318

319-
pipeline_parallel_split_points: list[str] = field(default_factory=list)
320-
"""
321-
DEPRECATED: Use module_fqns_per_model_part instead.
322-
Specify comma-separated names of modules to use as the beginning of a split point.
323-
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
324-
the first containing all the layers up to layers.0,
325-
the second containing layers.0 and up to layers.2,
326-
the third containing layers.2 and all the remaining layers.
327-
Note: fully-automated splitting may be enabled in the future,
328-
but currently the split points must be specified manually.
329-
"""
330-
331319
module_fqns_per_model_part: list[list[str]] | None = None
332320
"""
333321
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk.

torchtitan/distributed/pipeline_parallel.py

Lines changed: 143 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import copy
7+
8+
import math
79
import os
810
from typing import Callable
911

@@ -22,19 +24,135 @@
2224
ScheduleZBVZeroBubble,
2325
)
2426

25-
from torchtitan.components.loss import rescale_accumulated_loss
27+
from torchtitan.components.loss import LossFunction, rescale_accumulated_loss
2628
from torchtitan.config import JobConfig
29+
from torchtitan.distributed import ParallelDims
30+
from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction
2731
from torchtitan.tools.logging import logger
2832

29-
3033
__all__ = [
34+
"pipeline_llm",
3135
"build_pipeline_schedule",
32-
"stage_ids_this_rank",
3336
"generate_llm_fqn_per_model_part",
3437
"pipeline_module_split",
3538
]
3639

3740

41+
def pipeline_llm(
42+
model: nn.Module,
43+
parallel_dims: ParallelDims,
44+
job_config: JobConfig,
45+
device: torch.device,
46+
model_args: BaseModelArgs,
47+
parallelize_fn: ParallelizeFunction,
48+
loss_fn: LossFunction,
49+
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
50+
pp_mesh = parallel_dims.world_mesh["pp"]
51+
52+
# Determine the number of virtual stages based on schedule type
53+
schedule_class = get_schedule_class(
54+
job_config.parallelism.pipeline_parallel_schedule
55+
)
56+
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
57+
layers_per_stage = job_config.parallelism.pipeline_parallel_layers_per_stage
58+
if hasattr(model_args, "n_layers"):
59+
num_layers = model_args.n_layers
60+
else:
61+
raise ValueError("Model does not have n_layers attribute.")
62+
63+
# You can adjust these weights based on the computational cost of embeddings and output layers
64+
# Higher weights mean these modules are treated as "heavier" in the distribution
65+
input_weight = job_config.parallelism.pipeline_parallel_first_stage_less_layers
66+
output_weight = job_config.parallelism.pipeline_parallel_last_stage_less_layers
67+
68+
# Calculate number of virtual stages
69+
if layers_per_stage is not None:
70+
71+
# Calculate number of virtual stages needed (using ceiling division)
72+
# This allows for unequal distribution where stages can differ by at most 1 layer
73+
num_virtual_stages = math.ceil(
74+
(num_layers + input_weight + output_weight) / layers_per_stage
75+
)
76+
77+
# Validation: check stages per rank based on schedule type
78+
model_config_info = f"Model has {num_layers} layers with pipeline_parallel_layers_per_stage={layers_per_stage}"
79+
stage_distribution_info = (
80+
f"resulting in {num_virtual_stages=} across {parallel_dims.pp} PP ranks"
81+
)
82+
83+
if num_virtual_stages % parallel_dims.pp != 0:
84+
raise ValueError(
85+
f"Number of virtual stages ({num_virtual_stages}) must be divisible by "
86+
f"pipeline parallel size ({parallel_dims.pp}). "
87+
f"{model_config_info}. "
88+
f"Please adjust pipeline_parallel_layers_per_stage to a value that results in a number of stages "
89+
f"divisible by {parallel_dims.pp}."
90+
)
91+
92+
stages_per_rank = num_virtual_stages // parallel_dims.pp
93+
94+
if is_single_stage_schedule and stages_per_rank != 1:
95+
raise ValueError(
96+
f"Single stage schedule requires exactly 1 stage per rank, but got {stages_per_rank} stages per rank. "
97+
f"{model_config_info}, {stage_distribution_info}. "
98+
f"Please increase pipeline_parallel_layers_per_stage to {num_layers // parallel_dims.pp} or higher "
99+
f"to achieve 1 stage per rank."
100+
)
101+
102+
if not is_single_stage_schedule and stages_per_rank < 2:
103+
raise ValueError(
104+
f"Multi-stage schedule requires at least 2 stages per rank, but got {stages_per_rank} stages per rank. "
105+
f"{model_config_info}, {stage_distribution_info}. "
106+
f"Please decrease pipeline_parallel_layers_per_stage to achieve at least 2 stages per rank."
107+
)
108+
else:
109+
# Fallback to default behavior when layers_per_stage is not provided
110+
# For multi-stage schedules, default is 2 virtual stages per rank
111+
# For single-stage schedules, default is 1 virtual stage per rank
112+
stages_per_rank = 1 if is_single_stage_schedule else 2
113+
num_virtual_stages = parallel_dims.pp * stages_per_rank
114+
115+
module_names_per_stage = job_config.parallelism.module_fqns_per_model_part
116+
if module_names_per_stage is None:
117+
module_names_per_stage = generate_llm_fqn_per_model_part(
118+
num_virtual_stages, num_layers, input_weight, output_weight
119+
)
120+
for i, stage_ms in enumerate(module_names_per_stage):
121+
logger.debug(f"Stage {i}: {stage_ms}")
122+
123+
stages, model_parts = pipeline_module_split(
124+
model,
125+
pp_mesh,
126+
job_config.parallelism.pipeline_parallel_schedule,
127+
device,
128+
module_names_per_stage,
129+
)
130+
131+
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
132+
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
133+
# optimizer, and checkpointing
134+
for i, m in enumerate(model_parts):
135+
# apply SPMD-style PT-D techniques
136+
m = parallelize_fn(m, parallel_dims, job_config)
137+
model_parts[i] = m
138+
# NOTE: this is to update the model in the stage
139+
# in case the model is modified e.g. by torch.compile
140+
stages[i].submod = m
141+
142+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
143+
144+
# This is used in the train loop to determine whether to pass in the input_ids and labels
145+
has_first_stage = False
146+
has_last_stage = False
147+
for stage in stages:
148+
if stage.is_first:
149+
has_first_stage = True
150+
if stage.is_last:
151+
has_last_stage = True
152+
153+
return pp_schedule, model_parts, has_first_stage, has_last_stage
154+
155+
38156
def build_pipeline_schedule(
39157
job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable
40158
) -> _PipelineSchedule:
@@ -105,27 +223,6 @@ def build_pipeline_schedule(
105223
return schedule
106224

107225

108-
# TODO(whc) should this be a utility inside torch.pipelining?
109-
def stage_ids_this_rank(
110-
pp_rank: int, pp_size: int, num_stages: int, style: str = "loop"
111-
) -> tuple[int]:
112-
"""Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule"""
113-
assert (
114-
num_stages % pp_size == 0
115-
), f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}"
116-
stages_per_rank = num_stages // pp_size
117-
if style == "loop":
118-
return tuple(pp_rank + s * pp_size for s in range(stages_per_rank))
119-
elif style == "v":
120-
assert (
121-
stages_per_rank == 2
122-
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
123-
stage_v_pairs = list(
124-
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
125-
)
126-
return stage_v_pairs[pp_rank]
127-
128-
129226
def generate_llm_fqn_per_model_part(
130227
num_stages: int,
131228
num_layers: int,
@@ -277,7 +374,7 @@ def pipeline_module_split(
277374
]
278375
"""
279376
pp_rank = pp_mesh.get_local_rank()
280-
pp_size = pp_mesh.size()
377+
pp_degree = pp_mesh.size()
281378

282379
def _build_stage_from_modules(
283380
stage_idx: int, module_names: list[str], num_stages: int
@@ -286,7 +383,6 @@ def _build_stage_from_modules(
286383

287384
# Create a set of modules to keep for faster lookup
288385
modules_to_keep = set(module_names)
289-
logger.info(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}")
290386
for module_name, module_value in model.named_children():
291387
# Handle layer-like structures (e.g., "layers.0", "layers.1")
292388
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
@@ -342,7 +438,27 @@ def _build_stage_from_modules(
342438
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
343439
)
344440

345-
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
441+
def _get_stage_indices() -> tuple[int]:
442+
"""
443+
Compute the stage ids for the stages that will run on this pp rank
444+
for either a looped or V style schedule
445+
"""
446+
assert (
447+
num_stages % pp_degree == 0
448+
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
449+
stages_per_rank = num_stages // pp_degree
450+
if style == "loop":
451+
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
452+
elif style == "v":
453+
assert (
454+
stages_per_rank == 2
455+
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
456+
stage_v_pairs = list(
457+
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
458+
)
459+
return stage_v_pairs[pp_rank]
460+
461+
for stage_idx in _get_stage_indices():
346462
module_names = module_names_per_stage[stage_idx]
347463
stage, model_chunk = _build_stage_from_modules(
348464
stage_idx,

torchtitan/experiments/simple_fsdp/deepseek_v3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
1212
from torchtitan.components.tokenizer import build_hf_tokenizer
1313
from torchtitan.datasets.hf_datasets import build_hf_dataloader
14+
from torchtitan.distributed.pipeline_parallel import pipeline_llm
1415
from torchtitan.models.deepseek_v3 import deepseekv3_configs
15-
from torchtitan.models.llama3 import pipeline_llama
1616
from torchtitan.protocols.train_spec import TrainSpec
1717

1818
from .model import SimpleFSDPDeepSeekV3Model
@@ -24,7 +24,7 @@ def get_train_spec() -> TrainSpec:
2424
model_cls=SimpleFSDPDeepSeekV3Model,
2525
model_args=deepseekv3_configs,
2626
parallelize_fn=parallelize_deepseekv3,
27-
pipelining_fn=pipeline_llama,
27+
pipelining_fn=pipeline_llm,
2828
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
2929
build_lr_schedulers_fn=build_lr_schedulers,
3030
build_dataloader_fn=build_hf_dataloader,

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
1212
from torchtitan.distributed import ParallelDims
1313
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
14-
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp
15-
from torchtitan.models.deepseek_v3.infra.parallelize import apply_non_moe_tp
16-
from torchtitan.models.llama3.infra.parallelize import apply_ac
14+
from torchtitan.models.deepseek_v3.infra.parallelize import (
15+
apply_ac,
16+
apply_moe_ep_tp,
17+
apply_non_moe_tp,
18+
)
1719
from torchtitan.tools.logging import logger
1820

1921
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
2022

23+
2124
# Adapted from llama4/infra/parallelize.py
2225
def parallelize_deepseekv3(
2326
model: nn.Module,

torchtitan/experiments/simple_fsdp/llama3/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from torchtitan.components.optimizer import build_optimizers
1212
from torchtitan.components.tokenizer import build_hf_tokenizer
1313
from torchtitan.datasets.hf_datasets import build_hf_dataloader
14-
from torchtitan.models.llama3 import llama3_configs, pipeline_llama
14+
from torchtitan.distributed.pipeline_parallel import pipeline_llm
15+
from torchtitan.models.llama3 import llama3_configs
1516
from torchtitan.protocols.train_spec import TrainSpec
1617

1718
from .model import SimpleFSDPTransformer
@@ -23,7 +24,7 @@ def get_train_spec() -> TrainSpec:
2324
model_cls=SimpleFSDPTransformer,
2425
model_args=llama3_configs,
2526
parallelize_fn=parallelize_llama,
26-
pipelining_fn=pipeline_llama,
27+
pipelining_fn=pipeline_llm,
2728
build_optimizers_fn=build_optimizers,
2829
build_lr_schedulers_fn=build_lr_schedulers,
2930
build_dataloader_fn=build_hf_dataloader,

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
1212
from torchtitan.components.tokenizer import build_hf_tokenizer
1313
from torchtitan.datasets.hf_datasets import build_hf_dataloader
14-
from torchtitan.models.llama3.infra.pipeline import pipeline_llama
14+
from torchtitan.distributed.pipeline_parallel import pipeline_llm
1515
from torchtitan.models.moe import MoEArgs
1616
from torchtitan.protocols.train_spec import TrainSpec
1717

@@ -164,7 +164,7 @@ def get_train_spec() -> TrainSpec:
164164
model_cls=DeepSeekV3Model,
165165
model_args=deepseekv3_configs,
166166
parallelize_fn=parallelize_deepseekv3,
167-
pipelining_fn=pipeline_llama,
167+
pipelining_fn=pipeline_llm,
168168
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
169169
build_lr_schedulers_fn=build_lr_schedulers,
170170
build_dataloader_fn=build_hf_dataloader,

torchtitan/models/llama3/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,16 @@
1010
from torchtitan.components.tokenizer import build_hf_tokenizer
1111
from torchtitan.components.validate import build_validator
1212
from torchtitan.datasets.hf_datasets import build_hf_dataloader
13+
from torchtitan.distributed.pipeline_parallel import pipeline_llm
1314
from torchtitan.protocols.train_spec import TrainSpec
1415

1516
from .infra.parallelize import parallelize_llama
16-
from .infra.pipeline import pipeline_llama
1717
from .model.args import TransformerModelArgs
1818
from .model.model import Transformer
1919
from .model.state_dict_adapter import Llama3StateDictAdapter
2020

2121
__all__ = [
2222
"parallelize_llama",
23-
"pipeline_llama",
2423
"TransformerModelArgs",
2524
"Transformer",
2625
"llama3_configs",
@@ -75,7 +74,7 @@ def get_train_spec() -> TrainSpec:
7574
model_cls=Transformer,
7675
model_args=llama3_configs,
7776
parallelize_fn=parallelize_llama,
78-
pipelining_fn=pipeline_llama,
77+
pipelining_fn=pipeline_llm,
7978
build_optimizers_fn=build_optimizers,
8079
build_lr_schedulers_fn=build_lr_schedulers,
8180
build_dataloader_fn=build_hf_dataloader,

0 commit comments

Comments
 (0)