Skip to content

Commit

Permalink
feat(): apply DDP+TP manually
Browse files Browse the repository at this point in the history
  • Loading branch information
yzs981130 committed Sep 13, 2024
1 parent d2a4904 commit 84db7f3
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 27 deletions.
88 changes: 64 additions & 24 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,32 +141,71 @@ def loss_fn(pred, labels):
)

# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
pp_schedule, model_parts = models_pipelining_fns[model_name](
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
from torch.distributed._tensor import Replicate, Shard
tp_mesh = world_mesh["tp"]
loss_parallel = parallel_dims.loss_parallel_enabled

parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
},
)
rowwise_parallel, colwise_parallel, prepare_module_input = (
RowwiseParallel,
ColwiseParallel,
PrepareModuleInput,
)
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
"attention.wv": colwise_parallel(),
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel(),
}

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
m.to_empty(device="cuda")
m.init_weights()
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
model.init_weights()
model.train()

model_parts = [model]
from torch.distributed._composable import replicate
replicate(model, device_mesh=world_mesh["dp"])
model.to_empty(device="cuda")
model.init_weights()
model.train()
model_parts = [model]

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
Expand Down Expand Up @@ -294,6 +333,7 @@ def loss_fn(pred, labels):
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
print(f"{model.layers[list(model.layers.keys())[0]].attention.wq.weight.grad=}")

# clip gradients
for m in model_parts:
Expand Down
5 changes: 2 additions & 3 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ seq_len = 2048
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
data_parallel_degree = 4
tensor_parallel_degree = 2
compile = false
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

Expand Down

0 comments on commit 84db7f3

Please sign in to comment.