From 26e790df61e23b2ba340c36b84eb9940fec100bb Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 24 Sep 2024 10:12:07 -0700 Subject: [PATCH] clean up device checks in float8 unit test files (#923) Summary: While working on rowwise scaling I noticed that some of the CUDA device capability checks we had in the test files did not make sense, cleaning this up. Test Plan: tests pass on my H100 CI, it should skip less tests now since CI only has CUDA capability 8, 9 Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 23 ----------------------- test/float8/test_compile.py | 3 ++- 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index db66a206e..8fb3921f6 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -231,15 +231,6 @@ def test_linear( linear_dtype: torch.dtype, linear_bias: bool, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) @@ -287,16 +278,6 @@ def test_autocast_outputs( emulate: bool, linear_dtype: torch.dtype, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) config = Float8LinearConfig( cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), @@ -334,10 +315,6 @@ def test_autocast_outputs( @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): - emulate = ( - not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0) - ) - m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) config = Float8LinearConfig(emulate=emulate) m = Float8Linear.from_float(copy.deepcopy(m), config) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 8a0458bec..bae62bf77 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -224,7 +224,8 @@ def forward(self, x): return x_hp return x_fp8 - @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available") + # TODO(future): figure out why the test below fails on CUDA capability 8.9 + @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with capability 9.0 or greater not available") def test_float8_with_graph_break_in_the_middle(self): """Test that having Float8Tensor object at the boundary of a subgraph""" cnts = CompileCounterWithBackend("inductor")