-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable CP #433
Enable CP #433
Conversation
…lelism. We currently support only FSDP + CP and CP only. CP + TP is being tested. ghstack-source-id: d57fcdae2fdc2481722471d8d4efbb4f416fe396 Pull Request resolved: #433
reshard_after_forward = ( | ||
int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled | ||
) | ||
fully_shard( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this PR still a WIP? It seems like this just do exact same thing as applying FSDP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is already working. CP requires FSDP (or DDP) to do parameter reduction for all parameters. The other CP related calls are enable_context_parallel()
and context_parallel_buffers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohhh yes that make sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the application of fully_shard
to the model be factored out so that if we change the "wrapping" we do not need to make changes in two places?
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh | ||
if parallel_dims.cp_enabled: | ||
# Manually create another device mesh for now as we don't support | ||
# submesh flattening/reshape yet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc., @wz337
train.py
Outdated
del pred | ||
loss.backward() | ||
with context_parallel_ctx( | ||
buffers=[input_ids, labels, model.freqs_cis], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Context parallelism shards on the sequence dimension of the data. Other buffers that are sharded along the sequence dimension also needs to be adjusted accordingly. In this case, freqs_cis
has the be sharded along the sequence dimension.
torchtitan/parallelisms/__init__.py
Outdated
dp * tp * pp == self.world_size | ||
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" | ||
assert dp * cp * tp * pp == self.world_size, ( | ||
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for consistency
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) " | |
f"Invalid parallel dims: dp({dp}) * cp({cp}) * tp({tp}) * pp({pp}) " |
reshard_after_forward = ( | ||
int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled | ||
) | ||
fully_shard( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the application of fully_shard
to the model be factored out so that if we change the "wrapping" we do not need to make changes in two places?
model = fully_shard( | ||
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit (not from this PR): We should probably change this in the original place too:
torchtitan/torchtitan/parallelisms/parallelize_llama.py
Lines 537 to 539 in 81012d1
model = fully_shard( | |
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled | |
) |
We do not need the
model = fully_shard(model, ...)
and can just call fully_shard(model, ...)
.
train.py
Outdated
# need to free to before bwd to avoid peaking memory | ||
del pred | ||
loss.backward() | ||
with context_parallel_ctx( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit for discussion: These nested context managers lead to a lot of indentation and then, combined with the formatter, lead to a lot of verticality.
I wonder if we should fold context_parallel_ctx
into train_context
since it should not hurt to make train_context
have a larger scope?
# submesh flattening/reshape yet. | ||
dp_mesh = init_device_mesh( | ||
world_mesh.device_type, | ||
(parallel_dims.dp * parallel_dims.cp,), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not follow this exactly. What happens when we call init_device_mesh
with a mesh shape that does not cover the global world?
For example, suppose we are composing FSDP + TP + CP. How does this init_device_mesh
call know to combine the ranks from the CP mesh and FSDP mesh while accounting for the existence of a TP mesh?
This PR adds experimental flags and functions to enable context parallelism. We currently support on ly FSDP + CP and CP only. CP + TP is being tested. [ghstack-poisoned]
ad945df
to
bc51495
Compare
Close in favor of #592 |
Stack from ghstack (oldest at bottom):
This PR adds experimental flags and functions to enable context parallelism. We currently support on ly FSDP + CP and CP only. CP + TP is being tested.