Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Op] Refactor qkv processing #46

Merged
merged 4 commits into from
Feb 9, 2023
Merged

[Op] Refactor qkv processing #46

merged 4 commits into from
Feb 9, 2023

Conversation

comaniac
Copy link
Contributor

@comaniac comaniac commented Feb 8, 2023

Description

Pointed out by @szhengac, the current logic that uses .chunk(3, dim=-1) to split qkv assumes different data layouts for TP=1 and TP>1 cases. Specifically, when TP=1, we assume the qkv is contiguous, meaning that the weight layout is [q0q1,...,k0k1, ..., v0v1]. However, when TP>1, since weight is sharded along axis=0, each partitioned weight has [3 * H // TP]. This assumes the qkv layout is interleaved (i.e., [q0k0v0, ...]).

This won't be an issue if we always run the model within the same case, but the produces incorrect results if, for example, we trained the model with TP=2 but now want to fine-tune it with TP=1. Although transposing trained weights could also resolve this issue, this seems not straightforward to users.

This PR fixes this issue by assuming the qkv weights are always interleaved. This is also the methodology used in Megatron-LM. Accordingly, we need to manually transpose the weights in the unit test to match the GPT-2 attention results.

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

@szhengac szhengac merged commit 9d6aed3 into awslabs:main Feb 9, 2023
@comaniac comaniac deleted the fix_qkv branch February 9, 2023 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants