diff --git a/train.py b/train.py index d1973b6d..8c421fe2 100644 --- a/train.py +++ b/train.py @@ -141,32 +141,71 @@ def loss_fn(pred, labels): ) # apply parallelisms and initialization - if parallel_dims.pp_enabled: - # apply PT-D Pipeline Parallel - pp_schedule, model_parts = models_pipelining_fns[model_name]( - model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn + from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + ) + from torch.distributed._tensor import Replicate, Shard + tp_mesh = world_mesh["tp"] + loss_parallel = parallel_dims.loss_parallel_enabled + + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + for layer_id, transformer_block in model.layers.items(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": colwise_parallel(), + "attention.wk": colwise_parallel(), + "attention.wv": colwise_parallel(), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, ) - # For PP with looped schedules, each item in model_parts is one stage-model-chunk. - # We need to iterate through model_parts to apply SPMD parallelisms, compilation, - # optimizer, and checkpointing - for m in model_parts: - # apply SPMD-style PT-D techniques - models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) - m.to_empty(device="cuda") - m.init_weights() - m.train() - else: - # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) - - # move sharded model to CPU/GPU and initialize weights via DTensor - init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" - model.to_empty(device=init_device) - model.init_weights() - model.train() - - model_parts = [model] + from torch.distributed._composable import replicate + replicate(model, device_mesh=world_mesh["dp"]) + model.to_empty(device="cuda") + model.init_weights() + model.train() + model_parts = [model] gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -294,6 +333,7 @@ def loss_fn(pred, labels): # need to free to before bwd to avoid peaking memory del pred loss.backward() + print(f"{model.layers[list(model.layers.keys())[0]].attention.wq.weight.grad=}") # clip gradients for m in model_parts: diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index bb3cd353..db8fb38b 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -35,9 +35,8 @@ seq_len = 2048 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 10 -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -tensor_parallel_degree = 1 +data_parallel_degree = 4 +tensor_parallel_degree = 2 compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)