Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable CP #433

Closed
wants to merge 12 commits into from
Closed

Enable CP #433

wants to merge 12 commits into from

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Jun 27, 2024

Stack from ghstack (oldest at bottom):

This PR adds experimental flags and functions to enable context parallelism. We currently support on ly FSDP + CP and CP only. CP + TP is being tested.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jun 27, 2024
…lelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: d57fcdae2fdc2481722471d8d4efbb4f416fe396
Pull Request resolved: #433
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 27, 2024
@fegin fegin changed the title This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested. Enable CP Jun 27, 2024
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jun 27, 2024
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: dce65fea12a547741209b602245f644562068e98
Pull Request resolved: #433
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jun 27, 2024
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: a10894ac7c9835dd47e28afe2ff1ec2924ee51c8
Pull Request resolved: #433
reshard_after_forward = (
int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled
)
fully_shard(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this PR still a WIP? It seems like this just do exact same thing as applying FSDP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is already working. CP requires FSDP (or DDP) to do parameter reduction for all parameters. The other CP related calls are enable_context_parallel() and context_parallel_buffers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhh yes that make sense

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the application of fully_shard to the model be factored out so that if we change the "wrapping" we do not need to make changes in two places?

@fegin fegin requested a review from tianyu-l July 8, 2024 18:12
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jul 8, 2024
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: 7eb849b117af07987669ac1c97deedad2e1647d6
Pull Request resolved: #433
@fegin fegin requested review from wz337 and awgu July 8, 2024 22:45
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
if parallel_dims.cp_enabled:
# Manually create another device mesh for now as we don't support
# submesh flattening/reshape yet.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc., @wz337

train.py Outdated
del pred
loss.backward()
with context_parallel_ctx(
buffers=[input_ids, labels, model.freqs_cis],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context parallelism shards on the sequence dimension of the data. Other buffers that are sharded along the sequence dimension also needs to be adjusted accordingly. In this case, freqs_cis has the be sharded along the sequence dimension.

@fegin fegin requested a review from wconstab July 8, 2024 22:48
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert dp * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for consistency

Suggested change
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
f"Invalid parallel dims: dp({dp}) * cp({cp}) * tp({tp}) * pp({pp}) "

reshard_after_forward = (
int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled
)
fully_shard(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the application of fully_shard to the model be factored out so that if we change the "wrapping" we do not need to make changes in two places?

Comment on lines 490 to 492
model = fully_shard(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (not from this PR): We should probably change this in the original place too:

model = fully_shard(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)

We do not need the model = fully_shard(model, ...) and can just call fully_shard(model, ...).

train.py Outdated
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
with context_parallel_ctx(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit for discussion: These nested context managers lead to a lot of indentation and then, combined with the formatter, lead to a lot of verticality.

I wonder if we should fold context_parallel_ctx into train_context since it should not hurt to make train_context have a larger scope?

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jul 9, 2024
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: d3d3fbaf1ee11d1bbc059b08548a0c552e64b485
Pull Request resolved: #433
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jul 9, 2024
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: 7ccd54fd5cdc306d861a058bdcb787f9e2a7df42
Pull Request resolved: #433
# submesh flattening/reshape yet.
dp_mesh = init_device_mesh(
world_mesh.device_type,
(parallel_dims.dp * parallel_dims.cp,),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not follow this exactly. What happens when we call init_device_mesh with a mesh shape that does not cover the global world?

For example, suppose we are composing FSDP + TP + CP. How does this init_device_mesh call know to combine the ranks from the CP mesh and FSDP mesh while accounting for the existence of a TP mesh?

[ghstack-poisoned]
@fegin fegin mentioned this pull request Jul 16, 2024
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jul 31, 2024
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: 769512b50f250e58cf90771349ef1c8a71d47804
Pull Request resolved: #433
This PR adds experimental flags and functions to enable context parallelism. We currently support on    ly FSDP + CP and CP only. CP + TP is being tested.



[ghstack-poisoned]
fegin added a commit that referenced this pull request Aug 7, 2024
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: 5d4f276bcff9ff53c2caadd161161a9b7a33142a
Pull Request resolved: #433
@fegin fegin marked this pull request as draft August 7, 2024 06:45
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: 20b884454fd5c989ac270fc925408396fad8bc52
Pull Request resolved: #433
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: 5d4f276bcff9ff53c2caadd161161a9b7a33142a
Pull Request resolved: #433
[ghstack-poisoned]
fegin added a commit that referenced this pull request Aug 16, 2024
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: 923da26044006e9711dec1f1e106cb76fb501497
Pull Request resolved: #433
@fegin
Copy link
Contributor Author

fegin commented Oct 1, 2024

Close in favor of #592

@fegin fegin closed this Oct 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants