diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 00508a11b042..1a5eeebbdb6f 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -207,6 +207,7 @@ inline const char* DeviceName(int type) { switch (type) { case kDLCPU: return "cpu"; case kDLGPU: return "gpu"; + case kDLCPUPinned: return "cpu_pinned"; case kDLOpenCL: return "opencl"; case kDLSDAccel: return "sdaccel"; case kDLAOCL: return "aocl"; diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 86b52518354c..d9f03e773bc9 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -112,17 +112,25 @@ class CUDADeviceAPI final : public DeviceAPI { size_t nbytes, size_t alignment, DLDataType type_hint) final { - CUDA_CALL(cudaSetDevice(ctx.device_id)); CHECK_EQ(256 % alignment, 0U) - << "CUDA space is aligned at 256 bytes"; + << "CUDA space is aligned at 256 bytes"; void *ret; - CUDA_CALL(cudaMalloc(&ret, nbytes)); + if (ctx.device_type == kDLCPUPinned) { + CUDA_CALL(cudaMallocHost(&ret, nbytes)); + } else { + CUDA_CALL(cudaSetDevice(ctx.device_id)); + CUDA_CALL(cudaMalloc(&ret, nbytes)); + } return ret; } void FreeDataSpace(TVMContext ctx, void* ptr) final { - CUDA_CALL(cudaSetDevice(ctx.device_id)); - CUDA_CALL(cudaFree(ptr)); + if (ctx.device_type == kDLCPUPinned) { + CUDA_CALL(cudaFreeHost(ptr)); + } else { + CUDA_CALL(cudaSetDevice(ctx.device_id)); + CUDA_CALL(cudaFree(ptr)); + } } void CopyDataFromTo(const void* from, @@ -137,6 +145,21 @@ class CUDADeviceAPI final : public DeviceAPI { cudaStream_t cu_stream = static_cast(stream); from = static_cast(from) + from_offset; to = static_cast(to) + to_offset; + + if (ctx_from.device_type == kDLCPUPinned) { + ctx_from.device_type = kDLCPU; + } + + if (ctx_to.device_type == kDLCPUPinned) { + ctx_to.device_type = kDLCPU; + } + + // In case there is a copy from host mem to host mem */ + if (ctx_to.device_type == kDLCPU && ctx_from.device_type == kDLCPU) { + memcpy(to, from, size); + return; + } + if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) { CUDA_CALL(cudaSetDevice(ctx_from.device_id)); if (ctx_from.device_id == ctx_to.device_id) { @@ -235,5 +258,11 @@ TVM_REGISTER_GLOBAL("device_api.gpu") *rv = static_cast(ptr); }); +TVM_REGISTER_GLOBAL("device_api.cpu_pinned") +.set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global().get(); + *rv = static_cast(ptr); + }); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index ff2f34ee6e4e..99594ee50f93 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -233,7 +233,9 @@ void NDArray::CopyFromTo(const DLTensor* from, CHECK(from->ctx.device_type == to->ctx.device_type || from->ctx.device_type == kDLCPU - || to->ctx.device_type == kDLCPU) + || to->ctx.device_type == kDLCPU + || from->ctx.device_type == kDLCPUPinned + || to->ctx.device_type == kDLCPUPinned) << "Can not copy across different ctx types directly"; // Use the context that is *not* a cpu context to get the correct device