[nvfuserex] Decomposed torch._scaled_mm
#1749
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Write a decomposed, emulate nvfuser definition for
torch._scaled_mm
so that we can be free from worrying about column/row-major of input FP8 matrices oftorch._scaled_mm
, especially in backward.The backward (bottom trace) correctly uses
nv_decomposed_scaled_mm
but the forward, not.The reason does not look clear to me at the moment.
The decomposed
torch._scaled_mm
consists of (1) upcasts of 2 FP8 matrices to FP32, (2) scaling of the two matrices, (3) matmul of the two.