Skip to content

Commit 6dc2203

Browse files
DenisVieriu97Denis Vieriu
andauthored
Fix slice followed by reshape (#237)
Co-authored-by: Denis Vieriu <denisvieriu@mac-3045BF.local>
1 parent 497300d commit 6dc2203

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
110110
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
111111
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor
112112
// Output tensor should have been promoted but it remains an int32 tensor
113-
if (outputDataType != common_dtype) {
113+
if (outputDataType != common_dtype ||
114+
[newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) {
114115
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, outputDataType);
115116
}
116117
}

aten/src/ATen/native/mps/operations/View.mm

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,11 +450,26 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
450450
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape);
451451
size_t src_ndim_view = src_view_shape.size();
452452
if (src_ndim_base == src_ndim_view) {
453-
for (const auto i : c10::irange(src_ndim_base)) {
453+
for (const auto i: c10::irange(src_ndim_base)) {
454454
if (src_view_shape[i] > src_base_shape[i]) {
455455
return false;
456456
}
457457
}
458+
} else {
459+
// Detect slice followed by reshape cases, e.g (1,4800,2) -> (1,4800)
460+
bool allDimsEqual = true;
461+
auto min_ndim = std::min(src_ndim_base, src_ndim_view);
462+
for (const auto i: c10::irange(min_ndim)) {
463+
if (src_view_shape[i] > src_base_shape[i]) {
464+
return false;
465+
}
466+
else if (src_view_shape[i] != src_base_shape[i]) {
467+
allDimsEqual = false;
468+
}
469+
}
470+
if (allDimsEqual) {
471+
return false;
472+
}
458473
}
459474
return true;
460475
}

test/test_mps.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,9 +1629,23 @@ def test_slice_reshape(self):
16291629

16301630
x = x[:,3:].view(2, 3, 4, 1)
16311631
x_cpu = x_cpu[:,3:].view(2, 3, 4, 1)
1632+
self.assertEqual(x, x_cpu)
16321633

1634+
x = x + 2
1635+
x_cpu = x_cpu + 2
16331636
self.assertEqual(x, x_cpu)
16341637

1638+
def test_slice_reshape_contg_view(self):
1639+
import torch
1640+
1641+
x_mps = torch.randn(1, 4800, 2, device="mps")
1642+
x_cpu = x_mps.detach().clone().cpu()
1643+
1644+
r_mps = x_mps + 2
1645+
r_cpu = x_cpu + 2
1646+
1647+
self.assertEqual(r_mps, r_cpu)
1648+
16351649
def test_view_slice(self):
16361650
# https://github.com/pytorch/pytorch/issues/83995
16371651
NUM_SAMPLES = 60

0 commit comments

Comments
 (0)