Skip to content

Commit

Permalink
Copy fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
kulinseth committed Nov 3, 2022
1 parent f243162 commit 58747c2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 29 deletions.
5 changes: 0 additions & 5 deletions aten/src/ATen/native/mps/operations/BitwiseOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,6 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot
getMetalType(self),
getMetalType(self),
"bitwise_not");
uint32_t length = output.numel();
if (length == 0) {
return output_;
}

dispatch_sync(stream->queue(), ^(){
id<MTLCommandBuffer> buffer = stream->commandBuffer();
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
Expand Down
31 changes: 7 additions & 24 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -192,27 +192,10 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo
{
MPSStream* stream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
uint64_t src_total_size = 0;

// This is weird, but sometimes this function can be called
// with contiguous destination and non-contiguous source
if (src_.is_view() || dst_.is_contiguous() != src_.is_contiguous()) {
src = src_.to(dst_.dtype()).expand_as(dst_).contiguous();
// Get the actual size of a View (takes into account the storage offset)
// For View tensors, the storage offset can be bigger than what's being reported by nbytes
src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset());
} else {
TORCH_INTERNAL_ASSERT(src_.strides() == dst_.strides());
src = src_;
if (src.dtype() != dst_.dtype()) {
// In case of dtype change, perform conversion on source device
src = src.to(dst_.dtype());
}
src_total_size = src.nbytes();
}

auto dst_byte_offset = dst.storage_offset() * dst.itemsize();
auto src_byte_offset = src.storage_offset() * src.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst);
const size_t size_to_copy = src.nbytes();
const void* host_src = static_cast<char *>(src.storage().data()) + src_byte_offset;

Expand All @@ -223,15 +206,15 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo
NSUInteger alignedLength = 0;
NSUInteger sourceOffset = 0;

void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength);
void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength);
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);

id<MTLBuffer> sourceBuffer = nil;
// If the destination is a strided MPS tensor, we cannot perform a blit directly to copy the
// memory from the CPU tensor into the MPS tensor. We need to scatter the data into the right indices
bool doScatter = (!dst_.is_contiguous() && src.is_contiguous());
bool doScatter = (!dst.is_contiguous() && src.is_contiguous());
if (doScatter) {
sourceBuffer = [device newBufferWithBytes:(void*)((uint8_t*)host_src + (src_.storage_offset() * src_.itemsize()))
sourceBuffer = [device newBufferWithBytes:(void*)((uint8_t*)host_src + (size_to_copy))
length:size_to_copy
options:options];
}
Expand All @@ -243,7 +226,7 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo
}

if (doScatter) {
scatterViewTensor(src, dst_, sourceBuffer);
scatterViewTensor(src, dst, sourceBuffer);
} else {
stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
}
Expand Down

0 comments on commit 58747c2

Please sign in to comment.