Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions unified-runtime/source/adapters/cuda/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,41 @@ urUSMPoolTrimToExp(ur_context_handle_t hContext, ur_device_handle_t hDevice,
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMContextMemcpyExp(ur_context_handle_t,
void *pDst,
const void *pSrc,
size_t Size) {
UR_APIEXPORT ur_result_t UR_APICALL urUSMContextMemcpyExp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a performance-critical function. Can we just always synchronize at the end? Would make the code simpler.

ur_context_handle_t hContext, void *pDst, const void *pSrc, size_t Size) {

CUmemorytype memTypeDst;
UR_CHECK_ERROR(cuPointerGetAttribute(
&memTypeDst, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, (CUdeviceptr)pDst));

CUmemorytype memTypeSrc;
UR_CHECK_ERROR(cuPointerGetAttribute(
&memTypeSrc, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, (CUdeviceptr)pSrc));

if (memTypeDst != CU_MEMORYTYPE_DEVICE ||
memTypeSrc != CU_MEMORYTYPE_DEVICE) {
UR_CHECK_ERROR(cuMemcpy((CUdeviceptr)pDst, (CUdeviceptr)pSrc, Size));
return UR_RESULT_SUCCESS;
}

// For device to device copy, we need to synchronize the host with the memcpy
// operation to ensure the copy is completed.
// For details, see:
// https://docs.nvidia.com/cuda/cuda-driver-api/api-sync-behavior.html.
unsigned int devIdx = 0;
UR_CHECK_ERROR(cuPointerGetAttribute(
&devIdx, CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, (CUdeviceptr)pDst));
if (devIdx >= hContext->getDevices().size()) {
return UR_RESULT_ERROR_INVALID_CONTEXT;
}

// Device ordinal obtained using cuPointerGetAttribute corresponds to the
// device index in the context's device list.
ur_device_handle_t owningDev = hContext->getDevices()[devIdx];

ScopedContext Active(owningDev);
UR_CHECK_ERROR(cuMemcpy((CUdeviceptr)pDst, (CUdeviceptr)pSrc, Size));
UR_CHECK_ERROR(cuCtxSynchronize());

return UR_RESULT_SUCCESS;
}
Loading