From e3d58738d89d05e1321026250f9f9ad0085af204 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Fri, 15 Jul 2022 10:32:24 -0700 Subject: [PATCH] Use 'computeStorageNbytesContiguous' to get the size just when src is a view --- aten/src/ATen/native/mps/operations/Copy.mm | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index e2d6ea521a66e..a27d424f39418 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -155,21 +155,23 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, id device = MPSDevice::getInstance()->device(); auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize(); id destBuffer = getMTLBufferStorage(dst_); + uint64_t size = 0; if (src_.is_view()) { src = src_.to(dst_.dtype()).expand_as(dst_).contiguous(); + // Get the actual size of a View (takes into account the storage offset) + // For View tensors, the storage offset can be bigger than what's being reported by nbytes + size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()); } else { src = src_; if (src.dtype() != dst_.dtype()) { // In case of dtype change, perform conversion on source device src = src.to(dst_.dtype()); } + size = src.nbytes(); } const void* host_src = src.storage().data(); - // Get the actual size of a View (takes into account the storage offset) - // For View tensors, the storage offset can be bigger than what's being reported by nbytes - uint64_t size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()); NSUInteger sourceOffset = 0; @autoreleasepool {