Skip to content

Commit

Permalink
Get the correct size of the view tensor when copying from cpu to mps (#…
Browse files Browse the repository at this point in the history
…64)

* Get the correct size of the view tensor when copying from cpu to mps

* Use 'computeStorageNbytesContiguous' to get the size just when src is a view

* Add asserts and tests to check for storage_offset (fixes pytorch#81567, pytorch#80844)

* Add testcase for pytorch#80844

* Replace assert_allclose with assertEqual

* Replace TORCH_CHECK with TORCH_INTERNAL_ASSERT
  • Loading branch information
DenisVieriu97 authored and kulinseth committed Jul 19, 2022
1 parent cc295ba commit 8d6c1a3
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 9 deletions.
32 changes: 23 additions & 9 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -113,30 +113,37 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
src = src_;
}
id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
const size_t src_size = src.nbytes();
// if there's anything wrong with source, we shouldn't return dst_ silently and must error out.
TORCH_CHECK(sourceBuffer && src_size > 0);
size_t src_total_size = src_.is_view() ? at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()) :
src.nbytes();
size_t size_to_copy = src.nbytes();

// In case of dtype change, first convert src inplace
if (src_.dtype() != dst_.dtype()) {
copy_cast_mps(dst, src, sourceBuffer, sourceBuffer);
// Use the element size of dst to calculate the total size after casting
size_to_copy = (size_to_copy / src.element_size()) * dst.element_size();
}

// If there's anything wrong with source, we shouldn't return dst_ silently and must error out.
TORCH_INTERNAL_ASSERT(sourceBuffer && size_to_copy > 0);
TORCH_INTERNAL_ASSERT(src_total_size >= storage_byte_offset);
TORCH_INTERNAL_ASSERT(dst.nbytes() >= (dst.storage_offset() * dst.element_size()));

@autoreleasepool {
MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared;
NSUInteger alignedLength = 0;

void* host_dst = dst.storage().data();
void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)src_size, &alignedLength);
void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)src_total_size, &alignedLength);
id<MTLBuffer> destBuffer = [device newBufferWithBytesNoCopy:alignedPtr
length:alignedLength
options:options
deallocator:nil];
NSUInteger destOffset = uintptr_t(host_dst) - uintptr_t(alignedPtr);
// 4 bytes alignment required on macos for blits.
TORCH_CHECK(destOffset % 4 == 0, "Unaligned blit request");
TORCH_INTERNAL_ASSERT(destOffset % 4 == 0, "Unaligned blit request");

stream->copy_and_sync(sourceBuffer, destBuffer, src_size, storage_byte_offset, destOffset, non_blocking);
stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, storage_byte_offset, destOffset, non_blocking);
[destBuffer release];
}
if (!dst.is_same(dst_)) {
Expand All @@ -155,26 +162,33 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
id<MTLDevice> device = MPSDevice::getInstance()->device();
auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
uint64_t src_total_size = 0;

if (src_.is_view()) {
src = src_.to(dst_.dtype()).expand_as(dst_).contiguous();
// Get the actual size of a View (takes into account the storage offset)
// For View tensors, the storage offset can be bigger than what's being reported by nbytes
src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset());
} else {
src = src_;
if (src.dtype() != dst_.dtype()) {
// In case of dtype change, perform conversion on source device
src = src.to(dst_.dtype());
}
src_total_size = src.nbytes();
}

const size_t size_to_copy = src.nbytes();
const void* host_src = src.storage().data();
uint64_t size = src.nbytes();
TORCH_INTERNAL_ASSERT(src_total_size >= (src.storage_offset() * src.element_size()));
TORCH_INTERNAL_ASSERT(dst_.nbytes() >= dst_byte_offset);

NSUInteger sourceOffset = 0;
@autoreleasepool {
MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared;
NSUInteger alignedLength = 0;

void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size, &alignedLength);
void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength);
id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
length:alignedLength
options:options
Expand All @@ -183,7 +197,7 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
if (src_.is_view() || !src_.is_contiguous())
sourceOffset += src_.storage_offset() * src_.itemsize();

stream->copy_and_sync(sourceBuffer, destBuffer, size, sourceOffset, dst_byte_offset, non_blocking);
stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
[sourceBuffer release];
}

Expand Down
52 changes: 52 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,58 @@ def test_stride_of_strides(self) -> None:
z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)

def test_type_casting(self):
# https://github.com/pytorch/pytorch/issues/81567
def helper(data, to_dtype):
a_cpu = torch.tensor(data)
a_mps = a_cpu.to(torch.device('mps'))

res_cpu = a_cpu.type(to_dtype)
res_mps = a_mps.type(to_dtype)
self.assertEqual(res_cpu, res_mps)

helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor)

def test_to_casting(self):
# https://github.com/pytorch/pytorch/issues/81567
def helper(data, to_dtype):
a_cpu = torch.tensor(data)
a_mps = a_cpu.to(torch.device('mps'))

res_cpu = a_cpu.to(to_dtype)
res_mps = a_mps.to(to_dtype)
self.assertEqual(res_cpu, res_mps)

helper([9.0, 3.0, 5.0, 4.0], torch.int64)
helper([9.0, 3.0, 5.0, 4.0], torch.float)
helper([9.0, 3.0, 5.0, 4.0], torch.int32)
helper([9.0, 3.0, 5.0, 4.0], torch.short)
helper([9.0, 3.0, 5.0, 4.0], torch.half)
helper([9.0, 3.0, 5.0, 4.0], torch.int8)
helper([9.0, 3.0, 5.0, 4.0], torch.uint8)

def test_storage_offset_greater_than_src_nbytes(self):
# https://github.com/pytorch/pytorch/issues/80844
n_tensors= 100
n_tensor_elems = 784
elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32)

tensor_list = []
for i in range(0, n_tensors - 1):
# create a list of contiguous view tensors (view tensor created by the slice op)
t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)]
tensor_list.append(t)

for i in range(0, n_tensors - 1):
t = tensor_list[i].view(1, 784)
t_mps = t.to("mps")
self.assertEqual(t, t_mps.cpu())

class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
Expand Down

0 comments on commit 8d6c1a3

Please sign in to comment.