diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index 287731e7fef43..da4d38567cce5 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -90,6 +90,7 @@ void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) { using arg2_t = typename traits::template arg<1>::type; auto a = iter.scalar_value(1); iter.remove_operand(1); + const OptionalDeviceGuard device_guard(device_of(iter.tensor(1))); gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) { return f(a, b); }); diff --git a/test/test_torch.py b/test/test_torch.py index ad8be07ca2521..cf045c8fcedc6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9558,7 +9558,7 @@ def run_test(input_): M, N = input_.shape input_.zero_() for i in range(min(M, N)): - input_[i][i] = 1 + input_[i][i] = 1 output1 = input_.argmax(dim=0) output2 = input_.sum(dim=0) for i in range(min(M, N)): @@ -14177,6 +14177,24 @@ def test_cross_device_binary_ops(self, devices): with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): op(cpu_tensor, a) + # This test ensures that a scalar Tensor can be safely used + # in a binary operation in conjuction with a Tensor on all + # available CUDA devices + @deviceCountAtLeast(2) + @onlyCUDA + def test_binary_op_scalar_device_unspecified(self, devices): + scalar_val = torch.tensor(1.) + for default_device in devices: + with torch.cuda.device(default_device): + for device in devices: + device_obj = torch.device(device) + x = torch.rand(3, device=device) + y0 = x * scalar_val + self.assertEqual(y0.device, device_obj) + y1 = scalar_val * x + self.assertEqual(y1.device, device_obj) + self.assertEqual(y0, y1) + # Tests that CPU scalars (including zero dim tensors) can be used in # binary operations with CUDA tensors. @onlyCUDA