|
6 | 6 | import os |
7 | 7 | from typing import Callable |
8 | 8 |
|
| 9 | +from torch import nn |
| 10 | + |
9 | 11 | from torch.distributed.pipelining.schedules import ( |
10 | 12 | _PipelineSchedule, |
11 | 13 | _PipelineScheduleRuntime, |
|
19 | 21 | from torchtitan.tools.logging import logger |
20 | 22 |
|
21 | 23 |
|
22 | | -__all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"] |
| 24 | +__all__ = [ |
| 25 | + "build_pipeline_schedule", |
| 26 | + "generate_split_points", |
| 27 | + "stage_ids_this_rank", |
| 28 | + "generate_module_names_per_stage", |
| 29 | + "module_split", |
| 30 | +] |
23 | 31 |
|
24 | 32 |
|
25 | 33 | # TODO: It's unclear if this API is general enough to be used by other models. |
@@ -206,6 +214,196 @@ def stage_ids_this_rank( |
206 | 214 | stages_per_rank == 2 |
207 | 215 | ), f"v schedules assume 2 stages per rank, got {stages_per_rank}" |
208 | 216 | stage_v_pairs = list( |
209 | | - zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1)) |
| 217 | + zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1), strict=True) |
210 | 218 | ) |
211 | 219 | return stage_v_pairs[pp_rank] |
| 220 | + |
| 221 | + |
| 222 | +def generate_module_names_per_stage( |
| 223 | + num_stages: int, |
| 224 | + num_layers: int, |
| 225 | + input_weight: int = 1, |
| 226 | + output_weight: int = 1, |
| 227 | +) -> list[list[str]]: |
| 228 | + """ |
| 229 | + Programmatically generates module names per stage for pipeline parallelism with weighting. |
| 230 | +
|
| 231 | + Args: |
| 232 | + num_stages: Number of pipeline stages |
| 233 | + num_layers: Total number of transformer layers in the model |
| 234 | + input_weight: Weight for input modules (tok_embeddings) in layer calculation |
| 235 | + output_weight: Weight for output modules (norm + output) in layer calculation |
| 236 | +
|
| 237 | + Returns: |
| 238 | + List of lists containing module names for each stage |
| 239 | +
|
| 240 | + Example: |
| 241 | + generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2) |
| 242 | + treats embeddings as 2 layers and norm+output as 2 layers for distribution |
| 243 | + """ |
| 244 | + if num_stages < 1: |
| 245 | + raise ValueError("Number of stages must be at least 1") |
| 246 | + |
| 247 | + if num_stages == 1: |
| 248 | + # Single stage gets everything |
| 249 | + layer_names = [f"layers.{i}" for i in range(num_layers)] |
| 250 | + return [["tok_embeddings"] + layer_names + ["norm", "output"]] |
| 251 | + |
| 252 | + # Calculate effective layers including weights |
| 253 | + num_effective_layers = num_layers + input_weight + output_weight |
| 254 | + |
| 255 | + if num_stages > num_effective_layers: |
| 256 | + raise ValueError( |
| 257 | + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" |
| 258 | + ) |
| 259 | + |
| 260 | + # Calculate layers per stage (distribute evenly) |
| 261 | + layers_per_stage = num_effective_layers // num_stages |
| 262 | + extra_layers = num_effective_layers % num_stages |
| 263 | + |
| 264 | + # Ensure each stage gets at least the weight of input/output modules |
| 265 | + if layers_per_stage < max(input_weight, output_weight): |
| 266 | + raise ValueError( |
| 267 | + f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})" |
| 268 | + ) |
| 269 | + |
| 270 | + module_names_per_stage = [] |
| 271 | + current_layer = 0 |
| 272 | + |
| 273 | + for stage_idx in range(num_stages): |
| 274 | + stage_modules = [] |
| 275 | + |
| 276 | + # Calculate effective layers for this stage |
| 277 | + effective_layers_for_stage = layers_per_stage |
| 278 | + if stage_idx < extra_layers: |
| 279 | + effective_layers_for_stage += 1 |
| 280 | + |
| 281 | + # First stage: handle input modules with weighting |
| 282 | + if stage_idx == 0: |
| 283 | + stage_modules.append("tok_embeddings") |
| 284 | + # Account for input weight in layer distribution |
| 285 | + remaining_layers_for_stage = effective_layers_for_stage - input_weight |
| 286 | + |
| 287 | + # Add transformer layers |
| 288 | + for _ in range(remaining_layers_for_stage): |
| 289 | + if current_layer < num_layers: |
| 290 | + stage_modules.append(f"layers.{current_layer}") |
| 291 | + current_layer += 1 |
| 292 | + |
| 293 | + # Last stage: handle output modules with weighting |
| 294 | + elif stage_idx == num_stages - 1: |
| 295 | + # Account for output weight in layer distribution |
| 296 | + remaining_layers_for_stage = effective_layers_for_stage - output_weight |
| 297 | + |
| 298 | + # Add transformer layers |
| 299 | + for _ in range(remaining_layers_for_stage): |
| 300 | + if current_layer < num_layers: |
| 301 | + stage_modules.append(f"layers.{current_layer}") |
| 302 | + current_layer += 1 |
| 303 | + |
| 304 | + # Add output modules |
| 305 | + stage_modules.extend(["norm", "output"]) |
| 306 | + |
| 307 | + # Middle stages: only transformer layers |
| 308 | + else: |
| 309 | + for _ in range(effective_layers_for_stage): |
| 310 | + if current_layer < num_layers: |
| 311 | + stage_modules.append(f"layers.{current_layer}") |
| 312 | + current_layer += 1 |
| 313 | + |
| 314 | + module_names_per_stage.append(stage_modules) |
| 315 | + |
| 316 | + return module_names_per_stage |
| 317 | + |
| 318 | + |
| 319 | +def module_split( |
| 320 | + model: nn.Module, |
| 321 | + module_names_per_stage: list[list[str]], |
| 322 | +) -> list[nn.Module]: |
| 323 | + """ |
| 324 | + This API creates pipeline stages based on specified module names for each stage. |
| 325 | + This method updates the model in place. |
| 326 | +
|
| 327 | + Args: |
| 328 | + model: The complete model to be split |
| 329 | + module_names_per_stage: List of lists, where each inner list contains the module names |
| 330 | + that should be included in that stage. Module names should be |
| 331 | + dot-separated paths. Examples: |
| 332 | + - "tok_embeddings" for token embeddings |
| 333 | + - "layers.0", "layers.1" for specific transformer layers |
| 334 | + - "norm" for the final normalization layer |
| 335 | + - "output" for the output projection layer |
| 336 | +
|
| 337 | + Returns: |
| 338 | + List of model chunks |
| 339 | +
|
| 340 | + Example usage: |
| 341 | + module_names_per_stage = [ |
| 342 | + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer |
| 343 | + ["layers.1", "layers.2"], # Stage 1: middle layers |
| 344 | + ["norm", "output"] # Stage 2: final norm + output |
| 345 | + ] |
| 346 | + """ |
| 347 | + |
| 348 | + def _build_stage_from_modules(stage_idx: int, module_names: list[str]) -> nn.Module: |
| 349 | + stage_model = nn.Module() |
| 350 | + # Create a set of modules to keep for faster lookup |
| 351 | + modules_to_keep = set(module_names) |
| 352 | + print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}") |
| 353 | + for module_name, module_value in model.named_children(): |
| 354 | + # Handle layer-like structures (e.g., "layers.0", "layers.1") |
| 355 | + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): |
| 356 | + layers_to_keep = { |
| 357 | + name.split(".", 1)[1] |
| 358 | + for name in modules_to_keep |
| 359 | + if name.startswith(f"{module_name}.") |
| 360 | + } |
| 361 | + |
| 362 | + if not layers_to_keep: |
| 363 | + continue |
| 364 | + |
| 365 | + # Keep only specified layers |
| 366 | + if isinstance(module_value, nn.ModuleDict): |
| 367 | + for layer_name in list(module_value.keys()): |
| 368 | + if layer_name in layers_to_keep: |
| 369 | + setattr( |
| 370 | + stage_model, |
| 371 | + f"{module_name}.{layer_name}", |
| 372 | + module_value[layer_name], |
| 373 | + ) |
| 374 | + else: |
| 375 | + indices_to_keep = { |
| 376 | + int(idx) for idx in layers_to_keep if idx.isdigit() |
| 377 | + } |
| 378 | + new_layers = nn.ModuleList( |
| 379 | + [ |
| 380 | + layer |
| 381 | + for i, layer in enumerate(module_value) |
| 382 | + if i in indices_to_keep |
| 383 | + ] |
| 384 | + ) |
| 385 | + setattr(stage_model, module_name, new_layers) |
| 386 | + |
| 387 | + continue |
| 388 | + |
| 389 | + # Handle simple module attributes (e.g., "linear", "norm") |
| 390 | + if module_name not in modules_to_keep: |
| 391 | + continue |
| 392 | + |
| 393 | + setattr(stage_model, module_name, module_value) |
| 394 | + |
| 395 | + return stage_model |
| 396 | + |
| 397 | + num_stages = len(module_names_per_stage) |
| 398 | + models = [] |
| 399 | + |
| 400 | + for stage_idx in range(num_stages): |
| 401 | + module_names = module_names_per_stage[stage_idx] |
| 402 | + model_chunk = _build_stage_from_modules( |
| 403 | + stage_idx, |
| 404 | + module_names, |
| 405 | + ) |
| 406 | + logger.info(f"building stage_idx {stage_idx} " f"with modules {module_names}") |
| 407 | + models.append(model_chunk) |
| 408 | + |
| 409 | + return models |
0 commit comments