diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 352a0be3692e4..3e96c99fd408b 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -237,10 +237,10 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { // If dst is contiguous and there is no byte offset, we can save directly the result of // gather into dst. This reduces the overhead of doing an additional blit for most cases - bool returnGatherOutput = (dst_.is_contiguous() && !dst_byte_offset); + bool returnGatherOutput = (dst_.is_contiguous() && !dst_byte_offset && src_.dtype() == dst_.dtype()); Tensor src; - if (!src_.is_contiguous()) { + if (src_.is_view() || !src_.is_contiguous()) { Tensor emptyShell = Tensor(); src = gatherViewTensor(src_, returnGatherOutput ? dst_ : emptyShell); diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index 5c9dccf4aa068..bfb2852e0a51a 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -192,13 +192,13 @@ NSUInteger targetDimLength = currDimLength; NSUInteger currReshapeSize = 1; NSUInteger innerStride = srcStride; - do { + + while (currReshapeSize != targetDimLength && srcDim >= 0) { NSUInteger srcDimLength = [[inputTensor shape][srcDim] integerValue]; currReshapeSize *= srcDimLength; srcStride *= srcDimLength; - srcDim--; - } while(currReshapeSize != targetDimLength && srcDim >= 0); + }; isValidReshape &= (currReshapeSize == targetDimLength && currStride == innerStride); } @@ -516,6 +516,24 @@ return outputTensor; } +static IntArrayRef updateTensorBaseShape(const Tensor& self) +{ + IntArrayRef base_shape = get_buffer_shape(self.storage().data()); + // if there's no base_shape stored in MPSAllocator, then infer it from tensor's size and store it + if (base_shape.size() == 0) { + // IntArrayRef wouldn't own the data, so we use a static storage + static const int64_t shape_1d = 1; + // self.sizes().size() could be zero + base_shape = self.sizes().size() ? self.sizes() : + ((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1)); + + // base_shape will be retained in MPSAllocator until buffer gets recycled + if (self.storage().data()) + set_buffer_shape(self.storage().data(), base_shape); + } + return base_shape; +} + // There are few cases we need to consider: // Here nodes are the Tensors and the edges are the operations performed on the // Tensor. As a result of the operation performed we can have result as View @@ -535,22 +553,11 @@ // NonView T NonView T static ViewCachedGraph* createViewGraph(const Tensor& self, IntArrayRef size, IntArrayRef stride, int64_t storage_offset, bool needsScatter) { - IntArrayRef base_shape = get_buffer_shape(self.storage().data()); - if (base_shape.size() == 0) { - // IntArrayRef wouldn't own the data, so we use a static storage - static const int64_t shape_1d = 1; - // self.sizes().size() could be zero - base_shape = self.sizes().size() ? self.sizes() : - self.is_view() ? self._base().sizes() : IntArrayRef(&shape_1d, 1); - - // base_shape will be retained in MPSAllocator until buffer gets recycled - if (self.storage().data()) - set_buffer_shape(self.storage().data(), base_shape); - } - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + IntArrayRef base_shape = updateTensorBaseShape(self); @autoreleasepool { string key = getStridedKey(self.scalar_type(), base_shape, size, stride, storage_offset, needsScatter); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); ViewCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if (!cachedGraph) { @@ -586,26 +593,17 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { - ViewCachedGraph* cachedGraph = nullptr; - - const IntArrayRef& base_shape = get_buffer_shape(src.storage().data()); - if (base_shape.size() > 0) { - string key = getStridedKey(src.scalar_type(), base_shape, src.sizes(), src.strides(), src.storage_offset(), /*is_scatter*/ false); - cachedGraph = static_cast(MPSGraphCache::getInstance()->LookUp(key)); - } - // there are cases where gatherViewTensor() is called without having as_strided() called beforehand. - // this typically may come from copy_mps variants. In such cases, when the base_shape isn't found the - // callers would resort to make the tensor contiguous in an alternative code path. - if (!cachedGraph) { + if (src.sizes().size() == 0) { return Tensor(); } - bool requires_sync = false; Tensor output; if (!dst.has_storage()) { output = at::native::empty_mps(src.sizes(), src.scalar_type(), c10::nullopt, kMPS); requires_sync = true; } + ViewCachedGraph* cachedGraph = createViewGraph(src, src.sizes(), src.strides(), + src.storage_offset(), /*needsScatter*/ false); return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false, requires_sync); } @@ -625,9 +623,11 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayR auto result = detail::make_tensor(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); setStrided(result, size, stride, storage_offset); - // 0 sizes won't result in any change in the shape of the Tensor so we can skip it. - if (size.size() > 0) - mps::createViewGraph(self, size, stride, storage_offset, /*needsScatter*/ false); + // creating the view graph will be deferred until gatherViewTensor() or scatterViewTensor() are called. + // In as_strided, we just update the base shape of the buffer in order to retrieve it later + // when we create/run the view graph. + IntArrayRef base_shape = mps::updateTensorBaseShape(self); + TORCH_INTERNAL_ASSERT(base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data()); return result; }