Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aten/src/ATen/native/cuda/RowwiseScaledMM.cu
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,9 @@ void check_inputs(
const at::Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const at::Tensor& out) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 9, "f8f8bf16_rowwise is sm_90 specific.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be another change to not call into this at all?
We should fall back to another implementation for sm10+ right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think in the vary least we should have a tracker for functionality/features we skip on new hardware but we are tracking so that support can be added in full

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about SM_90 minor version? Not relevant at all here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about SM_90 minor version? Not relevant at all here?

It is not relevant. But, I missed SM_89, I will include it as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fall back to another implementation for sm10+ right?

Correct, the current approach/kernel is not compatible with SM_100+, since there are no kernels for the Blackwell machines yet, I propose to just throw an exception. Otherwise, it will fail with a CUTLASS error, which is not an elegant behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Aidyn-A do they still fail with CUTLASS 3.7 btw? Or do we need to wait for 3.8?

That is a good question. I will need to check it. Thanks for reminding me about the CUTLASS update!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Aidyn-A We also need a CUDNN update (only for the versions we started the ManyLinux upgrade on so 2.6/2.8)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Skylion007 for the PR! I just ran the test with the latest CUTLASS 3.8 on SM_100 and got the errors:

FAILED [0.3029s] test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_float8_rowwise_scaling_sanity_use_fast_accum_False_cuda - AssertionError: Tensor-likes are not close!
FAILED [2.5473s] test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_float8_rowwise_scaling_sanity_use_fast_accum_True_cuda - AssertionError: Tensor-likes are not close!
FAILED [0.0025s] test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_mm_vs_emulated_row_wise_bfloat16_cuda - AssertionError: Tensor-likes are not close!

The reason it failed with numerical mismatches is that the kernel was simply aborted with the following message:

ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like __CUDA_ARCH_FEAT_SM90_ALL is not defined, not sure if this a CMAKE or CUTLASS bug.


TORCH_CHECK(a.is_cuda());
TORCH_CHECK(a.device() == b.device());
TORCH_CHECK(scale_a.device() == a.device());
Expand Down
41 changes: 29 additions & 12 deletions test/test_matmul_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@

_IS_SM8X = False
_IS_SM9X = False
_IS_SM10X = False
if TEST_CUDA:
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
_IS_SM10X = torch.cuda.get_device_capability(0)[0] == 10

# Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32
Expand Down Expand Up @@ -659,18 +661,33 @@ def test_float8_error_messages(self, device) -> None:
out_dtype=torch.bfloat16,
)

# Note re.compile is used, not re.escape. This is to accomodate fn vs fnuz type message.
with self.assertRaisesRegex(
RuntimeError,
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
):
torch._scaled_mm(
x_fp8,
y_fp8.to(e5m2_type),
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N), device="cuda"),
out_dtype=torch.bfloat16,
)
if _IS_SM10X:
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"f8f8bf16_rowwise is not implemented on sm_100 or later.",
),
):
torch._scaled_mm(
x_fp8,
y_fp8.to(e5m2_type),
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N), device="cuda"),
out_dtype=torch.bfloat16,
)
else:
# Note re.compile is used, not re.escape. This is to accomodate fn vs fnuz type message.
with self.assertRaisesRegex(
RuntimeError,
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
):
torch._scaled_mm(
x_fp8,
y_fp8.to(e5m2_type),
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N), device="cuda"),
out_dtype=torch.bfloat16,
)

@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
Expand Down
Loading