1212from torch .testing ._internal import common_utils
1313
1414from torchao .dtypes import MarlinSparseLayout , SemiSparseLayout
15+ from torchao .quantization import (
16+ Float8DynamicActivationFloat8SemiSparseWeightConfig ,
17+ Float8DynamicActivationFloat8WeightConfig ,
18+ )
1519from torchao .quantization .quant_api import (
1620 int4_weight_only ,
1721 int8_dynamic_activation_int8_weight ,
1822 quantize_ ,
1923)
2024from torchao .sparsity import apply_fake_sparsity , semi_sparse_weight , sparsify_
25+ from torchao .utils import is_sm_at_least_90
2126
2227logging .basicConfig (
2328 format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" , level = logging .INFO
2429)
2530
26- from torchao .quantization import (
27- Float8DynamicActivationFloat8SemiSparseWeightConfig ,
28- Float8DynamicActivationFloat8WeightConfig ,
29- )
30-
3131
3232class TestSemiStructuredSparse (common_utils .TestCase ):
3333 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
@@ -126,6 +126,7 @@ def test_sparse_marlin(self, compile):
126126
127127 torch .testing .assert_close (dense_result , sparse_result , atol = 3e-1 , rtol = 3e-1 )
128128
129+ @unittest .skipIf (not is_sm_at_least_90 (), "Need H100 to run" )
129130 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
130131 @common_utils .parametrize ("compile" , [True , False ])
131132 def test_fp8_cutlass_sparse (self , compile ):
@@ -155,6 +156,7 @@ def test_fp8_cutlass_sparse(self, compile):
155156
156157 torch .testing .assert_close (dense_result , sparse_result , atol = 3e-1 , rtol = 3e-1 )
157158
159+ @unittest .skipIf (not is_sm_at_least_90 (), "Need H100 to run" )
158160 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
159161 def test_fp8_cutlass_sparse_lowering_op_clone (self ):
160162 with torch .inference_mode ():
@@ -168,6 +170,7 @@ def test_fp8_cutlass_sparse_lowering_op_clone(self):
168170 for o , c in zip (original , cloned ):
169171 torch .testing .assert_close (o , c , atol = 0.0 , rtol = 0.0 )
170172
173+ @unittest .skipIf (not is_sm_at_least_90 (), "Need H100 to run" )
171174 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
172175 def test_fp8_cutlass_sparse_lowering_op_to (self ):
173176 # Need to run with inference mode to avoid dispatching to `aten.to_copy`
0 commit comments