44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66import copy
7+
8+ import math
79import os
810from typing import Callable
911
2224 ScheduleZBVZeroBubble ,
2325)
2426
25- from torchtitan .components .loss import rescale_accumulated_loss
27+ from torchtitan .components .loss import LossFunction , rescale_accumulated_loss
2628from torchtitan .config import JobConfig
29+ from torchtitan .distributed import ParallelDims
30+ from torchtitan .protocols .train_spec import BaseModelArgs , ParallelizeFunction
2731from torchtitan .tools .logging import logger
2832
29-
3033__all__ = [
34+ "pipeline_llm" ,
3135 "build_pipeline_schedule" ,
32- "stage_ids_this_rank" ,
3336 "generate_llm_fqn_per_model_part" ,
3437 "pipeline_module_split" ,
3538]
3639
3740
41+ def pipeline_llm (
42+ model : nn .Module ,
43+ parallel_dims : ParallelDims ,
44+ job_config : JobConfig ,
45+ device : torch .device ,
46+ model_args : BaseModelArgs ,
47+ parallelize_fn : ParallelizeFunction ,
48+ loss_fn : LossFunction ,
49+ ) -> tuple [_PipelineSchedule , list [nn .Module ], bool , bool ]:
50+ pp_mesh = parallel_dims .world_mesh ["pp" ]
51+
52+ # Determine the number of virtual stages based on schedule type
53+ schedule_class = get_schedule_class (
54+ job_config .parallelism .pipeline_parallel_schedule
55+ )
56+ is_single_stage_schedule = issubclass (schedule_class , PipelineScheduleSingle )
57+ layers_per_stage = job_config .parallelism .pipeline_parallel_layers_per_stage
58+ if hasattr (model_args , "n_layers" ):
59+ num_layers = model_args .n_layers
60+ else :
61+ raise ValueError ("Model does not have n_layers attribute." )
62+
63+ # You can adjust these weights based on the computational cost of embeddings and output layers
64+ # Higher weights mean these modules are treated as "heavier" in the distribution
65+ input_weight = job_config .parallelism .pipeline_parallel_first_stage_less_layers
66+ output_weight = job_config .parallelism .pipeline_parallel_last_stage_less_layers
67+
68+ # Calculate number of virtual stages
69+ if layers_per_stage is not None :
70+
71+ # Calculate number of virtual stages needed (using ceiling division)
72+ # This allows for unequal distribution where stages can differ by at most 1 layer
73+ num_virtual_stages = math .ceil (
74+ (num_layers + input_weight + output_weight ) / layers_per_stage
75+ )
76+
77+ # Validation: check stages per rank based on schedule type
78+ model_config_info = f"Model has { num_layers } layers with pipeline_parallel_layers_per_stage={ layers_per_stage } "
79+ stage_distribution_info = (
80+ f"resulting in { num_virtual_stages = } across { parallel_dims .pp } PP ranks"
81+ )
82+
83+ if num_virtual_stages % parallel_dims .pp != 0 :
84+ raise ValueError (
85+ f"Number of virtual stages ({ num_virtual_stages } ) must be divisible by "
86+ f"pipeline parallel size ({ parallel_dims .pp } ). "
87+ f"{ model_config_info } . "
88+ f"Please adjust pipeline_parallel_layers_per_stage to a value that results in a number of stages "
89+ f"divisible by { parallel_dims .pp } ."
90+ )
91+
92+ stages_per_rank = num_virtual_stages // parallel_dims .pp
93+
94+ if is_single_stage_schedule and stages_per_rank != 1 :
95+ raise ValueError (
96+ f"Single stage schedule requires exactly 1 stage per rank, but got { stages_per_rank } stages per rank. "
97+ f"{ model_config_info } , { stage_distribution_info } . "
98+ f"Please increase pipeline_parallel_layers_per_stage to { num_layers // parallel_dims .pp } or higher "
99+ f"to achieve 1 stage per rank."
100+ )
101+
102+ if not is_single_stage_schedule and stages_per_rank < 2 :
103+ raise ValueError (
104+ f"Multi-stage schedule requires at least 2 stages per rank, but got { stages_per_rank } stages per rank. "
105+ f"{ model_config_info } , { stage_distribution_info } . "
106+ f"Please decrease pipeline_parallel_layers_per_stage to achieve at least 2 stages per rank."
107+ )
108+ else :
109+ # Fallback to default behavior when layers_per_stage is not provided
110+ # For multi-stage schedules, default is 2 virtual stages per rank
111+ # For single-stage schedules, default is 1 virtual stage per rank
112+ stages_per_rank = 1 if is_single_stage_schedule else 2
113+ num_virtual_stages = parallel_dims .pp * stages_per_rank
114+
115+ module_names_per_stage = job_config .parallelism .module_fqns_per_model_part
116+ if module_names_per_stage is None :
117+ module_names_per_stage = generate_llm_fqn_per_model_part (
118+ num_virtual_stages , num_layers , input_weight , output_weight
119+ )
120+ for i , stage_ms in enumerate (module_names_per_stage ):
121+ logger .debug (f"Stage { i } : { stage_ms } " )
122+
123+ stages , model_parts = pipeline_module_split (
124+ model ,
125+ pp_mesh ,
126+ job_config .parallelism .pipeline_parallel_schedule ,
127+ device ,
128+ module_names_per_stage ,
129+ )
130+
131+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
132+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
133+ # optimizer, and checkpointing
134+ for i , m in enumerate (model_parts ):
135+ # apply SPMD-style PT-D techniques
136+ m = parallelize_fn (m , parallel_dims , job_config )
137+ model_parts [i ] = m
138+ # NOTE: this is to update the model in the stage
139+ # in case the model is modified e.g. by torch.compile
140+ stages [i ].submod = m
141+
142+ pp_schedule = build_pipeline_schedule (job_config , stages , loss_fn )
143+
144+ # This is used in the train loop to determine whether to pass in the input_ids and labels
145+ has_first_stage = False
146+ has_last_stage = False
147+ for stage in stages :
148+ if stage .is_first :
149+ has_first_stage = True
150+ if stage .is_last :
151+ has_last_stage = True
152+
153+ return pp_schedule , model_parts , has_first_stage , has_last_stage
154+
155+
38156def build_pipeline_schedule (
39157 job_config : JobConfig , stages : list [PipelineStage ], loss_fn : Callable
40158) -> _PipelineSchedule :
@@ -105,27 +223,6 @@ def build_pipeline_schedule(
105223 return schedule
106224
107225
108- # TODO(whc) should this be a utility inside torch.pipelining?
109- def stage_ids_this_rank (
110- pp_rank : int , pp_size : int , num_stages : int , style : str = "loop"
111- ) -> tuple [int ]:
112- """Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule"""
113- assert (
114- num_stages % pp_size == 0
115- ), f"num_stages { num_stages } must be evenly divisible by pp_size { pp_size } "
116- stages_per_rank = num_stages // pp_size
117- if style == "loop" :
118- return tuple (pp_rank + s * pp_size for s in range (stages_per_rank ))
119- elif style == "v" :
120- assert (
121- stages_per_rank == 2
122- ), f"v schedules assume 2 stages per rank, got { stages_per_rank } "
123- stage_v_pairs = list (
124- zip (range (pp_size ), range (num_stages - 1 , pp_size - 1 , - 1 ))
125- )
126- return stage_v_pairs [pp_rank ]
127-
128-
129226def generate_llm_fqn_per_model_part (
130227 num_stages : int ,
131228 num_layers : int ,
@@ -277,7 +374,7 @@ def pipeline_module_split(
277374 ]
278375 """
279376 pp_rank = pp_mesh .get_local_rank ()
280- pp_size = pp_mesh .size ()
377+ pp_degree = pp_mesh .size ()
281378
282379 def _build_stage_from_modules (
283380 stage_idx : int , module_names : list [str ], num_stages : int
@@ -286,7 +383,6 @@ def _build_stage_from_modules(
286383
287384 # Create a set of modules to keep for faster lookup
288385 modules_to_keep = set (module_names )
289- logger .info (f"Stage { stage_idx } : Modules to keep: { modules_to_keep } " )
290386 for module_name , module_value in model .named_children ():
291387 # Handle layer-like structures (e.g., "layers.0", "layers.1")
292388 if isinstance (module_value , (nn .ModuleDict , nn .ModuleList )):
@@ -342,7 +438,27 @@ def _build_stage_from_modules(
342438 "v" if schedule_class in (ScheduleZBVZeroBubble , ScheduleDualPipeV ) else "loop"
343439 )
344440
345- for stage_idx in stage_ids_this_rank (pp_rank , pp_size , num_stages , style = style ):
441+ def _get_stage_indices () -> tuple [int ]:
442+ """
443+ Compute the stage ids for the stages that will run on this pp rank
444+ for either a looped or V style schedule
445+ """
446+ assert (
447+ num_stages % pp_degree == 0
448+ ), f"num_stages { num_stages } must be evenly divisible by pp_degree { pp_degree } "
449+ stages_per_rank = num_stages // pp_degree
450+ if style == "loop" :
451+ return tuple (pp_rank + s * pp_degree for s in range (stages_per_rank ))
452+ elif style == "v" :
453+ assert (
454+ stages_per_rank == 2
455+ ), f"v schedules assume 2 stages per rank, got { stages_per_rank } "
456+ stage_v_pairs = list (
457+ zip (range (pp_degree ), range (num_stages - 1 , pp_degree - 1 , - 1 ))
458+ )
459+ return stage_v_pairs [pp_rank ]
460+
461+ for stage_idx in _get_stage_indices ():
346462 module_names = module_names_per_stage [stage_idx ]
347463 stage , model_chunk = _build_stage_from_modules (
348464 stage_idx ,
0 commit comments