Skip to content

Commit b119fd1

Browse files
author
realAsma
committed
fixed minor tests
Signed-off-by: realAsma <you@example.com>
1 parent 1b2a57d commit b119fd1

File tree

2 files changed

+8
-17
lines changed

2 files changed

+8
-17
lines changed

tests/gpu/torch/quantization/test_tensor_quant_cuda.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,6 @@
3030
class TestFakeTensorQuantCuda(FakeTensorQuantTester):
3131
device = "cuda"
3232

33-
def test_non_current_gpu(self, need_2_gpus):
34-
device = torch.cuda.device_count() - 1
35-
assert torch.cuda.current_device() != device
36-
x = torch.randn(3, 4).cuda(device)
37-
quant_x = tensor_quant.fake_tensor_quant(x, torch.max(torch.abs(x)), None)
38-
quant_x_ref = quant(x, torch.max(torch.abs(x)), fake=True)
39-
assert torch.allclose(quant_x, quant_x_ref)
40-
4133

4234
class TestCudaExt:
4335
@pytest.mark.parametrize("num_bits", [3, 4, 5, 7, 8, 11])
@@ -145,15 +137,6 @@ def test_backward(self, device):
145137
loss.backward()
146138
assert torch.allclose(quant_x.grad, x.grad)
147139

148-
def test_non_current_gpu(self, need_2_gpus):
149-
torch.cuda.set_device(0)
150-
device = torch.cuda.device_count() - 1
151-
x = torch.randn(3, 4).cuda()
152-
quant_x_ref = tensor_quant.fp8_eager(x, torch.tensor(448.0, device=x.device))
153-
x = x.cuda(device)
154-
quant_x = tensor_quant.scaled_e4m3(x, None, None, 4, 3)
155-
assert torch.allclose(quant_x.cuda(), quant_x_ref)
156-
157140
@pytest.mark.parametrize("axis", [0, 1, 2])
158141
def test_e4m3_per_channel(self, axis):
159142
x = torch.randn(4, 4, 4, dtype=torch.float32).cuda()

tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def test_e4m3(self, E, M, axis): # noqa: N803
5555
ref = tensor_quant.scaled_e4m3(x, e4m3_quantizer._get_amax(x), None, E, M)
5656
assert torch.allclose(e4m3_x, ref)
5757

58+
def test_non_current_gpu(self, need_2_gpus):
59+
x = torch.randn(3, 4)
60+
e4m3_desc = QuantizerAttributeConfig(num_bits=(4, 3), axis=None)
61+
quantizer = tensor_quantizer.TensorQuantizer(e4m3_desc).cuda()
62+
xq_ref = quantizer(x.to("cuda:0"))
63+
xq_test = quantizer(x.to("cuda:1"))
64+
assert torch.allclose(xq_ref, xq_test.to("cuda:0"))
65+
5866

5967
@pytest.mark.skipif(get_cuda_ext_mx() is None, reason="cuda_ext_mx is not available")
6068
class TestTensorQuantizerfp4:

0 commit comments

Comments
 (0)