diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 7db4eb1318b04..25a485f5d0574 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -74,9 +74,22 @@ class Placeholder { return _value == nullptr; } + void allocateViewTensor(const at::Tensor& src) + { + assert (!_viewOutput.numel()); + _viewOutput = at::native::empty_mps( + src.sizes(), + src.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + } + private: MPSGraphTensor* _placeholder; MPSGraphTensorData* _value; + Tensor _viewOutput; }; void resize_tensor(Tensor* output); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 81d9908d09f49..d70608b0f9fc1 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -265,57 +265,79 @@ void printTensorNDArray(const Tensor& t) { [tdata printNDArray]; } -id gatherViewTensor(const at::Tensor& src, id sourceBuffer) { - assert (!src.is_contiguous()); +MPSCachedGraph* _getCachedGraph(const at::Tensor& src) { + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + string key = getStridedKey(src, src.sizes(), src.strides(), src.storage_offset()); + MPSCachedGraph* cachedGraph = cache_->LookUp(key); + + return cachedGraph; +} + +id _gatherViewTensor(const at::Tensor& src, id sourceBuffer, MPSCachedGraph* mpsCachedGraph, Tensor& output) { + TORCH_CHECK(mpsCachedGraph != nil); + id device = MPSDevice::getInstance()->device(); MPSStream* stream = getCurrentMPSStream(); + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + IntArrayRef size_; + IntArrayRef stride_; + int64_t storage_offset_; + }; + + CachedGraph* cachedGraph = static_cast(mpsCachedGraph); + @autoreleasepool { - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - IntArrayRef size_; - IntArrayRef stride_; - int64_t storage_offset_; + MPSGraphTensor* inputTensor = cachedGraph->inputTensor_; + MPSGraphTensorData* inputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer + shape: [inputTensor shape] + dataType: [inputTensor dataType]] autorelease]; + id resultBuffer = __builtin_bit_cast(id, output.storage().data()); + MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: resultBuffer + shape: getMPSShape(src.sizes()) + dataType: getMPSDataType(src.scalar_type())] autorelease]; + NSDictionary* feeds = @{ + inputTensor : inputTensorData }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - string key = getStridedKey(src, src.sizes(), src.strides(), src.storage_offset()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if (cachedGraph) { - @autoreleasepool { - MPSGraphTensor* inputTensor = cachedGraph->inputTensor_; - auto output = at::native::empty_mps( - src.sizes(), - src.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); - MPSGraphTensorData* inputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer - shape: [inputTensor shape] - dataType: [inputTensor dataType]] autorelease]; - id resultBuffer = __builtin_bit_cast(id, output.storage().data()); - MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: resultBuffer - shape: getMPSShape(src.sizes()) - dataType: getMPSDataType(src.scalar_type())] autorelease]; - NSDictionary* feeds = @{ - inputTensor : inputTensorData - }; - - NSDictionary* results = @{ - cachedGraph->outputTensor_ : outputTensorData - }; - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - return resultBuffer; - } - } + NSDictionary* results = @{ + cachedGraph->outputTensor_ : outputTensorData + }; + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + return resultBuffer; + } +} + +id gatherViewTensor(const at::Tensor& src, id sourceBuffer) { + MPSCachedGraph* mpsCachedGraph = _getCachedGraph(src); + if (mpsCachedGraph) { + Tensor output = at::native::empty_mps( + src.sizes(), + src.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + + _gatherViewTensor(src, sourceBuffer, mpsCachedGraph, output); + return __builtin_bit_cast(id, output.storage().data()); } + return nil; } +id gatherViewTensorWithAllocatedMem(const at::Tensor& src, id sourceBuffer, Tensor& output, MPSCachedGraph* mpsCachedGraph) { + TORCH_CHECK(mpsCachedGraph != nil); + + _gatherViewTensor(src, sourceBuffer, mpsCachedGraph, output); + return __builtin_bit_cast(id, output.storage().data()); +} + Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape, bool check_view) { @@ -323,15 +345,20 @@ void printTensorNDArray(const Tensor& t) { TORCH_CHECK(src_.is_mps(), "Placeholder storage has not been allocated on MPS device!"); // extract the pointer to MTLBuffer from the Tensor's storage id srcBuf = __builtin_bit_cast(id, src.storage().data()); - if (check_view && !src.is_contiguous()) { - id gatherTensor = gatherViewTensor(src, srcBuf); - if (gatherTensor) { - srcBuf = gatherTensor; + if (check_view) { + MPSCachedGraph* cachedGraph = _getCachedGraph(src); + if (cachedGraph) { + allocateViewTensor(src); + id gatherTensor = gatherViewTensorWithAllocatedMem(src, srcBuf, _viewOutput, cachedGraph); + if (gatherTensor) { + srcBuf = gatherTensor; + } } else { src_ = src.contiguous(); srcBuf = __builtin_bit_cast(id, src_.storage().data()); } } + const size_t buf_size = [srcBuf length]; // tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero. diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 9c37d930d2e72..9471bd180775a 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -21,22 +21,19 @@ typedef MPSGraphTensor* (^BinaryOpBlock)(MPSGraph*, MPSGraphTensor*, MPSGraphTensor*); #define BinaryOpFn() MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* primary, MPSGraphTensor* secondary) -void binaryOpTensor(const Tensor& self_t, const Tensor& other_t, const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock) +void binaryOpTensor(const Tensor& self, const Tensor& other, const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock) { // it's possible to receive empty tensors here - if (self_t.numel() == 0 || other_t.numel() == 0) { + if (self.numel() == 0 || other.numel() == 0) { return; } MPSStream* mpsStream = getCurrentMPSStream(); - const bool is_self_scalar = self_t.dim() == 0; - const bool is_other_scalar = other_t.dim() == 0; + const bool is_self_scalar = self.dim() == 0; + const bool is_other_scalar = other.dim() == 0; - Tensor self = is_self_scalar ? self_t : self_t.contiguous(at::MemoryFormat::Contiguous); - Tensor other = is_other_scalar ? other_t : other_t.contiguous(at::MemoryFormat::Contiguous); - - const MPSDataType self_dtype = getMPSScalarType((is_self_scalar && !is_other_scalar ? other_t : self_t).scalar_type()); - const MPSDataType other_dtype = getMPSScalarType((!is_other_scalar ? other_t : self_t).scalar_type()); + const MPSDataType self_dtype = getMPSScalarType((is_self_scalar && !is_other_scalar ? other : self).scalar_type()); + const MPSDataType other_dtype = getMPSScalarType((!is_other_scalar ? other : self).scalar_type()); MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { @@ -58,20 +55,23 @@ void binaryOpTensor(const Tensor& self_t, const Tensor& other_t, const Tensor& o cachedGraph = static_cast(tmpCachedGraph); } - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + Placeholder selfPlaceholder; + Placeholder otherPlaceholder; + if (is_self_scalar) { feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self.item(), self_dtype); } else { - Placeholder selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self); + selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, nullptr, true); feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); } if (is_other_scalar) { feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other.item(), other_dtype); } else { - Placeholder otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other); + otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, nullptr, true); feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); } - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr); NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index e1e2a9e24fc04..cce7f4d51f614 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -114,7 +114,7 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, // 0 sizes won't result in any change in the shape of the Tensor so we can // skip it. Also if the memory is contiguous we don't need to do // gather-scatter operations using graph. - if (size.size() > 0 && (!result.is_contiguous())) { + if (size.size() > 0) { // If self itself was a view tensor, that means we need to chain the graphs // else we will create a new entry in the cache @@ -287,11 +287,6 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, } else { dst = dst_; } - dst._set_conj(dst_.is_conj()); - src._set_conj(src_.is_conj()); - - dst._set_neg(dst_.is_neg()); - src._set_neg(src_.is_neg()); auto storage_byte_offset = src_.storage_offset() * src_.itemsize(); id sourceBuffer = __builtin_bit_cast(id, src_.storage().data()); @@ -399,6 +394,8 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, options:options deallocator:nil]; sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr); + if (src_.is_view() || !src_.is_contiguous()) + sourceOffset += src_.storage_offset() * src_.itemsize(); dispatch_sync(stream->queue(), ^() { @autoreleasepool { diff --git a/test/test_mps.py b/test/test_mps.py index 8215642929ead..3c1adbdc5a8e2 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -315,10 +315,9 @@ def helper(input_shape, batch1_shape, batch2_shape): output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha) output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha) - print(output_cpu.shape) - print(output_mps.shape) self.assertEqual(output_cpu, output_mps) self.assertEqual(output_cpu.size(), output_mps.size()) + helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5)) helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5)) helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77)) @@ -1222,6 +1221,68 @@ def test_slice(self): mps_slice4 = mps_x[1, :].to('cpu') self.assertEqual(cpu_slice4, mps_slice4) + def test_slice_contiguous_view(self): + # https://github.com/pytorch/pytorch/issues/77750 + + def helper(operator): + t_mps = torch.tensor([1, 2, 3, 4], device="mps") + t_cpu = torch.tensor([1, 2, 3, 4], device="cpu") + + # contiguous view + x_mps = t_mps[2:] # 3, 4 + y_mps = t_mps[:2] # 1, 2 + + x_cpu = t_cpu[2:] + y_cpu = t_cpu[:2] + + res_mps = res_cpu = None + if operator == "<=": + res_mps = x_mps <= y_mps + res_cpu = x_cpu <= y_cpu + if operator == "<": + res_mps = x_mps < y_mps + res_cpu = x_cpu < y_cpu + if operator == ">=": + res_mps = x_mps >= y_mps + res_cpu = x_cpu >= y_cpu + if operator == ">": + res_mps = x_mps >= y_mps + res_cpu = x_cpu >= y_cpu + if operator == "==": + res_mps = x_mps == y_mps + res_cpu = x_cpu == y_cpu + if operator == "!=": + res_mps = x_mps != y_mps + res_cpu = x_cpu != y_cpu + + self.assertEqual(res_mps, res_cpu) + + for op in ["<=", "<", ">=", ">", "==", "!="]: + helper(op) + + def test_index_storage_offset(self): + # https://github.com/pytorch/pytorch/issues/78107 + + a = torch.tensor([8.2670e-01,-1.0293e+00]) + b_cpu = a[0] + c_cpu = a[1] + + # both 'b' and 'c' are views of 'a' + # 'b' has a storage offset of 0, while 'c' has a storage offset of 1 + # when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account, + # otherwise it ends with same value as 'b' + b = b_cpu.to('mps') + c = c_cpu.to('mps') + + res_mps = b > c + res_cpu = b_cpu > c_cpu + self.assertEqual(res_mps, res_cpu) + + + res_mps = c > b + res_cpu = c_cpu > b_cpu + self.assertEqual(res_mps, res_cpu) + def test_flatten(self): values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] cpu_x = torch.tensor(values, device='cpu')