Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support flip and rot90 for complex dtype #37826

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/TensorTransformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Tensor flip_cpu(const Tensor& self, IntArrayRef dims) {
}
}

AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cpu", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cpu", [&] {
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
flip_cpu_kernel<scalar_t>(
total_dims,
stride_contiguous_v,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/TensorTransformations.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) {

// use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work
if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cuda", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cuda", [&] {
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
auto in_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(in_tensor);
auto out_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(out_tensor);
int flip_dim = in_tensor_info.collapseDims(flip_dims[0]);
Expand Down Expand Up @@ -122,7 +122,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) {
}
}

AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.scalar_type(), "flip_cuda", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, in_tensor.scalar_type(), "flip_cuda", [&] {
flip_cuda_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
in_tensor.data_ptr<scalar_t>(), out_tensor.data_ptr<scalar_t>(), N,
flip_dims_t.cuda().data_ptr<int64_t>(),
Expand Down
2 changes: 1 addition & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4319,7 +4319,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
'tril', 'triu', 'fill_', 'eq_', 'ne_', 'permute', 'squeeze', 'unsqueeze',
'chunk', 'split', 'split_with_sizes', 'resize', 'resize_as', 'sin', 'cos',
'__rmul__', '__rdiv__', 'sum', 'transpose', 'round', 'add', 'roll',
'__radd__', 'repeat', 'expand', 'mul', 'tanh'] + separate_complex_tests
'__radd__', 'repeat', 'expand', 'mul', 'tanh', 'flip', 'rot90'] + separate_complex_tests

def add_test(
name,
Expand Down
77 changes: 62 additions & 15 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch._six import inf, nan, string_classes, istuple
from itertools import product, combinations, combinations_with_replacement, permutations
from functools import reduce
from functools import partial
from random import randrange
from torch import multiprocessing as mp
from torch.testing._internal.common_methods_invocations import tri_tests_args, run_additional_tri_tests, \
Expand Down Expand Up @@ -6368,6 +6369,28 @@ def _np_compare(self, fn_name, vals, device, dtype):
torch_result = torch_fn(t).cpu()
self.assertEqual(np_result, torch_result)

def _np_compare_func(self, fns, vals, device, dtype):
assert TEST_NUMPY

torch_fn, np_fn = fns

a = np.array(vals, dtype=torch_to_numpy_dtype_dict[dtype])

# `numpy` may return an array with `negative` strides
# which is currently not supported. `.copy()` assures
# that we get an array with `postive` strides only.
np_result = torch.from_numpy(np_fn(a).copy())
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

t = torch.tensor(vals, device=device, dtype=dtype)
torch_result = torch_fn(t).cpu()
self.assertEqual(np_result, torch_result)

def _rand_shape(self, dim, min_size, max_size):
shape = []
for i in range(dim):
shape.append(random.randint(min_size, max_size))
return tuple(shape)

@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@dtypes(torch.float)
def test_isfinite_isinf_isnan(self, device, dtype):
Expand Down Expand Up @@ -8022,6 +8045,21 @@ def test_flip(self, device):
a = torch.tensor([False, True])
self.assertEqual(a.flip(0), torch.tensor([True, False]))

@dtypes(torch.cfloat, torch.cdouble)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_complex_flip(self, device, dtype):
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
rand_dim = random.randint(3, 4)
shape = self._rand_shape(rand_dim, 5, 10)

# Axis to sample for given shape.
for i in range(1, rand_dim):
# Check all combinations of `i` axis.
for flip_dim in combinations(range(rand_dim), i):
data = torch.randn(*shape, dtype=dtype).tolist()
torch_fn = partial(torch.flip, dims=flip_dim)
np_fn = partial(np.flip, axis=flip_dim)
self._np_compare_func((torch_fn, np_fn), data, device, dtype)

def test_rot90(self, device):
data = torch.arange(1, 5, device=device).view(2, 2)
self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1]))
Expand Down Expand Up @@ -8055,6 +8093,16 @@ def test_rot90(self, device):
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2]))
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0]))

