Skip to content

Commit

Permalink
just try to allocate on device; fallback if that fails
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed Aug 15, 2024
1 parent 0d52d2a commit c845757
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions llmc/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,20 @@ void global_sum_deterministic(float* result, const Float* values, int count, cud

// allocate memory, preferrably on the
void cudaMallocConditionallyManaged(void** out, size_t bytes, const char *file, int line) {
size_t free, total;
cudaCheck(cudaMemGetInfo(&free, &total));
// check if we have enough space to pin the memory to device (with 1% slack)
if(100 * free < 99 * bytes) {
cudaCheck_(cudaMalloc((void**)out, bytes), file, line);
} else {
// if not, fallback to a managed allocation. It will be slower, but at least
// try to allocate `bytes` on device
cudaError_t err = cudaMalloc(out, bytes);
if(err == cudaErrorMemoryAllocation) {
// if that fails, fallback to a managed allocation. It will be slower, but at least
// it won't crash.
fprintf(stderr, "[WARN] Not enough space to allocate %zu bytes on device.\n"
" Falling back to managed allocation.\n Speed may be negatively affected.",
bytes);
cudaCheck_(cudaMallocManaged((void**)out, bytes), file, line);
fprintf(stderr, "[WARN] Not enough space to allocate %zu MiB on device.\n"
" Falling back to managed allocation.\n"
" Speed may be negatively affected.\n",
bytes / 1024 / 1024);
// reset the error before the next API call
cudaGetLastError();
cudaCheck_(cudaMallocManaged(out, bytes), file, line);
} else {
cudaCheck_(err, file, line);
}
}

Expand Down

0 comments on commit c845757

Please sign in to comment.