Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX #471

Merged
merged 3 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,27 @@ def apply_tp(
parallelize_plan=layer_plan,
)

# updates expressly for async tensor parallel
if job_config.experimental.enable_async_tensor_parallel:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._dynamo.config.cache_size_limit = 10000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should unconditionally set this since we are doing per-transformer block compile. @Chillee do you know if it's the right thing to do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point and happy to make this more refined as atm no guarantee the setting above is optimal, but something like it is required to avoid the silent fail.
I'm going to go ahead and merge as -s since it's required atm to avoid users from outright silent failing when they enable async_tp, and happy to update with more refined limit(s) based on @Chillee feedback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the right thing to do here is to avoid recompilation for per-TransformerBlock compile, which requires fixing this pytorch/pytorch#125836 and enable inline_built_nn_modules by default

logger.info(
"Updating torch._dynamo.config.cache_size_limit to 10000 to support Async TP"
)

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info("Applied Tensor Parallelism to the model")
if not job_config.training.compile:
logger.warning(
"Async TP requires compilation...auto enabling compile = True for this job to resolve."
)
job_config.training.compile = True

logger.info(
f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model"
)
return model


Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)

[experimental]
pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

[checkpoint]
enable_checkpoint = false
Expand Down
Loading