Skip to content

Commit

Permalink
Implemented kDLCPUPinned (cudaMallocHost) (apache#4985)
Browse files Browse the repository at this point in the history
* implement kDLCPUPinned

* Fix line endings

* Fix whitespace for linter

* cleanup up allocdataspace method
  • Loading branch information
jmorrill authored and Trevor Morris committed Apr 16, 2020
1 parent 505039b commit c0eeab1
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
1 change: 1 addition & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ inline const char* DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLCPUPinned: return "cpu_pinned";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
Expand Down
39 changes: 34 additions & 5 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,25 @@ class CUDADeviceAPI final : public DeviceAPI {
size_t nbytes,
size_t alignment,
DLDataType type_hint) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CHECK_EQ(256 % alignment, 0U)
<< "CUDA space is aligned at 256 bytes";
<< "CUDA space is aligned at 256 bytes";
void *ret;
CUDA_CALL(cudaMalloc(&ret, nbytes));
if (ctx.device_type == kDLCPUPinned) {
CUDA_CALL(cudaMallocHost(&ret, nbytes));
} else {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaMalloc(&ret, nbytes));
}
return ret;
}

void FreeDataSpace(TVMContext ctx, void* ptr) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaFree(ptr));
if (ctx.device_type == kDLCPUPinned) {
CUDA_CALL(cudaFreeHost(ptr));
} else {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaFree(ptr));
}
}

void CopyDataFromTo(const void* from,
Expand All @@ -137,6 +145,21 @@ class CUDADeviceAPI final : public DeviceAPI {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset;

if (ctx_from.device_type == kDLCPUPinned) {
ctx_from.device_type = kDLCPU;
}

if (ctx_to.device_type == kDLCPUPinned) {
ctx_to.device_type = kDLCPU;
}

// In case there is a copy from host mem to host mem */
if (ctx_to.device_type == kDLCPU && ctx_from.device_type == kDLCPU) {
memcpy(to, from, size);
return;
}

if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) {
Expand Down Expand Up @@ -235,5 +258,11 @@ TVM_REGISTER_GLOBAL("device_api.gpu")
*rv = static_cast<void*>(ptr);
});

TVM_REGISTER_GLOBAL("device_api.cpu_pinned")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});

} // namespace runtime
} // namespace tvm
4 changes: 3 additions & 1 deletion src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ void NDArray::CopyFromTo(const DLTensor* from,

CHECK(from->ctx.device_type == to->ctx.device_type
|| from->ctx.device_type == kDLCPU
|| to->ctx.device_type == kDLCPU)
|| to->ctx.device_type == kDLCPU
|| from->ctx.device_type == kDLCPUPinned
|| to->ctx.device_type == kDLCPUPinned)
<< "Can not copy across different ctx types directly";

// Use the context that is *not* a cpu context to get the correct device
Expand Down

0 comments on commit c0eeab1

Please sign in to comment.