Skip to content

Commit 5e02675

Browse files
committed
Hopefully fix cuda test
1 parent acc7a98 commit 5e02675

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

test/test_transforms_v2_refactored.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
assert_no_warnings,
1717
cache,
1818
cpu_and_cuda,
19+
freeze_rng_state,
1920
ignore_jit_no_profile_information_warning,
2021
make_bounding_box,
2122
make_detection_mask,
@@ -61,8 +62,10 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs):
6162
input_cuda = input.as_subclass(torch.Tensor)
6263
input_cpu = input_cuda.to("cpu")
6364

64-
actual = kernel(input_cuda, *args, **kwargs)
65-
expected = kernel(input_cpu, *args, **kwargs)
65+
with freeze_rng_state():
66+
actual = kernel(input_cuda, *args, **kwargs)
67+
with freeze_rng_state():
68+
expected = kernel(input_cpu, *args, **kwargs)
6669

6770
assert_close(actual, expected, check_device=False, rtol=rtol, atol=atol)
6871

@@ -1772,11 +1775,11 @@ def test_cpu_vs_gpu(self, T):
17721775
batch_size = 3
17731776
H, W = 12, 12
17741777

1775-
imgs = torch.rand(batch_size, 3, H, W).to("cuda")
1776-
labels = torch.randint(0, num_classes, (batch_size,)).to("cuda")
1778+
imgs = torch.rand(batch_size, 3, H, W)
1779+
labels = torch.randint(0, num_classes, (batch_size,))
17771780
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
17781781

1779-
_check_kernel_cuda_vs_cpu(cutmix_mixup, input=(imgs, labels), rtol=None, atol=None)
1782+
_check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None)
17801783

17811784
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
17821785
def test_error(self, T):

0 commit comments

Comments
 (0)