Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Schedule] Support sequence parallelism #6

Merged
merged 14 commits into from
Jan 23, 2023
Merged

Conversation

comaniac
Copy link
Contributor

@comaniac comaniac commented Jan 19, 2023

Description

  • Support sequential parallelism. Specifically, now we can schedule as follows:
# output would be partial
sch["attention.out_proj"].shard("weight", axis=1)
# this indicates that we sync the output in forward pass, but defer the gather
# (at output axis=1) until resid_dropout. in other words, the schedule becomes:
# linear -> reduce_scatter -> dropout -> all_gather.
sch["attention.out_proj"].sync(
    mode="forward_defer_gather",
    gather_at=(sch["attention.resid_dropout"], 1)
)

Note that dist.reduce_scatter always scatters along the first dimension, so we implicitly transpose input and output tensors when needed.


UPDATE
Per offline discussion, the programming model is changed:

sch["attention.out_proj"].sync(mode="fwd_post", sync_op_or_fn="reduce_scatter", axis=1)
sch["attention.resid_dropout"].sync(mode="fwd_post", sync_op_or_fn="all_gather", axis=1)

Other 1: For readability and flexibility, now .sync requires users to always specify the op.
Other 2: The .hook primitive is integrated to .sync, which now allows a custom hook function.


  • Accordingly, a unit test was added to verify the correctness for both forward and backward.

  • Use dist.all_gather_into_tensor when available. This API was dist._all_gather_base, which will be deprecated soon. Note that dist.all_gather_into_tensor always concats along the first dimension, so we implicitly transpose input and output tensors when needed. Even with the transpose overheads, it seems still faster than dist.all_gather + torch.cat based on my micro benchmarks.

  • Move sharding logic to another place and establish a registration mechanism for sharable modules.

  • [MIsc] Refactor task_lint.sh so that we could run it locally to verify linting without being installed transformers everytime.

  • [Misc] Add conftest.py to enforce the test order; otherwise the distributed tests may stuck if distributed devices are not running the same test at the same time.

  • [Misc] Remove DeepSpeed from docker image for now, so that we could make it public for CI.

cc @szhengac @chhzh123

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

@comaniac comaniac changed the title Seq para [Schedule] Support sequence parallelism Jan 19, 2023
examples/opt/schedule.py Outdated Show resolved Hide resolved
slapo/schedule.py Outdated Show resolved Hide resolved
slapo/schedule.py Outdated Show resolved Hide resolved
slapo/schedule.py Outdated Show resolved Hide resolved
slapo/schedule.py Outdated Show resolved Hide resolved
slapo/sharding/utils.py Outdated Show resolved Hide resolved
conftest.py Show resolved Hide resolved
examples/gpt/schedule.py Show resolved Hide resolved
examples/gpt/schedule.py Outdated Show resolved Hide resolved
slapo/schedule.py Show resolved Hide resolved
@comaniac comaniac merged commit 54af148 into awslabs:main Jan 23, 2023
@comaniac
Copy link
Contributor Author

Thanks @szhengac

@comaniac comaniac deleted the seq_para branch January 23, 2023 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants