Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Float8Tensor uses cached transpose if available #524

Closed

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Nov 19, 2023

This PR changes the transpose behavior of Float8Tensor:

  • If update_cache == True, it will compute the transpose and update the cache
  • If update_cache == False and the cache is empty, it will compute the transpose
  • If update_cache == False and the cache is populated, it will return the cached transpose

This is somewhat of a kludge to support transpose caching with Megatron GPT (see NVIDIA/NeMo#7909). Its forward function doesn't keep track of gradient accumulation steps, so it doesn't pass is_first_microbatch to LayerNormLinear or Linear. E.g.:
https://github.com/NVIDIA/Megatron-LM/blob/9290c730d04b482be8fae92a4186fe4ff0c95270/megatron/core/transformer/attention.py#L271C31-L271C31
Compare to NeMo GPT, which contains TE-specific logic like is_first_microbatch:
https://github.com/NVIDIA/NeMo/blob/d81beac52423dbd04b48e4e04567b17df2428e3a/nemo/collections/nlp/modules/common/megatron/transformer.py#L1556

Discussion would be appreciated. This design ping-ponged a few times in #452, e.g. 00b9c31. This approach is convenient with an FP8-aware optimizer since the optimizer doesn't need any access to the TE modules, just the FP8 params. There are also some alternative approaches:

  • Add TE logic to Megatron, especially is_first_microbatch, to keep the current API
  • Add arguments to the TE module constructors to control transpose caching
  • Add attributes to the TE modules or FP8 tensors to control transpose caching

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added bug Something isn't working enhancement New feature or request labels Nov 19, 2023
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@@ -263,7 +263,7 @@ def test_transpose(
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1,
scale: float = 0.5,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought there was a correctness issue that was hidden by scale=1, but I don't think it's actually an issue. Making this non-one does a better job stress-testing this in any case though.

@timmoon10 timmoon10 marked this pull request as draft November 20, 2023 19:21
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 marked this pull request as ready for review November 20, 2023 23:55
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 requested a review from ptrendx November 21, 2023 22:30
@timmoon10
Copy link
Collaborator Author

Test failures are ONNX-related. This is ready to go.

@erhoo82
Copy link
Collaborator

erhoo82 commented Dec 8, 2023

@timmoon10
Can you re-open and close the PR? As I shared, I verified the functionality of this feature.

@timmoon10
Copy link
Collaborator Author

The work in this PR was merged in #529.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants