From 0e92d1cfe962ceb2501b196cc284d451d51ca6a0 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Fri, 20 May 2022 19:11:30 -0700 Subject: [PATCH 1/3] Fix mps tensor comparasion when storage_offset is different from 0 --- aten/src/ATen/native/mps/OperationUtils.mm | 22 +++++++++++++++++++ .../ATen/native/mps/operations/BinaryOps.mm | 21 ++++++++---------- aten/src/ATen/native/mps/operations/Copy.mm | 5 ----- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 81d9908d09f49..4d363463c1c4d 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -323,6 +323,7 @@ void printTensorNDArray(const Tensor& t) { TORCH_CHECK(src_.is_mps(), "Placeholder storage has not been allocated on MPS device!"); // extract the pointer to MTLBuffer from the Tensor's storage id srcBuf = __builtin_bit_cast(id, src.storage().data()); + size_t srcSize = [srcBuf length]; if (check_view && !src.is_contiguous()) { id gatherTensor = gatherViewTensor(src, srcBuf); if (gatherTensor) { @@ -332,6 +333,27 @@ void printTensorNDArray(const Tensor& t) { srcBuf = __builtin_bit_cast(id, src_.storage().data()); } } + else if (srcSize && src.storage_offset() && src.is_contiguous()) { + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + id commandQueue = mpsStream->commandQueue(); + id commandBuffer = [commandQueue commandBuffer]; + id blitEncoder = [commandBuffer blitCommandEncoder]; + id dstBuf = [[device newBufferWithLength: srcSize + options: srcBuf.resourceOptions] autorelease]; + + [blitEncoder copyFromBuffer: srcBuf + sourceOffset: src.storage_offset() * src.element_size() + toBuffer: dstBuf + destinationOffset: 0 + size: srcSize - (src.storage_offset() * src.element_size())]; +#if MTL_SUPPORT_MANAGED_STORAGE + [blitEncoder synchronizeResource:dstBuf]; +#endif + [blitEncoder endEncoding]; + [commandBuffer commit]; + srcBuf = dstBuf; + } const size_t buf_size = [srcBuf length]; // tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero. diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 9c37d930d2e72..2c81236fea1e0 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -21,22 +21,19 @@ typedef MPSGraphTensor* (^BinaryOpBlock)(MPSGraph*, MPSGraphTensor*, MPSGraphTensor*); #define BinaryOpFn() MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* primary, MPSGraphTensor* secondary) -void binaryOpTensor(const Tensor& self_t, const Tensor& other_t, const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock) +void binaryOpTensor(const Tensor& self, const Tensor& other, const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock) { // it's possible to receive empty tensors here - if (self_t.numel() == 0 || other_t.numel() == 0) { + if (self.numel() == 0 || other.numel() == 0) { return; } MPSStream* mpsStream = getCurrentMPSStream(); - const bool is_self_scalar = self_t.dim() == 0; - const bool is_other_scalar = other_t.dim() == 0; + const bool is_self_scalar = self.dim() == 0; + const bool is_other_scalar = other.dim() == 0; - Tensor self = is_self_scalar ? self_t : self_t.contiguous(at::MemoryFormat::Contiguous); - Tensor other = is_other_scalar ? other_t : other_t.contiguous(at::MemoryFormat::Contiguous); - - const MPSDataType self_dtype = getMPSScalarType((is_self_scalar && !is_other_scalar ? other_t : self_t).scalar_type()); - const MPSDataType other_dtype = getMPSScalarType((!is_other_scalar ? other_t : self_t).scalar_type()); + const MPSDataType self_dtype = getMPSScalarType((is_self_scalar && !is_other_scalar ? other : self).scalar_type()); + const MPSDataType other_dtype = getMPSScalarType((!is_other_scalar ? other : self).scalar_type()); MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { @@ -62,16 +59,16 @@ void binaryOpTensor(const Tensor& self_t, const Tensor& other_t, const Tensor& o if (is_self_scalar) { feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self.item(), self_dtype); } else { - Placeholder selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self); + Placeholder selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, nullptr, true); feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); } if (is_other_scalar) { feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other.item(), other_dtype); } else { - Placeholder otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other); + Placeholder otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, nullptr, true); feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); } - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr); NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index e1e2a9e24fc04..313e75cd5f774 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -287,11 +287,6 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, } else { dst = dst_; } - dst._set_conj(dst_.is_conj()); - src._set_conj(src_.is_conj()); - - dst._set_neg(dst_.is_neg()); - src._set_neg(src_.is_neg()); auto storage_byte_offset = src_.storage_offset() * src_.itemsize(); id sourceBuffer = __builtin_bit_cast(id, src_.storage().data()); From c8306dd9318697d63a12e9adfce4640722d20a33 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Tue, 24 May 2022 14:40:55 -0700 Subject: [PATCH 2/3] Modify as_strided_tensorimpl_mps and gatherViewTensor to work with contiguous tensors; take storage_offset for views into account for cpu->mps copies --- aten/src/ATen/native/mps/OperationUtils.h | 13 ++ aten/src/ATen/native/mps/OperationUtils.mm | 139 +++++++++--------- .../ATen/native/mps/operations/BinaryOps.mm | 7 +- aten/src/ATen/native/mps/operations/Copy.mm | 4 +- test/test_mps.py | 62 ++++++++ 5 files changed, 155 insertions(+), 70 deletions(-) diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 7db4eb1318b04..0a158b56bb39b 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -74,9 +74,22 @@ class Placeholder { return _value == nullptr; } + void allocateViewTensor(const at::Tensor& src) + { + assert (src.is_view() && !_viewOutput.numel()); + _viewOutput = at::native::empty_mps( + src.sizes(), + src.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + } + private: MPSGraphTensor* _placeholder; MPSGraphTensorData* _value; + Tensor _viewOutput; }; void resize_tensor(Tensor* output); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 4d363463c1c4d..d11063cc175a8 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -265,54 +265,81 @@ void printTensorNDArray(const Tensor& t) { [tdata printNDArray]; } -id gatherViewTensor(const at::Tensor& src, id sourceBuffer) { - assert (!src.is_contiguous()); +struct CachedGraph : public MPSCachedGraph +{ + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + IntArrayRef size_; + IntArrayRef stride_; + int64_t storage_offset_; +}; + +CachedGraph* _getCachedGraph(const at::Tensor& src) { + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + string key = getStridedKey(src, src.sizes(), src.strides(), src.storage_offset()); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + + return cachedGraph; +} + +id _gatherViewTensor(const at::Tensor& src, id sourceBuffer, CachedGraph* cachedGraph, Tensor& output) { + assert (src.is_view()); + id device = MPSDevice::getInstance()->device(); MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - IntArrayRef size_; - IntArrayRef stride_; - int64_t storage_offset_; + MPSGraphTensor* inputTensor = cachedGraph->inputTensor_; + MPSGraphTensorData* inputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer + shape: [inputTensor shape] + dataType: [inputTensor dataType]] autorelease]; + id resultBuffer = __builtin_bit_cast(id, output.storage().data()); + MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: resultBuffer + shape: getMPSShape(src.sizes()) + dataType: getMPSDataType(src.scalar_type())] autorelease]; + NSDictionary* feeds = @{ + inputTensor : inputTensorData }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - string key = getStridedKey(src, src.sizes(), src.strides(), src.storage_offset()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if (cachedGraph) { - @autoreleasepool { - MPSGraphTensor* inputTensor = cachedGraph->inputTensor_; - auto output = at::native::empty_mps( - src.sizes(), - src.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); - MPSGraphTensorData* inputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer - shape: [inputTensor shape] - dataType: [inputTensor dataType]] autorelease]; - id resultBuffer = __builtin_bit_cast(id, output.storage().data()); - MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: resultBuffer - shape: getMPSShape(src.sizes()) - dataType: getMPSDataType(src.scalar_type())] autorelease]; - NSDictionary* feeds = @{ - inputTensor : inputTensorData - }; - - NSDictionary* results = @{ - cachedGraph->outputTensor_ : outputTensorData - }; - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - return resultBuffer; - } - } + NSDictionary* results = @{ + cachedGraph->outputTensor_ : outputTensorData + }; + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + return resultBuffer; + } +} + +id gatherViewTensor(const at::Tensor& src, id sourceBuffer) { + assert (src.is_view()); + + CachedGraph* cachedGraph = _getCachedGraph(src); + if (cachedGraph) { + + Tensor output = at::native::empty_mps( + src.sizes(), + src.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + + _gatherViewTensor(src, sourceBuffer, cachedGraph, output); + return __builtin_bit_cast(id, output.storage().data()); } + + return nil; +} + +id gatherViewTensorWithAllocatedMem(const at::Tensor& src, id sourceBuffer, Tensor& output) { + assert (src.is_view()); + + CachedGraph* cachedGraph = _getCachedGraph(src); + if (cachedGraph) { + _gatherViewTensor(src, sourceBuffer, cachedGraph, output); + return __builtin_bit_cast(id, output.storage().data()); + } + return nil; } @@ -323,9 +350,9 @@ void printTensorNDArray(const Tensor& t) { TORCH_CHECK(src_.is_mps(), "Placeholder storage has not been allocated on MPS device!"); // extract the pointer to MTLBuffer from the Tensor's storage id srcBuf = __builtin_bit_cast(id, src.storage().data()); - size_t srcSize = [srcBuf length]; - if (check_view && !src.is_contiguous()) { - id gatherTensor = gatherViewTensor(src, srcBuf); + if (check_view && src.is_view()) { + allocateViewTensor(src); + id gatherTensor = gatherViewTensorWithAllocatedMem(src, srcBuf, _viewOutput); if (gatherTensor) { srcBuf = gatherTensor; } else { @@ -333,27 +360,7 @@ void printTensorNDArray(const Tensor& t) { srcBuf = __builtin_bit_cast(id, src_.storage().data()); } } - else if (srcSize && src.storage_offset() && src.is_contiguous()) { - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - id commandQueue = mpsStream->commandQueue(); - id commandBuffer = [commandQueue commandBuffer]; - id blitEncoder = [commandBuffer blitCommandEncoder]; - id dstBuf = [[device newBufferWithLength: srcSize - options: srcBuf.resourceOptions] autorelease]; - - [blitEncoder copyFromBuffer: srcBuf - sourceOffset: src.storage_offset() * src.element_size() - toBuffer: dstBuf - destinationOffset: 0 - size: srcSize - (src.storage_offset() * src.element_size())]; -#if MTL_SUPPORT_MANAGED_STORAGE - [blitEncoder synchronizeResource:dstBuf]; -#endif - [blitEncoder endEncoding]; - [commandBuffer commit]; - srcBuf = dstBuf; - } + const size_t buf_size = [srcBuf length]; // tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero. diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 2c81236fea1e0..3ec086fca22ce 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -55,17 +55,18 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Tensor& outpu cachedGraph = static_cast(tmpCachedGraph); } - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + Placeholder selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, nullptr, true); + Placeholder otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, nullptr, true); + if (is_self_scalar) { feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self.item(), self_dtype); } else { - Placeholder selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, nullptr, true); feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); } if (is_other_scalar) { feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other.item(), other_dtype); } else { - Placeholder otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, nullptr, true); feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); } Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr); diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 313e75cd5f774..66daaab66b785 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -114,7 +114,7 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, // 0 sizes won't result in any change in the shape of the Tensor so we can // skip it. Also if the memory is contiguous we don't need to do // gather-scatter operations using graph. - if (size.size() > 0 && (!result.is_contiguous())) { + if (size.size() > 0) { // If self itself was a view tensor, that means we need to chain the graphs // else we will create a new entry in the cache @@ -394,6 +394,8 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, options:options deallocator:nil]; sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr); + if (src_.is_view()) + sourceOffset += src_.storage_offset() * src_.itemsize(); dispatch_sync(stream->queue(), ^() { @autoreleasepool { diff --git a/test/test_mps.py b/test/test_mps.py index 8215642929ead..c85747adce181 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1222,6 +1222,68 @@ def test_slice(self): mps_slice4 = mps_x[1, :].to('cpu') self.assertEqual(cpu_slice4, mps_slice4) + def test_slice_contiguous_view(self): + # https://github.com/pytorch/pytorch/issues/77750 + + def helper(operator): + t_mps = torch.tensor([1, 2, 3, 4], device="mps") + t_cpu = torch.tensor([1, 2, 3, 4], device="cpu") + + # contiguous view + x_mps = t_mps[2:] # 3, 4 + y_mps = t_mps[:2] # 1, 2 + + x_cpu = t_cpu[2:] + y_cpu = t_cpu[:2] + + res_mps = res_cpu = None + if operator == "<=": + res_mps = x_mps <= y_mps + res_cpu = x_cpu <= y_cpu + if operator == "<": + res_mps = x_mps < y_mps + res_cpu = x_cpu < y_cpu + if operator == ">=": + res_mps = x_mps >= y_mps + res_cpu = x_cpu >= y_cpu + if operator == ">": + res_mps = x_mps >= y_mps + res_cpu = x_cpu >= y_cpu + if operator == "==": + res_mps = x_mps == y_mps + res_cpu = x_cpu == y_cpu + if operator == "!=": + res_mps = x_mps != y_mps + res_cpu = x_cpu != y_cpu + + self.assertEqual(res_mps, res_cpu) + + for op in ["<=", "<", ">=", ">", "==", "!="]: + helper(op) + + def test_index_storage_offset(self): + # https://github.com/pytorch/pytorch/issues/78107 + + a = torch.tensor([8.2670e-01,-1.0293e+00]) + b_cpu = a[0] + c_cpu = a[1] + + # both 'b' and 'c' are views of 'a' + # 'b' has a storage offset of 0, while 'c' has a storage offset of 1 + # when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account, + # otherwise it ends with same value as 'b' + b = b_cpu.to('mps') + c = c_cpu.to('mps') + + res_mps = b > c + res_cpu = b_cpu > c_cpu + self.assertEqual(res_mps, res_cpu) + + + res_mps = c > b + res_cpu = c_cpu > b_cpu + self.assertEqual(res_mps, res_cpu) + def test_flatten(self): values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] cpu_x = torch.tensor(values, device='cpu') From f9523607827062ee23234bb04b2239c18f7b11a4 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Wed, 25 May 2022 11:24:51 -0700 Subject: [PATCH 3/3] Address PR's comments; remove is_view() checks; If a cached graph is found for a gather op, assume respective Tensor is a View --- aten/src/ATen/native/mps/OperationUtils.h | 2 +- aten/src/ATen/native/mps/OperationUtils.mm | 66 +++++++++---------- .../ATen/native/mps/operations/BinaryOps.mm | 6 +- aten/src/ATen/native/mps/operations/Copy.mm | 2 +- test/test_mps.py | 3 +- 5 files changed, 39 insertions(+), 40 deletions(-) diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 0a158b56bb39b..25a485f5d0574 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -76,7 +76,7 @@ class Placeholder { void allocateViewTensor(const at::Tensor& src) { - assert (src.is_view() && !_viewOutput.numel()); + assert (!_viewOutput.numel()); _viewOutput = at::native::empty_mps( src.sizes(), src.scalar_type(), diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index d11063cc175a8..d70608b0f9fc1 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -265,29 +265,32 @@ void printTensorNDArray(const Tensor& t) { [tdata printNDArray]; } -struct CachedGraph : public MPSCachedGraph -{ - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - IntArrayRef size_; - IntArrayRef stride_; - int64_t storage_offset_; -}; - -CachedGraph* _getCachedGraph(const at::Tensor& src) { +MPSCachedGraph* _getCachedGraph(const at::Tensor& src) { MPSGraphCache* cache_ = MPSGraphCache::getInstance(); string key = getStridedKey(src, src.sizes(), src.strides(), src.storage_offset()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + MPSCachedGraph* cachedGraph = cache_->LookUp(key); return cachedGraph; } -id _gatherViewTensor(const at::Tensor& src, id sourceBuffer, CachedGraph* cachedGraph, Tensor& output) { - assert (src.is_view()); +id _gatherViewTensor(const at::Tensor& src, id sourceBuffer, MPSCachedGraph* mpsCachedGraph, Tensor& output) { + TORCH_CHECK(mpsCachedGraph != nil); id device = MPSDevice::getInstance()->device(); MPSStream* stream = getCurrentMPSStream(); + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + IntArrayRef size_; + IntArrayRef stride_; + int64_t storage_offset_; + }; + + CachedGraph* cachedGraph = static_cast(mpsCachedGraph); + @autoreleasepool { MPSGraphTensor* inputTensor = cachedGraph->inputTensor_; MPSGraphTensorData* inputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer @@ -311,11 +314,8 @@ void printTensorNDArray(const Tensor& t) { } id gatherViewTensor(const at::Tensor& src, id sourceBuffer) { - assert (src.is_view()); - - CachedGraph* cachedGraph = _getCachedGraph(src); - if (cachedGraph) { - + MPSCachedGraph* mpsCachedGraph = _getCachedGraph(src); + if (mpsCachedGraph) { Tensor output = at::native::empty_mps( src.sizes(), src.scalar_type(), @@ -324,23 +324,18 @@ void printTensorNDArray(const Tensor& t) { c10::nullopt, c10::nullopt); - _gatherViewTensor(src, sourceBuffer, cachedGraph, output); + _gatherViewTensor(src, sourceBuffer, mpsCachedGraph, output); return __builtin_bit_cast(id, output.storage().data()); } return nil; } -id gatherViewTensorWithAllocatedMem(const at::Tensor& src, id sourceBuffer, Tensor& output) { - assert (src.is_view()); - - CachedGraph* cachedGraph = _getCachedGraph(src); - if (cachedGraph) { - _gatherViewTensor(src, sourceBuffer, cachedGraph, output); - return __builtin_bit_cast(id, output.storage().data()); - } +id gatherViewTensorWithAllocatedMem(const at::Tensor& src, id sourceBuffer, Tensor& output, MPSCachedGraph* mpsCachedGraph) { + TORCH_CHECK(mpsCachedGraph != nil); - return nil; + _gatherViewTensor(src, sourceBuffer, mpsCachedGraph, output); + return __builtin_bit_cast(id, output.storage().data()); } Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, @@ -350,11 +345,14 @@ void printTensorNDArray(const Tensor& t) { TORCH_CHECK(src_.is_mps(), "Placeholder storage has not been allocated on MPS device!"); // extract the pointer to MTLBuffer from the Tensor's storage id srcBuf = __builtin_bit_cast(id, src.storage().data()); - if (check_view && src.is_view()) { - allocateViewTensor(src); - id gatherTensor = gatherViewTensorWithAllocatedMem(src, srcBuf, _viewOutput); - if (gatherTensor) { - srcBuf = gatherTensor; + if (check_view) { + MPSCachedGraph* cachedGraph = _getCachedGraph(src); + if (cachedGraph) { + allocateViewTensor(src); + id gatherTensor = gatherViewTensorWithAllocatedMem(src, srcBuf, _viewOutput, cachedGraph); + if (gatherTensor) { + srcBuf = gatherTensor; + } } else { src_ = src.contiguous(); srcBuf = __builtin_bit_cast(id, src_.storage().data()); diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 3ec086fca22ce..9471bd180775a 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -56,17 +56,19 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Tensor& outpu } NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; - Placeholder selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, nullptr, true); - Placeholder otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, nullptr, true); + Placeholder selfPlaceholder; + Placeholder otherPlaceholder; if (is_self_scalar) { feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self.item(), self_dtype); } else { + selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, nullptr, true); feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); } if (is_other_scalar) { feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other.item(), other_dtype); } else { + otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, nullptr, true); feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); } Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr); diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 66daaab66b785..cce7f4d51f614 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -394,7 +394,7 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, options:options deallocator:nil]; sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr); - if (src_.is_view()) + if (src_.is_view() || !src_.is_contiguous()) sourceOffset += src_.storage_offset() * src_.itemsize(); dispatch_sync(stream->queue(), ^() { diff --git a/test/test_mps.py b/test/test_mps.py index c85747adce181..3c1adbdc5a8e2 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -315,10 +315,9 @@ def helper(input_shape, batch1_shape, batch2_shape): output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha) output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha) - print(output_cpu.shape) - print(output_mps.shape) self.assertEqual(output_cpu, output_mps) self.assertEqual(output_cpu.size(), output_mps.size()) + helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5)) helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5)) helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77))