Skip to content

Commit

Permalink
Register unfold key for MPS (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisVieriu97 authored and pytorchmergebot committed Dec 22, 2022
1 parent 55749b9 commit 0670a82
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9439,7 +9439,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CPU, CUDA, Meta: unfold
CPU, CUDA, Meta, MPS: unfold
QuantizedCPU, QuantizedCUDA: unfold

- func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor
Expand Down
35 changes: 28 additions & 7 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,7 +2039,29 @@ def test_as_strided(self):
strided_mps_out = strided_mps1 - strided_mps2
self.assertEqual(strided_cpu_out, strided_mps_out)

def test_unfold(self):
x = torch.arange(1., 8)
x_mps = torch.arange(1., 8, device="mps")

y = x.unfold(0, 2, 1)
y_mps = x_mps.unfold(0, 2, 1)

self.assertEqual(y, y_mps)

def test_unfold_all_devices_and_dtypes(self):
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
for dt in supported_dtypes:
x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)

def test_unfold_scalars(self):
x = torch.tensor(0.5, device="mps")
# unfold on a 0-dimensional tensor should always return a 1-d dimensional
# tensor of shape [size] (i.e., the second parameter to unfold)

self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))

def test_sum_backward(self):
def helper(n, c):
Expand Down Expand Up @@ -5726,14 +5748,13 @@ def test_T_view(self, device="mps"):
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])

# requires aten::unfold
# def test_unfold_view(self, device="mps"):
# t = torch.ones(10, device=device)
# v = t.unfold(0, 3, 2)
# self.assertTrue(self.is_view_of(t, v))
def test_unfold_view(self, device="mps"):
t = torch.ones(10, device=device)
v = t.unfold(0, 3, 2)
self.assertTrue(self.is_view_of(t, v))

# v[1, 0] = 0
# self.assertEqual(t[2], v[1, 0])
v[1, 0] = 0
self.assertEqual(t[2], v[1, 0])

def test_squeeze_view(self, device="mps"):
t = torch.ones(5, 1, 5, device=device)
Expand Down

0 comments on commit 0670a82

Please sign in to comment.