Skip to content

Commit

Permalink
Use 'computeStorageNbytesContiguous' to get the size just when src is…
Browse files Browse the repository at this point in the history
… a view
  • Loading branch information
DenisVieriu97 authored and Denis Vieriu committed Jul 15, 2022
1 parent 72e217e commit e3d5873
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,23 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
id<MTLDevice> device = MPSDevice::getInstance()->device();
auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
uint64_t size = 0;

if (src_.is_view()) {
src = src_.to(dst_.dtype()).expand_as(dst_).contiguous();
// Get the actual size of a View (takes into account the storage offset)
// For View tensors, the storage offset can be bigger than what's being reported by nbytes
size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset());
} else {
src = src_;
if (src.dtype() != dst_.dtype()) {
// In case of dtype change, perform conversion on source device
src = src.to(dst_.dtype());
}
size = src.nbytes();
}

const void* host_src = src.storage().data();
// Get the actual size of a View (takes into account the storage offset)
// For View tensors, the storage offset can be bigger than what's being reported by nbytes
uint64_t size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset());

NSUInteger sourceOffset = 0;
@autoreleasepool {
Expand Down

0 comments on commit e3d5873

Please sign in to comment.