diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 69983ed4ac84c..3c2ab0d6c2f8b 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -113,30 +113,37 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, src = src_; } id sourceBuffer = getMTLBufferStorage(src); - const size_t src_size = src.nbytes(); - // if there's anything wrong with source, we shouldn't return dst_ silently and must error out. - TORCH_CHECK(sourceBuffer && src_size > 0); + size_t src_total_size = src_.is_view() ? at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()) : + src.nbytes(); + size_t size_to_copy = src.nbytes(); // In case of dtype change, first convert src inplace if (src_.dtype() != dst_.dtype()) { copy_cast_mps(dst, src, sourceBuffer, sourceBuffer); + // Use the element size of dst to calculate the total size after casting + size_to_copy = (size_to_copy / src.element_size()) * dst.element_size(); } + // If there's anything wrong with source, we shouldn't return dst_ silently and must error out. + TORCH_INTERNAL_ASSERT(sourceBuffer && size_to_copy > 0); + TORCH_INTERNAL_ASSERT(src_total_size >= storage_byte_offset); + TORCH_INTERNAL_ASSERT(dst.nbytes() >= (dst.storage_offset() * dst.element_size())); + @autoreleasepool { MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared; NSUInteger alignedLength = 0; void* host_dst = dst.storage().data(); - void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)src_size, &alignedLength); + void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)src_total_size, &alignedLength); id destBuffer = [device newBufferWithBytesNoCopy:alignedPtr length:alignedLength options:options deallocator:nil]; NSUInteger destOffset = uintptr_t(host_dst) - uintptr_t(alignedPtr); // 4 bytes alignment required on macos for blits. - TORCH_CHECK(destOffset % 4 == 0, "Unaligned blit request"); + TORCH_INTERNAL_ASSERT(destOffset % 4 == 0, "Unaligned blit request"); - stream->copy_and_sync(sourceBuffer, destBuffer, src_size, storage_byte_offset, destOffset, non_blocking); + stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, storage_byte_offset, destOffset, non_blocking); [destBuffer release]; } if (!dst.is_same(dst_)) { @@ -155,26 +162,33 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, id device = MPSDevice::getInstance()->device(); auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize(); id destBuffer = getMTLBufferStorage(dst_); + uint64_t src_total_size = 0; if (src_.is_view()) { src = src_.to(dst_.dtype()).expand_as(dst_).contiguous(); + // Get the actual size of a View (takes into account the storage offset) + // For View tensors, the storage offset can be bigger than what's being reported by nbytes + src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()); } else { src = src_; if (src.dtype() != dst_.dtype()) { // In case of dtype change, perform conversion on source device src = src.to(dst_.dtype()); } + src_total_size = src.nbytes(); } + const size_t size_to_copy = src.nbytes(); const void* host_src = src.storage().data(); - uint64_t size = src.nbytes(); + TORCH_INTERNAL_ASSERT(src_total_size >= (src.storage_offset() * src.element_size())); + TORCH_INTERNAL_ASSERT(dst_.nbytes() >= dst_byte_offset); NSUInteger sourceOffset = 0; @autoreleasepool { MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared; NSUInteger alignedLength = 0; - void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size, &alignedLength); + void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength); id sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr length:alignedLength options:options @@ -183,7 +197,7 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, if (src_.is_view() || !src_.is_contiguous()) sourceOffset += src_.storage_offset() * src_.itemsize(); - stream->copy_and_sync(sourceBuffer, destBuffer, size, sourceOffset, dst_byte_offset, non_blocking); + stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking); [sourceBuffer release]; } diff --git a/test/test_mps.py b/test/test_mps.py index b98e88bc33b82..62db3e7700fbd 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1490,6 +1490,58 @@ def test_stride_of_strides(self) -> None: z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu") self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z) + def test_type_casting(self): + # https://github.com/pytorch/pytorch/issues/81567 + def helper(data, to_dtype): + a_cpu = torch.tensor(data) + a_mps = a_cpu.to(torch.device('mps')) + + res_cpu = a_cpu.type(to_dtype) + res_mps = a_mps.type(to_dtype) + self.assertEqual(res_cpu, res_mps) + + helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor) + + def test_to_casting(self): + # https://github.com/pytorch/pytorch/issues/81567 + def helper(data, to_dtype): + a_cpu = torch.tensor(data) + a_mps = a_cpu.to(torch.device('mps')) + + res_cpu = a_cpu.to(to_dtype) + res_mps = a_mps.to(to_dtype) + self.assertEqual(res_cpu, res_mps) + + helper([9.0, 3.0, 5.0, 4.0], torch.int64) + helper([9.0, 3.0, 5.0, 4.0], torch.float) + helper([9.0, 3.0, 5.0, 4.0], torch.int32) + helper([9.0, 3.0, 5.0, 4.0], torch.short) + helper([9.0, 3.0, 5.0, 4.0], torch.half) + helper([9.0, 3.0, 5.0, 4.0], torch.int8) + helper([9.0, 3.0, 5.0, 4.0], torch.uint8) + + def test_storage_offset_greater_than_src_nbytes(self): + # https://github.com/pytorch/pytorch/issues/80844 + n_tensors= 100 + n_tensor_elems = 784 + elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32) + + tensor_list = [] + for i in range(0, n_tensors - 1): + # create a list of contiguous view tensors (view tensor created by the slice op) + t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)] + tensor_list.append(t) + + for i in range(0, n_tensors - 1): + t = tensor_list[i].view(1, 784) + t_mps = t.to("mps") + self.assertEqual(t, t_mps.cpu()) class TestLogical(TestCase): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):