@@ -205,6 +205,29 @@ void global_sum_deterministic(float* result, const Float* values, int count, cud
205
205
cudaCheck (cudaGetLastError ());
206
206
}
207
207
208
+ // ----------------------------------------------------------------------------
209
+ // memory management
210
+
211
+ // allocate memory, preferrably on the
212
+ void cudaMallocConditionallyManaged (void ** out, size_t bytes, const char *file, int line) {
213
+ size_t free , total;
214
+ cudaCheck (cudaMemGetInfo (&free , &total));
215
+ // check if we have enough space to pin the memory to device (with 1% slack)
216
+ if (100 * free < 99 * bytes) {
217
+ cudaCheck_ (cudaMalloc ((void **)out, bytes), file, line);
218
+ } else {
219
+ // if not, fallback to a managed allocation. It will be slower, but at least
220
+ // it won't crash.
221
+ fprintf (stderr, " [WARN] Not enough space to allocate %zu bytes on device.\n "
222
+ " Falling back to managed allocation.\n Speed may be negatively affected." ,
223
+ bytes);
224
+ cudaCheck_ (cudaMallocManaged ((void **)out, bytes), file, line);
225
+ }
226
+ }
227
+
228
+ #define cudaMallocConditionallyManaged (out, bytes )\
229
+ (cudaMallocConditionallyManaged((void **)out, bytes, __FILE__, __LINE__))
230
+
208
231
// ----------------------------------------------------------------------------
209
232
// Random Number Generation used in Stochastic Rounding
210
233
0 commit comments