diff --git a/test/dtypes/test_bitpacking.py b/test/dtypes/test_bitpacking.py index 0ed4462d5d..9c54631b45 100644 --- a/test/dtypes/test_bitpacking.py +++ b/test/dtypes/test_bitpacking.py @@ -8,9 +8,11 @@ from torch.utils._triton import has_triton from torchao.dtypes.uintx.bitpacking import pack, pack_cpu, unpack, unpack_cpu +from torchao.utils import get_current_accelerator_device bit_widths = (1, 2, 3, 4, 5, 6, 7) dimensions = (0, -1, 1) +_DEVICE = get_current_accelerator_device() @pytest.fixture(autouse=True) @@ -30,17 +32,19 @@ def test_CPU(bit_width, dim): assert unpacked.allclose(test_tensor) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("dim", dimensions) def test_GPU(bit_width, dim): - test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() + test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).to( + _DEVICE + ) packed = pack(test_tensor, bit_width, dim=dim) unpacked = unpack(packed, bit_width, dim=dim) assert unpacked.allclose(test_tensor) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("dim", dimensions) @@ -48,22 +52,26 @@ def test_compile(bit_width, dim): torch._dynamo.config.specialize_int = True torch.compile(pack, fullgraph=True) torch.compile(unpack, fullgraph=True) - test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() + test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).to( + _DEVICE + ) packed = pack(test_tensor, bit_width, dim=dim) unpacked = unpack(packed, bit_width, dim=dim) assert unpacked.allclose(test_tensor) # these test cases are for the example pack walk through in the bitpacking.py file -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") def test_pack_example(): test_tensor = torch.tensor( [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 - ).cuda() + ).to(_DEVICE) shard_4, shard_2 = pack(test_tensor, 6) print(shard_4, shard_2) - assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4) - assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2) + assert ( + torch.tensor([0, 105, 151, 37], dtype=torch.uint8).to(_DEVICE).allclose(shard_4) + ) + assert torch.tensor([39, 146], dtype=torch.uint8).to(_DEVICE).allclose(shard_2) unpacked = unpack([shard_4, shard_2], 6) assert unpacked.allclose(test_tensor) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index a3dd4d19e3..19a7ca4c56 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -33,10 +33,11 @@ quantize_, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import is_fbcode +from torchao.utils import get_current_accelerator_device, is_fbcode -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _Floatx_DTYPES = [(3, 2), (2, 2)] +_DEVICE = get_current_accelerator_device() +_DEVICES = ["cpu"] + ([_DEVICE] if torch.accelerator.is_available() else []) class TestFloatxTensorCoreAQTTensorImpl(TestCase): @@ -87,7 +88,7 @@ def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device): ) torch.testing.assert_close(actual, expected) - @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), reason="GPU not available") @parametrize("ebits,mbits", _Floatx_DTYPES) def test_to_copy_device(self, ebits, mbits): from torchao.quantization.quant_primitives import ( @@ -101,8 +102,8 @@ def test_to_copy_device(self, ebits, mbits): _layout = FloatxTensorCoreLayout(ebits, mbits) floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain( x, scale, None, _layout - ).cuda() - assert floatx_tensor_impl.device.type == "cuda" + ).to(_DEVICE) + assert floatx_tensor_impl.device.type == _DEVICE.type floatx_tensor_impl = floatx_tensor_impl.cpu() assert floatx_tensor_impl.device.type == "cpu" @@ -114,7 +115,7 @@ def test_to_copy_device(self, ebits, mbits): @skip_if_rocm("ROCm enablement in progress") def test_fpx_weight_only(self, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 - device = "cuda" + device = _DEVICE linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype) fpx_linear = copy.deepcopy(linear)