|  | 
| 15 | 15 | from torch.testing._internal import common_utils | 
| 16 | 16 | from torch.testing._internal.common_utils import run_tests | 
| 17 | 17 | 
 | 
|  | 18 | +from torchao.float8.inference import Float8MMConfig | 
| 18 | 19 | from torchao.quantization import ( | 
| 19 | 20 |     Float8DynamicActivationFloat8WeightConfig, | 
| 20 | 21 |     Float8WeightOnlyConfig, | 
|  | 22 | +    Granularity, | 
| 21 | 23 |     PerRow, | 
| 22 | 24 |     PerTensor, | 
| 23 | 25 |     quantize_, | 
| @@ -82,7 +84,7 @@ def test_fp8_linear_variants( | 
| 82 | 84 |         dtype: torch.dtype, | 
| 83 | 85 |         mode: str, | 
| 84 | 86 |         compile: bool, | 
| 85 |  | -        granularity, | 
|  | 87 | +        granularity: Granularity, | 
| 86 | 88 |         kernel_preference: KernelPreference, | 
| 87 | 89 |         sizes: Tuple, | 
| 88 | 90 |     ): | 
| @@ -148,6 +150,61 @@ def test_fp8_linear_variants( | 
| 148 | 150 |                 f"Quantization error is too high got a SQNR of {error}" | 
| 149 | 151 |             ) | 
| 150 | 152 | 
 | 
|  | 153 | +    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | 
|  | 154 | +    @unittest.skipIf( | 
|  | 155 | +        not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" | 
|  | 156 | +    ) | 
|  | 157 | +    @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) | 
|  | 158 | +    @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) | 
|  | 159 | +    @common_utils.parametrize( | 
|  | 160 | +        "kernel_preference", | 
|  | 161 | +        [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], | 
|  | 162 | +    ) | 
|  | 163 | +    # Inputs are (M,..), K, N | 
|  | 164 | +    @common_utils.parametrize( | 
|  | 165 | +        "sizes", | 
|  | 166 | +        [ | 
|  | 167 | +            ((128,), 256, 128), | 
|  | 168 | +            ((32, 128), 64, 256), | 
|  | 169 | +        ], | 
|  | 170 | +    ) | 
|  | 171 | +    def test_fp8_matmul( | 
|  | 172 | +        self, | 
|  | 173 | +        dtype: torch.dtype, | 
|  | 174 | +        granularity: Granularity, | 
|  | 175 | +        kernel_preference: KernelPreference, | 
|  | 176 | +        sizes: Tuple, | 
|  | 177 | +    ): | 
|  | 178 | +        if ( | 
|  | 179 | +            isinstance(granularity, PerTensor) | 
|  | 180 | +            and kernel_preference == KernelPreference.FBGEMM | 
|  | 181 | +        ): | 
|  | 182 | +            return unittest.skip( | 
|  | 183 | +                "per tensor with fbgemm kernel preferece does not work yet" | 
|  | 184 | +            ) | 
|  | 185 | +        M, N, K = sizes | 
|  | 186 | +        input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") | 
|  | 187 | +        weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda") | 
|  | 188 | +        mm_config = Float8MMConfig() | 
|  | 189 | +        input_tensor_fp8 = Float8Tensor.from_hp( | 
|  | 190 | +            input_tensor, | 
|  | 191 | +            granularity=granularity, | 
|  | 192 | +            mm_config=mm_config, | 
|  | 193 | +            kernel_preference=kernel_preference, | 
|  | 194 | +        ) | 
|  | 195 | +        weight_tensor_fp8 = Float8Tensor.from_hp( | 
|  | 196 | +            weight_tensor, | 
|  | 197 | +            granularity=granularity, | 
|  | 198 | +            mm_config=mm_config, | 
|  | 199 | +            kernel_preference=kernel_preference, | 
|  | 200 | +        ) | 
|  | 201 | +        output_tensor = torch.matmul(input_tensor, weight_tensor) | 
|  | 202 | +        output_tensor_fp8 = torch.matmul(input_tensor_fp8, weight_tensor_fp8) | 
|  | 203 | +        error = compute_error(output_tensor, output_tensor_fp8) | 
|  | 204 | +        assert compute_error(output_tensor, output_tensor_fp8) > 20, ( | 
|  | 205 | +            f"Quantization error is too high got a SQNR of {error}" | 
|  | 206 | +        ) | 
|  | 207 | + | 
| 151 | 208 |     @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) | 
| 152 | 209 |     @unittest.skipIf( | 
| 153 | 210 |         not is_sm_at_least_90(), | 
| @@ -653,6 +710,38 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape): | 
| 653 | 710 | 
 | 
| 654 | 711 |         self.assertEqual(sliced_dequantized, sliced_original) | 
| 655 | 712 | 
 | 
|  | 713 | +    def test_to_dtype_layout(self): | 
|  | 714 | +        x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) | 
|  | 715 | +        x_fp8 = Float8Tensor.from_hp(x) | 
|  | 716 | +        y_fp8 = torch.ops.aten.to.dtype_layout( | 
|  | 717 | +            x_fp8, dtype=x_fp8.dtype, layout=x_fp8.layout, device="cpu" | 
|  | 718 | +        ) | 
|  | 719 | +        self.assertEqual(y_fp8.dtype, x_fp8.dtype) | 
|  | 720 | +        self.assertEqual(y_fp8.layout, x_fp8.layout) | 
|  | 721 | +        self.assertEqual(y_fp8.device, torch.device("cpu")) | 
|  | 722 | + | 
|  | 723 | +    def test_has_compatible_shallow_copy_type(self): | 
|  | 724 | +        x1 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) | 
|  | 725 | +        x2 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) | 
|  | 726 | +        x3 = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16) | 
|  | 727 | +        x1_fp8 = Float8Tensor.from_hp(x1) | 
|  | 728 | +        x2_fp8 = Float8Tensor.from_hp(x2) | 
|  | 729 | +        x3_fp8 = Float8Tensor.from_hp(x3) | 
|  | 730 | +        self.assertFalse(torch._has_compatible_shallow_copy_type(x1, x2_fp8)) | 
|  | 731 | +        self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x2)) | 
|  | 732 | +        self.assertTrue(torch._has_compatible_shallow_copy_type(x1_fp8, x2_fp8)) | 
|  | 733 | +        # Wrong shape | 
|  | 734 | +        self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x3_fp8)) | 
|  | 735 | + | 
|  | 736 | +    def test_transpose(self): | 
|  | 737 | +        x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) | 
|  | 738 | +        x_fp8 = Float8Tensor.from_hp(x) | 
|  | 739 | +        x_fp8_t = x_fp8.t() | 
|  | 740 | +        torch.testing.assert_close(x_fp8_t.qdata, x_fp8.qdata.t(), atol=0, rtol=0) | 
|  | 741 | +        torch.testing.assert_close(x_fp8_t.scale, x_fp8.scale.t(), atol=0, rtol=0) | 
|  | 742 | +        self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0) | 
|  | 743 | +        self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0) | 
|  | 744 | + | 
| 656 | 745 | 
 | 
| 657 | 746 | common_utils.instantiate_parametrized_tests(TestFloat8Tensor) | 
| 658 | 747 | 
 | 
|  | 
0 commit comments