@dtypes(torch.cfloat, torch.cdouble)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_complex_rot90(self, device, dtype):
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
shape = self._rand_shape(random.randint(2, 4), 5, 10)
for rot_times in range(4):
data = torch.randn(*shape, dtype=dtype).tolist()
torch_fn = partial(torch.rot90, k=rot_times, dims=[0, 1])
np_fn = partial(np.rot90, k=rot_times, axes=[0, 1])
self._np_compare_func((torch_fn, np_fn), data, device, dtype)

def test_signal_window_functions(self, device):
if not TEST_SCIPY:
raise unittest.SkipTest('Scipy not found')
Expand Down Expand Up @@ -11886,14 +11934,11 @@ def test_vander(self, device):
with self.assertRaisesRegex(RuntimeError, "x must be a one-dimensional tensor."):
torch.vander(torch.stack((x, x)))

# This passes on the xla backend
if device != 'xla':
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(RuntimeError):
torch.vander(x.to(torch.complex64))

@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@onlyOnCPUAndCUDA
@dtypes(torch.bool, torch.uint8, torch.int8, torch.short, torch.int, torch.long, torch.float, torch.double)
@dtypes(torch.bool, torch.uint8, torch.int8, torch.short, torch.int, torch.long,
torch.float, torch.double,
torch.cfloat, torch.cdouble)
def test_vander_types(self, device, dtype):
if dtype is torch.uint8:
# Note: no negative uint8 values
Expand All @@ -11902,6 +11947,9 @@ def test_vander_types(self, device, dtype):
# Note: see https://github.com/pytorch/pytorch/issues/37398
# for why this is necessary.
X = [[True, True, True, True], [False, True, True, True, True]]
elif dtype in [torch.cfloat, torch.cdouble]:
X = [[1 + 1j, 1 + 0j, 0 + 1j, 0 + 0j],
[2 + 2j, 3 + 2j, 4 + 3j, 5 + 4j]]
else:
X = [[1, 2, 3, 5], [-math.pi, 0, 1 / 3, 1, math.pi, 3 / 7]]

Expand Down Expand Up @@ -18414,15 +18462,14 @@ def inner(self, device, dtype):
1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('zero_', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('new_zeros', '', _small_3d, lambda t, d: [1, 2, 3, 4], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('flip', 'd0', _small_3d, lambda t, d: [0], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('flip', 'd012', _small_3d, lambda t, d: [0, 1, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('flip', 'd02', _small_3d, lambda t, d: [0, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('flip', 'd20', _small_3d, lambda t, d: [2, 0], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('flip', 'neg_d', _small_3d, lambda t, d: [-1], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('rot90', 'k1_d01', _small_2d, lambda t, d: [1, [0, 1]], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('rot90', 'k1_d12', _small_3d, lambda t, d: [1, [1, 2]], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('rot90', 'k1_neg_d', _small_3d, lambda t, d: [1, [1, -1]], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('rot90', 'default', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('flip', 'd0', _small_3d, lambda t, d: [0], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
('flip', 'd02', _small_3d, lambda t, d: [0, 2], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
('flip', 'd20', _small_3d, lambda t, d: [2, 0], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
('flip', 'neg_d', _small_3d, lambda t, d: [-1], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
('rot90', 'k1_d01', _small_2d, lambda t, d: [1, [0, 1]], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
('rot90', 'k1_d12', _small_3d, lambda t, d: [1, [1, 2]], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
('rot90', 'k1_neg_d', _small_3d, lambda t, d: [1, [1, -1]], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
('rot90', 'default', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
('rsqrt', '', lambda t, d: _small_3d(t, d) + 1, lambda t, d: [], 1e-2, 1e-5, 1e-4, _float_types_no_half),
('sinh', '', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types),
('tan', '', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types),
Expand Down