diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 81e4066ee..d3df74dc8 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -222,6 +222,7 @@ void cudaMallocConditionallyManaged(void** out, size_t bytes, const char *file, // reset the error before the next API call cudaGetLastError(); cudaCheck_(cudaMallocManaged(out, bytes), file, line); + cudaCheck_(cudaMemAdvise(*out, bytes, cudaMemAdviseSetPreferredLocation, cudaCpuDeviceId), file, line); } else { cudaCheck_(err, file, line); }