From c7f88081cd45aa1b65036742d7208602abf784ca Mon Sep 17 00:00:00 2001 From: Ramin Azarmehr Date: Fri, 27 May 2022 18:46:52 -0400 Subject: [PATCH] Fix crashes in view tensors due to buffer size mismatch (#78247, #77886) --- aten/src/ATen/native/mps/OperationUtils.h | 3 +-- aten/src/ATen/native/mps/OperationUtils.mm | 19 ++++++------------- .../ATen/native/mps/operations/BinaryOps.mm | 6 +++--- .../native/mps/operations/LinearAlgebra.mm | 9 +++------ .../src/ATen/native/mps/operations/LossOps.mm | 1 + .../ATen/native/mps/operations/UnaryOps.mm | 5 ++--- 6 files changed, 16 insertions(+), 27 deletions(-) diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 25a485f5d0574..7860fcb2de35c 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -62,8 +62,7 @@ class Placeholder { public: Placeholder() : _placeholder(nullptr), _value(nullptr) {} Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr) {} - Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr, - bool check_view = false); + Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr); MPSGraphTensor* getMPSGraphTensor() { return _placeholder; } diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index d70608b0f9fc1..8f309c96e8884 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -338,14 +338,13 @@ void printTensorNDArray(const Tensor& t) { return __builtin_bit_cast(id, output.storage().data()); } -Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, - MPSShape *mpsShape, bool check_view) +Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape) { Tensor src_ = src; TORCH_CHECK(src_.is_mps(), "Placeholder storage has not been allocated on MPS device!"); // extract the pointer to MTLBuffer from the Tensor's storage id srcBuf = __builtin_bit_cast(id, src.storage().data()); - if (check_view) { + if (src.is_view()) { MPSCachedGraph* cachedGraph = _getCachedGraph(src); if (cachedGraph) { allocateViewTensor(src); @@ -358,24 +357,18 @@ void printTensorNDArray(const Tensor& t) { srcBuf = __builtin_bit_cast(id, src_.storage().data()); } } - - const size_t buf_size = [srcBuf length]; - // tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero. - // if buf_size is zero in here, it's not a user error. It could be a missing check for + // if buffer size is zero in here, it's not a user error. It could be a missing check for // tensor.numel() == 0 in our internal implementations of ops. - TORCH_INTERNAL_ASSERT(buf_size > 0, "Placeholder tensor is empty!"); - - TORCH_CHECK(src_.storage().nbytes() <= buf_size, "Placeholder buffer size (", buf_size, - ") is not large enough to contain the Tensor storage of size ", src_.storage().nbytes()); + TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!"); const MPSDataType mpsDataType = src_.dim() == 0 ? getMPSScalarType(src_.scalar_type()) : getMPSDataType(src_.scalar_type()); if (!mpsShape) mpsShape = getMPSShape(src_); _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf - shape:mpsShape - dataType:mpsDataType] autorelease]; + shape:mpsShape + dataType:mpsDataType] autorelease]; TORCH_INTERNAL_ASSERT(_value); _placeholder = mpsGraphTensor; } diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 9471bd180775a..19729aced6da3 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -62,16 +62,16 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Tensor& outpu if (is_self_scalar) { feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self.item(), self_dtype); } else { - selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, nullptr, true); + selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self); feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); } if (is_other_scalar) { feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other.item(), other_dtype); } else { - otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, nullptr, true); + otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other); feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); } - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index c0d15d8f512cf..22e5fa822b36a 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -338,12 +338,9 @@ void prepare_matrices_for_broadcasting( cachedGraph = static_cast(tmpCachedGraph); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self, - nullptr, true); - Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other, - nullptr, true); - Placeholder biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias, - nullptr, false); + Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self); + Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other); + Placeholder biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); NSDictionary* feeds = @{ diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index 35202fd70a5f4..26570b02ac0d5 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -288,6 +288,7 @@ void mse_loss_out_impl(const Tensor& input, const Tensor& target, Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss_squeezed); NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); feeds[targetPlaceholder.getMPSGraphTensor()] = targetPlaceholder.getMPSGraphTensorData(); if (weight.defined()) { diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index 528b1643ff6cf..23c9607b2cc79 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -13,9 +13,8 @@ typedef MPSGraphTensor* (^UnaryOpBlock)(MPSGraph*, MPSGraphTensor*); -void unary_op(const Tensor& self_t, const Tensor& output, std::string op_name, UnaryOpBlock unaryBlock) +void unary_op(const Tensor& self, const Tensor& output, std::string op_name, UnaryOpBlock unaryBlock) { - Tensor self = self_t.contiguous(at::MemoryFormat::Contiguous); if (!output.is_same_size(self)) { output.resize_(self.sizes()); } @@ -26,7 +25,7 @@ void unary_op(const Tensor& self_t, const Tensor& output, std::string op_name, U }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({self}); + string key = op_name + getTensorsStringKey({self}, /*use_scalar_value*/ false); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) {