Skip to content

Commit

Permalink
Add compile tests to test suite (#906)
Browse files Browse the repository at this point in the history
* Add compile tests to test suite

Summary:
This is a follow up PR addressing #839 (comment)
We can add more compiler related tests in the future.

Next
* refactor a bit to use quantize_ API directly
* use the test suite in existing API tests

Test Plan:
python torchao/testing/utils.py

Reviewers:

Subscribers:

Tasks:

Tags:

* rename

* add result check
  • Loading branch information
jerryzh168 authored Sep 26, 2024
1 parent d267622 commit 64719d5
Showing 1 changed file with 63 additions and 2 deletions.
65 changes: 63 additions & 2 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def new_test(self, value=value):


class TorchAOBasicTestCase(common_utils.TestCase):
"""Basic test case for tensor subclasses
"""
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

Expand Down Expand Up @@ -142,6 +140,66 @@ def test_linear(self, device, dtype):
lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor)
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)


class TorchAOCompileTestCase(common_utils.TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

TENSOR_SUBCLASS = AffineQuantizedTensor
FACTORY_FN = to_affine_quantized_intx
kwargs = {
"mapping_type": MappingType.ASYMMETRIC,
"block_size": (1, 32),
"target_dtype": torch.uint8,
}
# minimum sqnr for linear operation when the weight is quantized to low precision
# with the above setting
LINEAR_MIN_SQNR = 40
COMPILE_MIN_SQNR = 50

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_input_output_tensor_subclass(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
def f(tensor):
return tensor

ref = f(lp_tensor)
f = torch.compile(f)
compiled = f(lp_tensor)
self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
self.assertEqual(ref.dequantize(), compiled.dequantize())

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_input_tensor_subclass(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
def f(tensor):
return tensor.dequantize()

ref = f(lp_tensor)
f = torch.compile(f)
compiled = f(lp_tensor)
self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
self.assertEqual(ref, compiled)

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_output_tensor_subclass(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
def f(hp_tensor):
return self.FACTORY_FN(hp_tensor, **self.kwargs)

ref = f(hp_tensor)
f = torch.compile(f)
compiled = f(hp_tensor)
self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS))
# bfloat16 seems to result in much larger numerical differences
if dtype != torch.bfloat16:
self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR)

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_linear_compile(self, device, dtype):
Expand All @@ -155,7 +213,10 @@ def test_linear_compile(self, device, dtype):
lp_res = torch.compile(l)(hp_act_tensor)
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)



common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)

if __name__ == "__main__":
unittest.main()

0 comments on commit 64719d5

Please sign in to comment.