-
Notifications
You must be signed in to change notification settings - Fork 631
enable Context Parallel #592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3bf7333
afb1051
f99a6f5
4ad6881
038b5ce
4758df2
a6758dd
f570fa8
c102f73
83230fd
2863907
534ce58
0c355e6
172717d
b89e59b
a5e453f
e319ab9
99fe0bc
9bec02c
47c0078
15c00d5
bba36b4
346d721
a5d1fdf
8045cad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -306,6 +306,41 @@ def build_test_list(): | |
| "hsdp+tp", | ||
| ngpu=8, | ||
| ), | ||
| OverrideDefinitions( | ||
| [ | ||
| [ | ||
| "--training.data_parallel_shard_degree=2", | ||
| "--experimental.context_parallel_degree=2", | ||
| ] | ||
| ], | ||
| "FSDP+CP", | ||
| "fsdp+cp", | ||
| ngpu=4, | ||
| ), | ||
| OverrideDefinitions( | ||
| [ | ||
| [ | ||
| "--training.data_parallel_shard_degree=2", | ||
| "--training.data_parallel_replicate_degree=2", | ||
| "--experimental.context_parallel_degree=2", | ||
| ] | ||
| ], | ||
| "HSDP+CP", | ||
| "hsdp+cp", | ||
| ngpu=8, | ||
| ), | ||
| OverrideDefinitions( | ||
| [ | ||
| [ | ||
| "--training.data_parallel_shard_degree=2", | ||
| "--training.tensor_parallel_degree=2", | ||
| "--experimental.context_parallel_degree=2", | ||
| ] | ||
| ], | ||
| "FSDP+TP+CP", | ||
| "fsdp+tp+cp", | ||
| ngpu=8, | ||
| ), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: Looks like FSDP/HSDP + TP + CP is working. How about PP? We can also mention progress in the .md doc later.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, the next step is to test 4D/5D (w/ PP and HSDP) |
||
| OverrideDefinitions( | ||
| [ | ||
| [ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -415,8 +415,9 @@ def _precompute_freqs_cis(self) -> torch.Tensor: | |
| return precompute_freqs_cis( | ||
| self.model_args.dim // self.model_args.n_heads, | ||
| # Need to compute until at least the max token limit for generation | ||
| # (use 2x max sequence length to be safe) | ||
| self.model_args.max_seq_len * 2, | ||
| # TODO: explain in docs/composability.md why we removed the 2x | ||
| # relaxing in our CP enablement PR | ||
| self.model_args.max_seq_len, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc., @tianyu-l Want to understand is this okay? For a general use case, we can also expand the CP to support
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please elaborate a bit on why this change was needed by CP?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tianyu-l CP parallelize on the sequence dimension, anything related to the sequence dimension needs to be shard. So freqs_cis is the positional embedding and is required to be sharded according to the sequence length. So it is easier to support CP if everything has the same sequence length.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds reasonable to me. @awgu to confirm this is OK. Also we need to add a note in docs/composability.md to clarify why this (model change) is needed. It can be addressed in a separate PR; in that case please create issue / leave TODO. |
||
| self.model_args.rope_theta, | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from torch.distributed import DeviceMesh | ||
| from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy | ||
| from torch.distributed._composable.replicate import replicate | ||
|
|
@@ -72,36 +73,51 @@ def parallelize_llama( | |
| ) | ||
| apply_compile(model) | ||
|
|
||
| if parallel_dims.dp_enabled: | ||
| if parallel_dims.dp_shard_enabled: | ||
| if parallel_dims.dp_replicate_enabled: | ||
| dp_mesh = world_mesh["dp_replicate", "dp_shard"] | ||
| else: | ||
| dp_mesh = world_mesh["dp"] | ||
|
|
||
| apply_fsdp( | ||
| model, | ||
| dp_mesh, | ||
| param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], | ||
| reduce_dtype=TORCH_DTYPE_MAP[ | ||
| job_config.training.mixed_precision_reduce | ||
| ], | ||
| tp_enabled=parallel_dims.tp_enabled, | ||
| pp_enabled=parallel_dims.pp_enabled, | ||
| ) | ||
| if parallel_dims.dp_replicate_enabled: | ||
| logger.info("Applied HSDP to the model") | ||
| else: | ||
| logger.info("Applied FSDP to the model") | ||
| if ( | ||
| parallel_dims.dp_shard_enabled | ||
| ): # apply FSDP or HSDP, potentially with Context Parallel | ||
XilunWu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # TODO: instead of flattening the mesh twice, we could've done in a batter way: | ||
| # dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] | ||
| # However, this leads to an error in `DeviceMesh.__get_item__` which I believe is | ||
| # a bug in DeviceMesh. We should fix it and then use the above line. | ||
| dp_mesh_dim_names = ( | ||
| ("dp_replicate", "dp_shard") | ||
| if parallel_dims.dp_replicate_enabled | ||
| else ("dp",) | ||
| ) | ||
| # note that mesh can only be flattened from the finest-grained mesh dimensions | ||
| dp_mesh = ( | ||
| world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to document this design in torchtitan, e.g. in the file
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another question is: given the current implementation, does it mean DP and CP have to be adjacent to each other in the device mesh? E.g. it seems we can't do [TP, CP, PP, DP] (from inner to outer) as in Llama 3.1 paper. Is that correct?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding the mesh order, the current implementation of mesh flattening requires the flattened dimensions are contiguous which means
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the first question on |
||
| if parallel_dims.cp_enabled | ||
| else world_mesh[dp_mesh_dim_names] | ||
| ) | ||
|
|
||
| apply_fsdp( | ||
| model, | ||
| dp_mesh, | ||
| param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], | ||
| reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], | ||
| tp_enabled=parallel_dims.tp_enabled, | ||
| pp_enabled=parallel_dims.pp_enabled, | ||
| ) | ||
|
|
||
| if parallel_dims.dp_replicate_enabled: | ||
| logger.info("Applied HSDP to the model") | ||
| else: | ||
| if world_mesh.ndim > 1: | ||
| raise RuntimeError("DDP has not supported > 1D parallelism") | ||
| apply_ddp( | ||
| model, | ||
| world_mesh, | ||
| enable_compile=job_config.training.compile, | ||
| enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, | ||
| ) | ||
| logger.info("Applied FSDP to the model") | ||
|
|
||
| if parallel_dims.cp_enabled: | ||
| logger.info("Applied Context Parallel to the model") | ||
| elif parallel_dims.dp_replicate_enabled: | ||
| if world_mesh.ndim > 1: | ||
| raise RuntimeError("DDP has not supported > 1D parallelism") | ||
| apply_ddp( | ||
| model, | ||
| world_mesh, | ||
| enable_compile=job_config.training.compile, | ||
| enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, | ||
| ) | ||
|
|
||
|
|
||
| def apply_tp( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,12 +4,13 @@ | |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import contextlib | ||
| import gc | ||
| import os | ||
| import subprocess | ||
| from dataclasses import dataclass | ||
| from datetime import timedelta | ||
| from typing import Optional, Union | ||
| from typing import Generator, List, Optional, Set, Union | ||
|
|
||
| import torch | ||
| import torch.distributed._functional_collectives as funcol | ||
|
|
@@ -101,6 +102,57 @@ def run(self, step_count): | |
| SKIP_CLEANUP = "3" | ||
|
|
||
|
|
||
| def create_context_parallel_ctx( | ||
| cp_mesh: DeviceMesh, | ||
| cp_buffers: List[torch.Tensor], | ||
| cp_seq_dims: List[int], | ||
| cp_no_restore_buffers: Set[torch.Tensor], | ||
| ): | ||
| try: | ||
| from torch.distributed.tensor.experimental import context_parallel | ||
| except ImportError: | ||
| print( | ||
| f"PyTorch version {torch.__version__} does not include the experimental " | ||
| "Context Parallel API. Please update to a newer version." | ||
| ) | ||
|
|
||
| return context_parallel( | ||
| cp_mesh, | ||
| buffers=cp_buffers, | ||
| buffer_seq_dims=cp_seq_dims, | ||
| no_restore_buffers=cp_no_restore_buffers, | ||
| ) | ||
|
|
||
XilunWu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): | ||
| @contextlib.contextmanager | ||
| def context(cp_context: Optional[Generator[None, None, None]] = None): | ||
| with contextlib.ExitStack() as stack: | ||
| if enable_loss_parallel: | ||
| stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) | ||
|
|
||
| if enable_compiled_autograd: | ||
| stack.enter_context( | ||
| torch._dynamo.utils.maybe_enable_compiled_autograd(True) | ||
| ) | ||
|
|
||
| if cp_context is not None: | ||
| from torch.nn.attention import sdpa_kernel, SDPBackend | ||
|
|
||
| # currently we only support these two SDP backends. | ||
| # TODO (xilunwu): support cuDNN backend | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious if you recall what the blocker for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @tmm1 it's simply cudnn attention has a different op signature. I'm adding support now and should be able to have the PR draft out by next week.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @XilunWu, just curious. Is cuDNN backend supported now?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xingchensong Yes!! It's merged in pytorch/pytorch#148537 cc @tmm1 |
||
| stack.enter_context( | ||
| sdpa_kernel( | ||
| [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] | ||
| ) | ||
| ) | ||
| stack.enter_context(cp_context) | ||
|
|
||
| yield | ||
|
|
||
| return context | ||
|
|
||
|
|
||
| def init_distributed(job_config): | ||
| # FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup) | ||
| # to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.