-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor freqs_cis slice to be safer for PP #321
Conversation
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given batch. Changed: instead of slicing self.freqs_cis down to seqlen at top level transformer based on the input token shape, we slice it down to seqlen inside a transformer layer after we have re-expanded to the full seqlen in cases where TP has sharded across seqlen. In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but we do not generally know this. That makes it hard for stage1 to slice freqs_cis correctly. It's easy to do the slicing deeper inside, since at that point we do know the full seqlen unambiguously. Note: the full self.freqs_cis is stored in memory either way, and the thing passed into every layer is just a view. This change should not be material for memory usage or otherwise. ghstack-source-id: 20ef05e0734e53260366878dfe0fac5e1ab48f1d Pull Request resolved: #321
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense - lgtm!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given batch. Changed: instead of slicing self.freqs_cis down to seqlen at top level transformer based on the input token shape, we slice it down to seqlen inside a transformer layer after we have re-expanded to the full seqlen in cases where TP has sharded across seqlen. In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but we do not generally know this. That makes it hard for stage1 to slice freqs_cis correctly. It's easy to do the slicing deeper inside, since at that point we do know the full seqlen unambiguously. Note: the full self.freqs_cis is stored in memory either way, and the thing passed into every layer is just a view. This change should not be material for memory usage or otherwise. ghstack-source-id: 20ef05e0734e53260366878dfe0fac5e1ab48f1d Pull Request resolved: #321
@@ -76,7 +79,9 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten | |||
""" | |||
ndim = x.ndim | |||
assert 0 <= 1 < ndim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not from this PR: I wonder what the point of the 0 <= 1
part is 😃 .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol. its always good to check your assumptions
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given batch. Changed: instead of slicing self.freqs_cis down to seqlen at top level transformer based on the input token shape, we slice it down to seqlen inside a transformer layer after we have re-expanded to the full seqlen in cases where TP has sharded across seqlen. In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but we do not generally know this. That makes it hard for stage1 to slice freqs_cis correctly. It's easy to do the slicing deeper inside, since at that point we do know the full seqlen unambiguously. Note: the full self.freqs_cis is stored in memory either way, and the thing passed into every layer is just a view. This change should not be material for memory usage or otherwise. ghstack-source-id: 20ef05e0734e53260366878dfe0fac5e1ab48f1d Pull Request resolved: pytorch#321
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given batch. Changed: instead of slicing self.freqs_cis down to seqlen at top level transformer based on the input token shape, we slice it down to seqlen inside a transformer layer after we have re-expanded to the full seqlen in cases where TP has sharded across seqlen. In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but we do not generally know this. That makes it hard for stage1 to slice freqs_cis correctly. It's easy to do the slicing deeper inside, since at that point we do know the full seqlen unambiguously. Note: the full self.freqs_cis is stored in memory either way, and the thing passed into every layer is just a view. This change should not be material for memory usage or otherwise. ghstack-source-id: 20ef05e0734e53260366878dfe0fac5e1ab48f1d Pull Request resolved: pytorch#321
Stack from ghstack (oldest at bottom):
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.
Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.
In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this. That makes it hard for stage1 to slice
freqs_cis correctly. It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.
Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.