Skip to content

Commit

Permalink
merge with main for ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
ksimpson-work committed Nov 29, 2024
1 parent fd71ced commit 319a372
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
12 changes: 9 additions & 3 deletions cuda_core/cuda/core/experimental/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from cuda import cuda, cudart
from cuda.core.experimental._context import Context, ContextOptions
from cuda.core.experimental._memory import Buffer, MemoryResource, _DefaultAsyncMempool
from cuda.core.experimental._memory import Buffer, MemoryResource, _AsyncMemoryResource, _DefaultAsyncMempool
from cuda.core.experimental._stream import Stream, StreamOptions, default_stream
from cuda.core.experimental._utils import ComputeCapability, CUDAError, handle_return, precondition

Expand Down Expand Up @@ -62,15 +62,21 @@ def __new__(cls, device_id=None):
for dev_id in range(total):
dev = super().__new__(cls)
dev._id = dev_id
dev._mr = _DefaultAsyncMempool(dev_id)
# If the device is in TCC mode, or does not support memory pools for some other reason,
# use the AsyncMemoryResource which does not use memory pools.
if (handle_return(cudart.cudaGetDeviceProperties(dev_id))).memoryPoolsSupported == 0:
dev._mr = _AsyncMemoryResource(dev_id)
else:
dev._mr = _DefaultAsyncMempool(dev_id)

dev._has_inited = False
_tls.devices.append(dev)

return _tls.devices[device_id]

def _check_context_initialized(self, *args, **kwargs):
if not self._has_inited:
raise CUDAError("the device is not yet initialized, perhaps you forgot to call .set_current() first?")
raise CUDAError("the device is not yet initialized, " "perhaps you forgot to call .set_current() first?")

@property
def device_id(self) -> int:
Expand Down
37 changes: 36 additions & 1 deletion cuda_core/cuda/core/experimental/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ class Buffer:
"""

# TODO: handle ownership? (_mr could be None)
__slots__ = ("_ptr", "_size", "_mr")
__slots__ = (
"_ptr",
"_size",
"_mr",
)

def __init__(self, ptr, size, mr: MemoryResource = None):
self._ptr = ptr
Expand Down Expand Up @@ -286,3 +290,34 @@ def is_host_accessible(self) -> bool:
@property
def device_id(self) -> int:
raise RuntimeError("the pinned memory resource is not bound to any GPU")


class _AsyncMemoryResource(MemoryResource):
__slots__ = ("_dev_id",)

def __init__(self, dev_id):
self._handle = None
self._dev_id = dev_id

def allocate(self, size, stream=None) -> Buffer:
if stream is None:
stream = default_stream()
ptr = handle_return(cuda.cuMemAllocAsync(size, stream._handle))
return Buffer(ptr, size, self)

def deallocate(self, ptr, size, stream=None):
if stream is None:
stream = default_stream()
handle_return(cuda.cuMemFreeAsync(ptr, stream._handle))

@property
def is_device_accessible(self) -> bool:
return True

@property
def is_host_accessible(self) -> bool:
return False

@property
def device_id(self) -> int:
return self._dev_id

0 comments on commit 319a372

Please sign in to comment.