Skip to content

Commit e4eadc3

Browse files
DenisVieriu97Denis Vieriu
authored andcommitted
Fix slice followed by reshape (#237)
Co-authored-by: Denis Vieriu <denisvieriu@mac-3045BF.local>
1 parent ceef445 commit e4eadc3

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

aten/src/ATen/native/mps/operations/View.mm

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,11 +540,26 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
540540
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape);
541541
size_t src_ndim_view = src_view_shape.size();
542542
if (src_ndim_base == src_ndim_view) {
543-
for (const auto i : c10::irange(src_ndim_base)) {
543+
for (const auto i: c10::irange(src_ndim_base)) {
544544
if (src_view_shape[i] > src_base_shape[i]) {
545545
return false;
546546
}
547547
}
548+
} else {
549+
// Detect slice followed by reshape cases, e.g (1,4800,2) -> (1,4800)
550+
bool allDimsEqual = true;
551+
auto min_ndim = std::min(src_ndim_base, src_ndim_view);
552+
for (const auto i: c10::irange(min_ndim)) {
553+
if (src_view_shape[i] > src_base_shape[i]) {
554+
return false;
555+
}
556+
else if (src_view_shape[i] != src_base_shape[i]) {
557+
allDimsEqual = false;
558+
}
559+
}
560+
if (allDimsEqual) {
561+
return false;
562+
}
548563
}
549564
return true;
550565
}

test/test_mps.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,29 @@ def test_cpu_to_strided_mps_copy(self):
16231623

16241624
self.assertEqual(a1, a2)
16251625

1626+
def test_slice_reshape(self):
1627+
x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps")
1628+
x_cpu = x.detach().clone().to("cpu")
1629+
1630+
x = x[:,3:].view(2, 3, 4, 1)
1631+
x_cpu = x_cpu[:,3:].view(2, 3, 4, 1)
1632+
self.assertEqual(x, x_cpu)
1633+
1634+
x = x + 2
1635+
x_cpu = x_cpu + 2
1636+
self.assertEqual(x, x_cpu)
1637+
1638+
def test_slice_reshape_contg_view(self):
1639+
import torch
1640+
1641+
x_mps = torch.randn(1, 4800, 2, device="mps")
1642+
x_cpu = x_mps.detach().clone().cpu()
1643+
1644+
r_mps = x_mps + 2
1645+
r_cpu = x_cpu + 2
1646+
1647+
self.assertEqual(r_mps, r_cpu)
1648+
16261649
def test_view_slice(self):
16271650
# https://github.com/pytorch/pytorch/issues/83995
16281651
NUM_SAMPLES = 60

0 commit comments

Comments
 (0)