From 338f1838346ecac6598b0d01ae4d99ef765bad9d Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 7 Aug 2024 14:15:08 -0700 Subject: [PATCH 1/3] simply pp vs. non-pp set up [ghstack-poisoned] --- torchtitan/parallelisms/parallelize_llama.py | 22 ++---- train.py | 77 ++++++++++---------- 2 files changed, 47 insertions(+), 52 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index a300c644..18a0f452 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -51,7 +51,7 @@ def parallelize_llama( and not job_config.training.compile ): raise RuntimeError("Async TP requires --training.compile") - model = apply_tp( + apply_tp( model, world_mesh["tp"], loss_parallel=parallel_dims.loss_parallel_enabled, @@ -60,7 +60,7 @@ def parallelize_llama( ) if job_config.activation_checkpoint.mode != "none": - model = apply_ac(model, job_config.activation_checkpoint) + apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if job_config.training.compile: @@ -69,14 +69,14 @@ def parallelize_llama( "fused_rmsnorm is not compatible with torch.compile yet. " "Please use rmsnorm or layernorm." ) - model = apply_compile(model) + apply_compile(model) if parallel_dims.dp_enabled: if parallel_dims.dp_type == "fsdp": dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - model = apply_fsdp( + apply_fsdp( model, dp_mesh, param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], @@ -88,15 +88,13 @@ def parallelize_llama( else: if world_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") - model = apply_ddp( + apply_ddp( model, world_mesh, enable_compile=job_config.training.compile, enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, ) - return model - def apply_tp( model: nn.Module, @@ -110,7 +108,7 @@ def apply_tp( # transformer block's inputs) # 2. Parallelize the root norm layer over the sequence dim # 3. Parallelize the final linear output layer - model = parallelize_module( + parallelize_module( model, tp_mesh, { @@ -192,7 +190,6 @@ def apply_tp( f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" "Tensor Parallelism to the model" ) - return model # for selective op activation checkpointing @@ -273,7 +270,6 @@ def apply_ac(model: nn.Module, ac_config): model.layers.register_module(layer_id, transformer_block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") - return model def apply_compile(model: nn.Module): @@ -286,7 +282,6 @@ def apply_compile(model: nn.Module): model.layers.register_module(layer_id, transformer_block) logger.info("Compiling each TransformerBlock with torch.compile") - return model def apply_fsdp( @@ -329,8 +324,8 @@ def apply_fsdp( module._load_state_dict_pre_hooks.clear() assert len(module._state_dict_pre_hooks) <= 1 module._state_dict_pre_hooks.clear() + logger.info("Applied FSDP to the model") - return model def apply_ddp( @@ -347,7 +342,6 @@ def apply_ddp( else: torch._dynamo.config.optimize_ddp = "ddp_optimizer" - model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) logger.info("Applied DDP to the model") - return model diff --git a/train.py b/train.py index 58d23307..afb74408 100644 --- a/train.py +++ b/train.py @@ -115,17 +115,17 @@ def main(job_config: JobConfig): logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}") with torch.device("meta"): - whole_model = model_cls.from_model_args(model_config) + model = model_cls.from_model_args(model_config) # a no-op hander if float8 is not enabled float8_handler = Float8Handler(job_config, parallel_dims) # swap to Float8Linear based on float8 configs - float8_handler.convert_to_float8_training(whole_model) + float8_handler.convert_to_float8_training(model) # log model size - model_param_count = utils.get_num_params(whole_model) + model_param_count = utils.get_num_params(model) num_flop_per_token = utils.get_num_flop_per_token( - utils.get_num_params(whole_model, exclude_embedding=True), + utils.get_num_params(model, exclude_embedding=True), model_config, job_config.training.seq_len, ) @@ -134,41 +134,46 @@ def main(job_config: JobConfig): f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) - if parallel_dims.pp_enabled: - stages, model_parts = models_pipelining_fns[model_name]( - whole_model, pp_mesh, parallel_dims, job_config, device, model_config - ) - else: - # In 1D/2D cases or PP with simple schedules, model_parts is just one item - # for PP with looped schedules, each item is one stage-model-chunk - # we iterate all model_parts for applying SPMD parallelism, compilation, optimizer, and checkpointing - model_parts = [whole_model] - - # apply PT-D DP/TP parallelisms and activation checkpointing - model_parts = [ - models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) - for m in model_parts - ] - - init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" - for model in model_parts: - model.to_empty(device=init_device) - - # loss fn can be shared by pipeline-parallel or non-pp execution + # loss function to be shared by Pipeline Parallel and spmd training def loss_fn(pred, labels): return torch.nn.functional.cross_entropy( pred.flatten(0, 1), labels.flatten(0, 1) ) + # apply parallelisms and initialization if parallel_dims.pp_enabled: + # apply PT-D Pipeline Parallel + stages, model_parts = models_pipelining_fns[model_name]( + model, pp_mesh, parallel_dims, job_config, device, model_config + ) pp_schedule = build_pipeline_schedule( job_config, parallel_dims, stages, loss_fn ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for m in model_parts: + # apply spmd-style PT-D techniques + models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) + + # In PP, we cannot call init_weights directly because some layers are missing. + # In the future, we may make init_weights handle missing layers, but also have + # to consider RNG seed propagation. For now, we rely on a seed checkpoint to + # initialize the model. + m.to_empty(device="cuda") + m.train() else: - # If PP is enabled, we can't rely on init_weights, because some layers are missing. - # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation. - # allocate sharded model on GPU and initialize weights via DTensor - whole_model.init_weights() + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + + # move sharded model to CPU/GPU and initialize weights via DTensor + init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" + model.to_empty(device=init_device) + model.init_weights() + model.train() + + model_parts = [model] gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -183,10 +188,6 @@ def loss_fn(pred, labels): train_state = TrainState() - # train loop - for model in model_parts: - model.train() - # load initial checkpoint checkpoint = CheckpointManager( dataloader=data_loader, @@ -301,9 +302,9 @@ def loss_fn(pred, labels): loss.backward() # clip gradients - for model in model_parts: + for m in model_parts: torch.nn.utils.clip_grad_norm_( - model.parameters(), job_config.training.max_norm, foreach=True + m.parameters(), job_config.training.max_norm, foreach=True ) # sync float8 amaxes and scales @@ -393,14 +394,14 @@ def loss_fn(pred, labels): train_state.step, force=(train_state.step == job_config.training.steps) ) - # signals the profiler that the next profiling step has started + # signal the profiler that the next profiling step has started if torch_profiler: torch_profiler.step() - if memory_profiler: memory_profiler.step() - # Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished) + # reduce timeout after first train step for faster signal + # (assuming lazy init and compilation are finished) if train_state.step == 1: utils.set_pg_timeouts( timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), From f58ca70ed6ce36f983bd1585bf87e63869de9917 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 7 Aug 2024 15:03:21 -0700 Subject: [PATCH 2/3] Update on "[BE][5/n] simplify pp vs. non-pp set up" This PR restructures the PP vs. non-PP setup in `train.py`: - Now we only have two main if-else for PP vs. non-PP, one in setup phase, the other in training phase. - I think it's already clear to read or copy-paste, and it's not necessary to create separate sub-functions to hold the code. This PR also removes unnecessary module returns in `parallelize_llama`, as we are modifying module in-place. Note that torch.compile and AC require returning and reassigning the module. But since we are doing per-block compile and AC, we achieve that in-place for the whole model by ``` transformer_block = compile/AC(transformer_block) model.layers.register_module(layer_id, transformer_block) ``` [ghstack-poisoned] --- estimation.py | 33 ++++++++------------- torchtitan/parallelisms/__init__.py | 2 -- torchtitan/parallelisms/pipeline_llama.py | 18 +++++++---- torchtitan/parallelisms/pipelining_utils.py | 2 +- train.py | 8 ++--- 5 files changed, 28 insertions(+), 35 deletions(-) diff --git a/estimation.py b/estimation.py index 70fb66cb..13ccd4c1 100644 --- a/estimation.py +++ b/estimation.py @@ -122,33 +122,25 @@ def loss_fn(pred, labels): f"Building {model_name} {job_config.model.flavor} with {model_config}" ) with torch.device("meta"): - whole_model = model_cls.from_model_args(model_config) + model = model_cls.from_model_args(model_config) # a no-op hander if float8 is not enabled float8_handler = Float8Handler(job_config, parallel_dims) # swap to Float8Linear based on float8 configs - float8_handler.convert_to_float8_training(whole_model) + float8_handler.convert_to_float8_training(model) # apply PT-D DP/TP parallelisms and activation checkpointing - model_parts = [whole_model] - model_parts = [ - models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) - for m in model_parts - ] - - init_device = "cuda" - for model in model_parts: - model.to_empty(device=init_device) + models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + model.to_empty(device="cuda") if not active_fake_mode(): - whole_model.init_weights() + model.init_weights() + model.train() # build optimizer after applying parallelisms to the model - optimizers = build_optimizers(model_parts, job_config) + optimizers = build_optimizers([model], job_config) lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) - for model in model_parts: - model.train() logger.info(f"Vocab size: {model_config.vocab_size}") # Create a dummy batch instead of loading from a dataset batch = ( @@ -165,7 +157,7 @@ def loss_fn(pred, labels): device="cuda", ), ) - fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0]) + fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0]) fsdp_memtracker.track_inputs(batch) with fsdp_memtracker: @@ -173,16 +165,15 @@ def loss_fn(pred, labels): input_ids, labels = batch # train step with train_context(): - pred = whole_model(input_ids) + pred = model(input_ids) loss = loss_fn(pred, labels) del pred loss.backward() # clip gradients - for model in model_parts: - torch.nn.utils.clip_grad_norm_( - model.parameters(), job_config.training.max_norm, foreach=True - ) + torch.nn.utils.clip_grad_norm_( + model.parameters(), job_config.training.max_norm, foreach=True + ) # sync float8 amaxes and scales float8_handler.sync_float8_amax_and_scale_history(model) # optimizer step diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index dc06d572..b75cb336 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -8,11 +8,9 @@ from torchtitan.parallelisms.parallel_dims import ParallelDims from torchtitan.parallelisms.parallelize_llama import parallelize_llama from torchtitan.parallelisms.pipeline_llama import pipeline_llama -from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule __all__ = [ - "build_pipeline_schedule", "models_parallelize_fns", "models_pipelining_fns", "ParallelDims", diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/parallelisms/pipeline_llama.py index fa093b6e..67983270 100644 --- a/torchtitan/parallelisms/pipeline_llama.py +++ b/torchtitan/parallelisms/pipeline_llama.py @@ -7,7 +7,7 @@ # This file applies the PT-D pipeline parallelism to the Llama model. import copy -from typing import Union +from typing import Callable, Union import torch import torch.nn as nn @@ -18,7 +18,10 @@ from torchtitan.logging import logger from torchtitan.models.llama.model import ModelArgs from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank +from torchtitan.parallelisms.pipelining_utils import ( + build_pipeline_schedule, + stage_ids_this_rank, +) DeviceType = Union[int, str, torch.device] @@ -31,6 +34,7 @@ def pipeline_llama( job_config: JobConfig, device: DeviceType, model_config: ModelArgs, + loss_fn: Callable[..., torch.Tensor], ): split_mode = job_config.experimental.pipeline_parallel_split_mode valid_split_modes = ("manual", "tracer") @@ -39,14 +43,18 @@ def pipeline_llama( f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}" ) if split_mode == "manual": - return pipeline_llama_manual( + stages, models = pipeline_llama_manual( model, pp_mesh, parallel_dims, job_config, device, model_config ) elif split_mode == "tracer": - return pipeline_llama_tracer( + stages, models = pipeline_llama_tracer( model, pp_mesh, parallel_dims, job_config, device, model_config ) + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + return pp_schedule, models + def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"): """Get meta tensors with the right input shapes used for tracing""" @@ -218,4 +226,4 @@ def pipeline_llama_tracer( group=pp_mesh.get_group(), ) ) - return (stages, models) + return stages, models diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index aafe70fa..a5c61e62 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -14,7 +14,7 @@ from torchtitan.logging import logger -def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn): +def build_pipeline_schedule(job_config, stages, loss_fn): looped_schedule = False if job_config.experimental.pipeline_parallel_schedule == "1f1b": diff --git a/train.py b/train.py index afb74408..390263bb 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,6 @@ from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import ( - build_pipeline_schedule, models_parallelize_fns, models_pipelining_fns, ParallelDims, @@ -143,11 +142,8 @@ def loss_fn(pred, labels): # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel - stages, model_parts = models_pipelining_fns[model_name]( - model, pp_mesh, parallel_dims, job_config, device, model_config - ) - pp_schedule = build_pipeline_schedule( - job_config, parallel_dims, stages, loss_fn + pp_schedule, model_parts = models_pipelining_fns[model_name]( + model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) # For PP with looped schedules, each item in model_parts is one stage-model-chunk. From ff53569d7ece82c9c9f8428e945a40ca68faa03c Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Thu, 8 Aug 2024 15:09:43 -0700 Subject: [PATCH 3/3] Update on "[BE][5/n] simplify pp vs. non-pp set up" This PR refactors the PP vs. non-PP setup in `train.py`: - moves `build_pipeline_schedule ` into `pipeline_llama` which reduces the interface for PP in `train.py` - refactors the set up flow, so that we only have two main if-else for PP vs. non-PP, one in setup phase, the other in training phase. - I think it's already clear to read or copy-paste, and it's not necessary to create separate sub-functions to hold the code. This PR also removes unnecessary module returns in `parallelize_llama`, as we are modifying module in-place. Note that torch.compile and AC require returning and reassigning the module. But since we are doing per-block compile and AC, we achieve that in-place for the whole model by ``` transformer_block = compile/AC(transformer_block) model.layers.register_module(layer_id, transformer_block) ``` [ghstack-poisoned] --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 390263bb..8b0941ee 100644 --- a/train.py +++ b/train.py @@ -133,7 +133,7 @@ def main(job_config: JobConfig): f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) - # loss function to be shared by Pipeline Parallel and spmd training + # loss function to be shared by Pipeline Parallel and SPMD training def loss_fn(pred, labels): return torch.nn.functional.cross_entropy( pred.flatten(0, 1), labels.flatten(0, 1) @@ -150,7 +150,7 @@ def loss_fn(pred, labels): # We need to iterate through model_parts to apply SPMD parallelisms, compilation, # optimizer, and checkpointing for m in model_parts: - # apply spmd-style PT-D techniques + # apply SPMD-style PT-D techniques models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) # In PP, we cannot call init_weights directly because some layers are missing. @@ -269,7 +269,7 @@ def loss_fn(pred, labels): optimizers.zero_grad() if parallel_dims.pp_enabled: - # pipeline parallel forward / backward inside step() call + # Pipeline Parallel forward / backward inside step() call is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 with train_context():