Skip to content

Commit

Permalink
[RUNTIME] NDArray CopyFrom/To Bytes always synchronize (apache#6586)
Browse files Browse the repository at this point in the history
* [RUNTIME] NDArray CopyFrom/To Bytes always synchronize

The previous behavior of non-sync can be unsafe for GPU devices.
In particular, the need for an explicit synchronization could
leads to confusion behavior e.g. asnumpy does not immediately return
the right content for vulkan.

Also brings the requirement of array being contiguous.
Right now we encourage compact array since they are easier for optimization.
We can consider bring support later by introducing a compactify PackedFunc(which might need be jitted).
  • Loading branch information
tqchen authored and Tushar Dey committed Oct 15, 2020
1 parent 0262366 commit 24512bd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 2 additions & 4 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ class NDArray : public ObjectRef {
* \param data The source bytes to be copied from.
* \param nbytes The size of the buffer in bytes
* Must be equal to the size of the NDArray.
* \note The copy may happen asynchronously if it involves a GPU context.
* TVMSynchronize is necessary.
* \note The copy always triggers a TVMSynchronize.
*/
TVM_DLL void CopyFromBytes(const void* data, size_t nbytes);
/*!
Expand All @@ -97,8 +96,7 @@ class NDArray : public ObjectRef {
* \param data The source bytes to be copied from.
* \param nbytes The size of the data buffer.
* Must be equal to the size of the NDArray.
* \note The copy may happen asynchronously if it involves a GPU context.
* TVMSynchronize is necessary.
* \note The copy always triggers a TVMSynchronize.
*/
TVM_DLL void CopyToBytes(void* data, size_t nbytes) const;
/*!
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,12 @@ void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) {
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes) << "ArrayCopyFromBytes: size mismatch";
CHECK(IsContiguous(*handle)) << "ArrayCopyFromBytes only support contiguous array for now";
DeviceAPI::Get(handle->ctx)
->CopyDataFromTo(data, 0, handle->data, static_cast<size_t>(handle->byte_offset), nbytes,
cpu_ctx, handle->ctx, handle->dtype, nullptr);
// Synchronize in case data become unavailable later.
DeviceAPI::Get(handle->ctx)->StreamSync(handle->ctx, nullptr);
}

void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) {
Expand All @@ -81,9 +84,12 @@ void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) {
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch";
CHECK(IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now";
DeviceAPI::Get(handle->ctx)
->CopyDataFromTo(handle->data, static_cast<size_t>(handle->byte_offset), data, 0, nbytes,
handle->ctx, cpu_ctx, handle->dtype, nullptr);
// Synchronize in case data become unavailable later.
DeviceAPI::Get(handle->ctx)->StreamSync(handle->ctx, nullptr);
}

struct NDArray::Internal {
Expand Down

0 comments on commit 24512bd

Please sign in to comment.