Skip to content

Commit

Permalink
Reland: Fix CUDA device guard usage when first arg of kernel is scalar (
Browse files Browse the repository at this point in the history
pytorch#39956)

Summary:
Reland PR pytorch#39870

Closes pytorch#38889
Pull Request resolved: pytorch#39956

Differential Revision: D22027956

Pulled By: ngimel

fbshipit-source-id: e6029f450e2da3782b2d05bcc2012c19b82291da
  • Loading branch information
kurtamohler authored and xwang233 committed Jun 19, 2020
1 parent 88bceb5 commit 54e3d37
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
1 change: 1 addition & 0 deletions aten/src/ATen/native/cuda/Loops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
using arg2_t = typename traits::template arg<1>::type;
auto a = iter.scalar_value<arg1_t>(1);
iter.remove_operand(1);
const OptionalDeviceGuard device_guard(device_of(iter.tensor(1)));
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
return f(a, b);
});
Expand Down
20 changes: 19 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9558,7 +9558,7 @@ def run_test(input_):
M, N = input_.shape
input_.zero_()
for i in range(min(M, N)):
input_[i][i] = 1
input_[i][i] = 1
output1 = input_.argmax(dim=0)
output2 = input_.sum(dim=0)
for i in range(min(M, N)):
Expand Down Expand Up @@ -14177,6 +14177,24 @@ def test_cross_device_binary_ops(self, devices):
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
op(cpu_tensor, a)

# This test ensures that a scalar Tensor can be safely used
# in a binary operation in conjuction with a Tensor on all
# available CUDA devices
@deviceCountAtLeast(2)
@onlyCUDA
def test_binary_op_scalar_device_unspecified(self, devices):
scalar_val = torch.tensor(1.)
for default_device in devices:
with torch.cuda.device(default_device):
for device in devices:
device_obj = torch.device(device)
x = torch.rand(3, device=device)
y0 = x * scalar_val
self.assertEqual(y0.device, device_obj)
y1 = scalar_val * x
self.assertEqual(y1.device, device_obj)
self.assertEqual(y0, y1)

# Tests that CPU scalars (including zero dim tensors) can be used in
# binary operations with CUDA tensors.
@onlyCUDA
Expand Down

0 comments on commit 54e3d37

Please sign in to comment.