Skip to content

Commit 0d52d2a

Browse files
committed
fall back to cudaMallocManaged for optimizer states if we're out of memory
1 parent 4c84bc7 commit 0d52d2a

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

llmc/cuda_common.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ constexpr std::bool_constant<true> False;
4949
// Error checking
5050

5151
// CUDA error checking
52-
inline void cudaCheck(cudaError_t error, const char *file, int line) {
52+
inline void cudaCheck_(cudaError_t error, const char *file, int line) {
5353
if (error != cudaSuccess) {
5454
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, cudaGetErrorString(error));
5555
exit(EXIT_FAILURE);
5656
}
5757
};
58-
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))
58+
#define cudaCheck(err) (cudaCheck_(err, __FILE__, __LINE__))
5959

6060
// like cudaFree, but checks for errors _and_ resets the pointer.
6161
template<class T>

llmc/cuda_utils.cuh

+23
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,29 @@ void global_sum_deterministic(float* result, const Float* values, int count, cud
205205
cudaCheck(cudaGetLastError());
206206
}
207207

208+
// ----------------------------------------------------------------------------
209+
// memory management
210+
211+
// allocate memory, preferrably on the
212+
void cudaMallocConditionallyManaged(void** out, size_t bytes, const char *file, int line) {
213+
size_t free, total;
214+
cudaCheck(cudaMemGetInfo(&free, &total));
215+
// check if we have enough space to pin the memory to device (with 1% slack)
216+
if(100 * free < 99 * bytes) {
217+
cudaCheck_(cudaMalloc((void**)out, bytes), file, line);
218+
} else {
219+
// if not, fallback to a managed allocation. It will be slower, but at least
220+
// it won't crash.
221+
fprintf(stderr, "[WARN] Not enough space to allocate %zu bytes on device.\n"
222+
" Falling back to managed allocation.\n Speed may be negatively affected.",
223+
bytes);
224+
cudaCheck_(cudaMallocManaged((void**)out, bytes), file, line);
225+
}
226+
}
227+
228+
#define cudaMallocConditionallyManaged(out, bytes)\
229+
(cudaMallocConditionallyManaged((void**)out, bytes, __FILE__, __LINE__))
230+
208231
// ----------------------------------------------------------------------------
209232
// Random Number Generation used in Stochastic Rounding
210233

train_gpt2.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,13 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) {
393393
printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20);
394394
assert(model->m_memory == nullptr);
395395
assert(model->v_memory == nullptr);
396-
cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float)));
397-
cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float)));
396+
cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(float));
397+
cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(float));
398398

399399
if (model->use_master_weights == 1) {
400400
assert(model->master_weights == nullptr);
401401
printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20);
402-
cudaCheck(cudaMalloc((void**) &model->master_weights, shard_num_parameters * sizeof(float)));
402+
cudaMallocConditionallyManaged((void**) &model->master_weights, shard_num_parameters * sizeof(float));
403403
}
404404

405405
size_t free, total;

0 commit comments

Comments
 (0)