-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BE][5/n] simplify pp vs. non-pp set up #510
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,7 +51,7 @@ def parallelize_llama( | |
and not job_config.training.compile | ||
): | ||
raise RuntimeError("Async TP requires --training.compile") | ||
model = apply_tp( | ||
apply_tp( | ||
model, | ||
world_mesh["tp"], | ||
loss_parallel=parallel_dims.loss_parallel_enabled, | ||
|
@@ -60,7 +60,7 @@ def parallelize_llama( | |
) | ||
|
||
if job_config.activation_checkpoint.mode != "none": | ||
model = apply_ac(model, job_config.activation_checkpoint) | ||
apply_ac(model, job_config.activation_checkpoint) | ||
|
||
# turn on per-TransformerBlock compile after AC wrapping and before FSDP | ||
if job_config.training.compile: | ||
|
@@ -69,14 +69,14 @@ def parallelize_llama( | |
"fused_rmsnorm is not compatible with torch.compile yet. " | ||
"Please use rmsnorm or layernorm." | ||
) | ||
model = apply_compile(model) | ||
apply_compile(model) | ||
|
||
if parallel_dims.dp_enabled: | ||
if parallel_dims.dp_type == "fsdp": | ||
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh | ||
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names | ||
|
||
model = apply_fsdp( | ||
apply_fsdp( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this change break anything? IIRC one feature required us to keep returning the model. Maybe it was AC? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Great point. Yes torch.compile and AC require reassigning the model. But since we are doing per-block compile and AC, we achieve that in-place for the whole model by
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any motivation of removing the assignment? I thought an explicit assignment does not look bad There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wanchaol
|
||
model, | ||
dp_mesh, | ||
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], | ||
|
@@ -88,15 +88,13 @@ def parallelize_llama( | |
else: | ||
if world_mesh.ndim > 1: | ||
raise RuntimeError("DDP has not supported > 1D parallelism") | ||
model = apply_ddp( | ||
apply_ddp( | ||
model, | ||
world_mesh, | ||
enable_compile=job_config.training.compile, | ||
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, | ||
) | ||
|
||
return model | ||
|
||
|
||
def apply_tp( | ||
model: nn.Module, | ||
|
@@ -110,7 +108,7 @@ def apply_tp( | |
# transformer block's inputs) | ||
# 2. Parallelize the root norm layer over the sequence dim | ||
# 3. Parallelize the final linear output layer | ||
model = parallelize_module( | ||
parallelize_module( | ||
model, | ||
tp_mesh, | ||
{ | ||
|
@@ -192,7 +190,6 @@ def apply_tp( | |
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" | ||
"Tensor Parallelism to the model" | ||
) | ||
return model | ||
|
||
|
||
# for selective op activation checkpointing | ||
|
@@ -273,7 +270,6 @@ def apply_ac(model: nn.Module, ac_config): | |
model.layers.register_module(layer_id, transformer_block) | ||
|
||
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") | ||
return model | ||
|
||
|
||
def apply_compile(model: nn.Module): | ||
|
@@ -286,7 +282,6 @@ def apply_compile(model: nn.Module): | |
model.layers.register_module(layer_id, transformer_block) | ||
|
||
logger.info("Compiling each TransformerBlock with torch.compile") | ||
return model | ||
|
||
|
||
def apply_fsdp( | ||
|
@@ -329,8 +324,8 @@ def apply_fsdp( | |
module._load_state_dict_pre_hooks.clear() | ||
assert len(module._state_dict_pre_hooks) <= 1 | ||
module._state_dict_pre_hooks.clear() | ||
|
||
logger.info("Applied FSDP to the model") | ||
return model | ||
|
||
|
||
def apply_ddp( | ||
|
@@ -347,7 +342,6 @@ def apply_ddp( | |
else: | ||
torch._dynamo.config.optimize_ddp = "ddp_optimizer" | ||
|
||
model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) | ||
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) | ||
|
||
logger.info("Applied DDP to the model") | ||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc this set of changes is just to clean up estimation to be non-pp compatible, and estimation was copy-pasted from train.py which is why it had model_chunks in the first place. This makes sense to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. In general, estimation.py only supports eager FSDP and doesn't support TP/PP/compile. So let's keep this file simple.