Skip to content

Commit

Permalink
CallbackMemoryResource: handle exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Sep 22, 2022
1 parent 7ec6304 commit cc86be3
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 4 deletions.
48 changes: 46 additions & 2 deletions python/rmm/_lib/memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ import os
import warnings
from collections import defaultdict

from cpython.exc cimport PyErr_Occurred
from cython.operator cimport dereference as deref
from libc.stdint cimport int8_t, int64_t, uintptr_t
from libcpp cimport bool
from libcpp.cast cimport dynamic_cast
from libcpp.memory cimport make_shared, make_unique, shared_ptr, unique_ptr
from libcpp.pair cimport pair
from libcpp.string cimport string

from cuda.cudart import cudaError_t
Expand All @@ -35,6 +37,42 @@ from rmm._cuda.gpu import (

from rmm._lib.cuda_stream_view cimport cuda_stream_view

# Transparent handle of a C++ exception
ctypedef pair[int, string] CppExcept

cdef CppExcept translate_python_except_to_cpp(err: BaseException):
"""Translate a Python exception into a C++ exception handle
The returned exception handle can then be thrown by `throw_cpp_except()`,
which MUST be done without holding the GIL.
This is useful when C++ calls a Python function and needs to catch or
propagate exceptions.
"""
if isinstance(err, MemoryError):
return CppExcept(0, str.encode(str(err)))
if isinstance(err, BaseException):
return CppExcept(-1, str.encode(str(err)))

# Implementation of `throw_cpp_except()`, which throws a given `CppExcept`.
# This function MUST be called without the GIL otherwise the thrown C++
# exception are tanslated back into a Python exception.
cdef extern from *:
"""
#include <stdexcept>
#include <utility>
void throw_cpp_except(std::pair<int, std::string> res) {
switch(res.first) {
case 0:
throw rmm::out_of_memory(res.second);
default:
throw std::runtime_error(res.second);
}
}
"""
void throw_cpp_except(CppExcept) nogil


# NOTE: Keep extern declarations in .pyx file as much as possible to avoid
# leaking dependencies when importing RMM Cython .pxd files
Expand Down Expand Up @@ -523,8 +561,14 @@ cdef void* _allocate_callback_wrapper(
size_t nbytes,
cuda_stream_view stream,
void* ctx
) with gil:
return <void*><uintptr_t>((<object>ctx)(nbytes))
) nogil:
cdef CppExcept err
with gil:
try:
return <void*><uintptr_t>((<object>ctx)(nbytes))
except BaseException as e:
err = translate_python_except_to_cpp(e)
throw_cpp_except(err)

cdef void _deallocate_callback_wrapper(
void* ptr,
Expand Down
29 changes: 27 additions & 2 deletions python/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,13 +760,38 @@ def deallocate_func(ptr, size):
rmm.mr.CallbackMemoryResource(allocate_func, deallocate_func)
)

dbuf = rmm.DeviceBuffer(size=256)
del dbuf
rmm.DeviceBuffer(size=256)

captured = capsys.readouterr()
assert captured.out == "Allocating 256 bytes\nDeallocating 256 bytes\n"


@pytest.mark.parametrize(
"err_raise,err_catch",
[
(MemoryError, MemoryError),
(RuntimeError, RuntimeError),
(Exception, RuntimeError),
(BaseException, RuntimeError),
],
)
def test_callback_mr_error(err_raise, err_catch):
base_mr = rmm.mr.CudaMemoryResource()

def allocate_func(size):
raise err_raise("My alloc error")

def deallocate_func(ptr, size):
return base_mr.deallocate(ptr, size)

rmm.mr.set_current_device_resource(
rmm.mr.CallbackMemoryResource(allocate_func, deallocate_func)
)

with pytest.raises(err_catch, match="My alloc error"):
rmm.DeviceBuffer(size=256)


@pytest.fixture
def make_reinit_hook():
funcs = []
Expand Down

0 comments on commit cc86be3

Please sign in to comment.