Skip to content

Commit 656bba9

Browse files
DenisVieriu97Denis Vieriu
andcommitted
Fix slice followed by reshape (#237)
Co-authored-by: Denis Vieriu <denisvieriu@mac-3045BF.local>
1 parent 50fd99e commit 656bba9

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
109109
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
110110
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor
111111
// Output tensor should have been promoted but it remains an int32 tensor
112-
if (outputDataType != common_dtype) {
112+
if (outputDataType != common_dtype ||
113+
[newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) {
113114
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, outputDataType);
114115
}
115116
}

aten/src/ATen/native/mps/operations/View.mm

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,11 +536,26 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
536536
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape);
537537
int src_ndim_view = src_view_shape.size();
538538
if (src_ndim_base == src_ndim_view) {
539-
for (const auto i : c10::irange(src_ndim_base)) {
539+
for (const auto i: c10::irange(src_ndim_base)) {
540540
if (src_view_shape[i] > src_base_shape[i]) {
541541
return false;
542542
}
543543
}
544+
} else {
545+
// Detect slice followed by reshape cases, e.g (1,4800,2) -> (1,4800)
546+
bool allDimsEqual = true;
547+
auto min_ndim = std::min(src_ndim_base, src_ndim_view);
548+
for (const auto i: c10::irange(min_ndim)) {
549+
if (src_view_shape[i] > src_base_shape[i]) {
550+
return false;
551+
}
552+
else if (src_view_shape[i] != src_base_shape[i]) {
553+
allDimsEqual = false;
554+
}
555+
}
556+
if (allDimsEqual) {
557+
return false;
558+
}
544559
}
545560
return true;
546561
}

test/test_mps.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,9 +1617,23 @@ def test_slice_reshape(self):
16171617

16181618
x = x[:,3:].view(2, 3, 4, 1)
16191619
x_cpu = x_cpu[:,3:].view(2, 3, 4, 1)
1620+
self.assertEqual(x, x_cpu)
16201621

1622+
x = x + 2
1623+
x_cpu = x_cpu + 2
16211624
self.assertEqual(x, x_cpu)
16221625

1626+
def test_slice_reshape_contg_view(self):
1627+
import torch
1628+
1629+
x_mps = torch.randn(1, 4800, 2, device="mps")
1630+
x_cpu = x_mps.detach().clone().cpu()
1631+
1632+
r_mps = x_mps + 2
1633+
r_cpu = x_cpu + 2
1634+
1635+
self.assertEqual(r_mps, r_cpu)
1636+
16231637
def test_view_slice(self):
16241638
# https://github.com/pytorch/pytorch/issues/83995
16251639
NUM_SAMPLES = 60

0 commit comments

Comments
 (0)