From b294ea85ba857e4c9c98d14df44db74f4b2f087e Mon Sep 17 00:00:00 2001 From: Denis Vieriu <104024078+DenisVieriu97@users.noreply.github.com> Date: Tue, 11 Oct 2022 16:55:02 -0700 Subject: [PATCH] Register unfold key for MPS (#134) --- aten/src/ATen/native/native_functions.yaml | 2 +- test/test_mps.py | 35 +++++++++++++++++----- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b0621235e56d4..622f5ac5da13a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9430,7 +9430,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta: unfold + CPU, CUDA, Meta, MPS: unfold QuantizedCPU, QuantizedCUDA: unfold - func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor diff --git a/test/test_mps.py b/test/test_mps.py index aa3b444bce448..79f22f8373f04 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1956,7 +1956,29 @@ def test_as_strided(self): strided_mps_out = strided_mps1 - strided_mps2 self.assertEqual(strided_cpu_out, strided_mps_out) + def test_unfold(self): + x = torch.arange(1., 8) + x_mps = torch.arange(1., 8, device="mps") + y = x.unfold(0, 2, 1) + y_mps = x_mps.unfold(0, 2, 1) + + self.assertEqual(y, y_mps) + + def test_unfold_all_devices_and_dtypes(self): + supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8] + for dt in supported_dtypes: + x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps") + self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape) + + def test_unfold_scalars(self): + x = torch.tensor(0.5, device="mps") + # unfold on a 0-dimensional tensor should always return a 1-d dimensional + # tensor of shape [size] (i.e., the second parameter to unfold) + + self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1)) + self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2)) + self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1)) def test_sum_backward(self): def helper(n, c): @@ -5596,14 +5618,13 @@ def test_T_view(self, device="mps"): v[0, 1] = 0 self.assertEqual(t[1, 0], v[0, 1]) - # requires aten::unfold - # def test_unfold_view(self, device="mps"): - # t = torch.ones(10, device=device) - # v = t.unfold(0, 3, 2) - # self.assertTrue(self.is_view_of(t, v)) + def test_unfold_view(self, device="mps"): + t = torch.ones(10, device=device) + v = t.unfold(0, 3, 2) + self.assertTrue(self.is_view_of(t, v)) - # v[1, 0] = 0 - # self.assertEqual(t[2], v[1, 0]) + v[1, 0] = 0 + self.assertEqual(t[2], v[1, 0]) def test_squeeze_view(self, device="mps"): t = torch.ones(5, 1, 5, device=device)