|
16 | 16 | assert_no_warnings, |
17 | 17 | cache, |
18 | 18 | cpu_and_cuda, |
| 19 | + freeze_rng_state, |
19 | 20 | ignore_jit_no_profile_information_warning, |
20 | 21 | make_bounding_box, |
21 | 22 | make_detection_mask, |
@@ -61,8 +62,10 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs): |
61 | 62 | input_cuda = input.as_subclass(torch.Tensor) |
62 | 63 | input_cpu = input_cuda.to("cpu") |
63 | 64 |
|
64 | | - actual = kernel(input_cuda, *args, **kwargs) |
65 | | - expected = kernel(input_cpu, *args, **kwargs) |
| 65 | + with freeze_rng_state(): |
| 66 | + actual = kernel(input_cuda, *args, **kwargs) |
| 67 | + with freeze_rng_state(): |
| 68 | + expected = kernel(input_cpu, *args, **kwargs) |
66 | 69 |
|
67 | 70 | assert_close(actual, expected, check_device=False, rtol=rtol, atol=atol) |
68 | 71 |
|
@@ -1772,11 +1775,11 @@ def test_cpu_vs_gpu(self, T): |
1772 | 1775 | batch_size = 3 |
1773 | 1776 | H, W = 12, 12 |
1774 | 1777 |
|
1775 | | - imgs = torch.rand(batch_size, 3, H, W).to("cuda") |
1776 | | - labels = torch.randint(0, num_classes, (batch_size,)).to("cuda") |
| 1778 | + imgs = torch.rand(batch_size, 3, H, W) |
| 1779 | + labels = torch.randint(0, num_classes, (batch_size,)) |
1777 | 1780 | cutmix_mixup = T(alpha=0.5, num_classes=num_classes) |
1778 | 1781 |
|
1779 | | - _check_kernel_cuda_vs_cpu(cutmix_mixup, input=(imgs, labels), rtol=None, atol=None) |
| 1782 | + _check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None) |
1780 | 1783 |
|
1781 | 1784 | @pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup]) |
1782 | 1785 | def test_error(self, T): |
|
0 commit comments