Skip to content

Commit

Permalink
Systematically replace __del__ with weakref.finalize() (#246)
Browse files Browse the repository at this point in the history
* Systematically replace `__del__` with `weakref.finalize()`

* Event._finalize() approach with self._finalizer.Detach()

* Stream._MembersNeededForFinalize() approach.

Corresponding demonstration of finalize behavior (immediate cleanup):

https://github.com/rwgk/stuff/blob/f6fbd670b8376003c7767c96538d8ab0b1f49d96/random_attic/weakref_finalize_toy_example.py

* Buffer._MembersNeededForFinalize() approach.

* Apply _MembersNeededForFinalize pattern to _event.py

* _module.py: simply keep TODO comment only

* Apply _MembersNeededForFinalize pattern to _program.py
  • Loading branch information
rwgk authored Dec 2, 2024
1 parent fd71ced commit c3077da
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 108 deletions.
35 changes: 20 additions & 15 deletions cuda_core/cuda/core/experimental/_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import weakref
from dataclasses import dataclass
from typing import Optional

Expand Down Expand Up @@ -50,19 +51,29 @@ class Event:
"""

__slots__ = ("_handle", "_timing_disabled", "_busy_waited")
class _MembersNeededForFinalize:
__slots__ = ("handle",)

def __init__(self, event_obj, handle):
self.handle = handle
weakref.finalize(event_obj, self.close)

def close(self):
if self.handle is not None:
handle_return(cuda.cuEventDestroy(self.handle))
self.handle = None

__slots__ = ("__weakref__", "_mnff", "_timing_disabled", "_busy_waited")

def __init__(self):
self._handle = None
raise NotImplementedError(
"directly creating an Event object can be ambiguous. Please call call Stream.record()."
)

@staticmethod
def _init(options: Optional[EventOptions] = None):
self = Event.__new__(Event)
# minimal requirements for the destructor
self._handle = None
self._mnff = Event._MembersNeededForFinalize(self, None)

options = check_or_create_options(EventOptions, options, "Event options")
flags = 0x0
Expand All @@ -76,18 +87,12 @@ def _init(options: Optional[EventOptions] = None):
self._busy_waited = True
if options.support_ipc:
raise NotImplementedError("TODO")
self._handle = handle_return(cuda.cuEventCreate(flags))
self._mnff.handle = handle_return(cuda.cuEventCreate(flags))
return self

def __del__(self):
"""Return close(self)"""
self.close()

def close(self):
"""Destroy the event."""
if self._handle:
handle_return(cuda.cuEventDestroy(self._handle))
self._handle = None
self._mnff.close()

@property
def is_timing_disabled(self) -> bool:
Expand All @@ -114,12 +119,12 @@ def sync(self):
has been completed.
"""
handle_return(cuda.cuEventSynchronize(self._handle))
handle_return(cuda.cuEventSynchronize(self._mnff.handle))

@property
def is_done(self) -> bool:
"""Return True if all captured works have been completed, otherwise False."""
(result,) = cuda.cuEventQuery(self._handle)
(result,) = cuda.cuEventQuery(self._mnff.handle)
if result == cuda.CUresult.CUDA_SUCCESS:
return True
elif result == cuda.CUresult.CUDA_ERROR_NOT_READY:
Expand All @@ -130,4 +135,4 @@ def is_done(self) -> bool:
@property
def handle(self) -> int:
"""Return the underlying cudaEvent_t pointer address as Python int."""
return int(self._handle)
return int(self._mnff.handle)
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ cdef class ParamHolder:
for i, arg in enumerate(kernel_args):
if isinstance(arg, Buffer):
# we need the address of where the actual buffer address is stored
self.data_addresses[i] = <void*><intptr_t>(arg._ptr.getPtr())
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, int):
# Here's the dilemma: We want to have a fast path to pass in Python
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/experimental/_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def launch(kernel, config, *kernel_args):
drv_cfg = cuda.CUlaunchConfig()
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
drv_cfg.hStream = config.stream._handle
drv_cfg.hStream = config.stream.handle
drv_cfg.sharedMemBytes = config.shmem_size
drv_cfg.numAttrs = 0 # TODO
handle_return(cuda.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0))
Expand Down
69 changes: 38 additions & 31 deletions cuda_core/cuda/core/experimental/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import abc
import weakref
from typing import Optional, Tuple, TypeVar

from cuda import cuda
Expand Down Expand Up @@ -41,17 +42,28 @@ class Buffer:
"""

class _MembersNeededForFinalize:
__slots__ = ("ptr", "size", "mr")

def __init__(self, buffer_obj, ptr, size, mr):
self.ptr = ptr
self.size = size
self.mr = mr
weakref.finalize(buffer_obj, self.close)

def close(self, stream=None):
if self.ptr and self.mr is not None:
if stream is None:
stream = default_stream()
self.mr.deallocate(self.ptr, self.size, stream)
self.ptr = 0
self.mr = None

# TODO: handle ownership? (_mr could be None)
__slots__ = ("_ptr", "_size", "_mr")
__slots__ = ("__weakref__", "_mnff")

def __init__(self, ptr, size, mr: MemoryResource = None):
self._ptr = ptr
self._size = size
self._mr = mr

def __del__(self):
"""Return close(self)."""
self.close()
self._mnff = Buffer._MembersNeededForFinalize(self, ptr, size, mr)

def close(self, stream=None):
"""Deallocate this buffer asynchronously on the given stream.
Expand All @@ -67,47 +79,42 @@ def close(self, stream=None):
the default stream.
"""
if self._ptr and self._mr is not None:
if stream is None:
stream = default_stream()
self._mr.deallocate(self._ptr, self._size, stream)
self._ptr = 0
self._mr = None
self._mnff.close(stream)

@property
def handle(self):
"""Return the buffer handle object."""
return self._ptr
return self._mnff.ptr

@property
def size(self):
"""Return the memory size of this buffer."""
return self._size
return self._mnff.size

@property
def memory_resource(self) -> MemoryResource:
"""Return the memory resource associated with this buffer."""
return self._mr
return self._mnff.mr

@property
def is_device_accessible(self) -> bool:
"""Return True if this buffer can be accessed by the GPU, otherwise False."""
if self._mr is not None:
return self._mr.is_device_accessible
if self._mnff.mr is not None:
return self._mnff.mr.is_device_accessible
raise NotImplementedError

@property
def is_host_accessible(self) -> bool:
"""Return True if this buffer can be accessed by the CPU, otherwise False."""
if self._mr is not None:
return self._mr.is_host_accessible
if self._mnff.mr is not None:
return self._mnff.mr.is_host_accessible
raise NotImplementedError

@property
def device_id(self) -> int:
"""Return the device ordinal of this buffer."""
if self._mr is not None:
return self._mr.device_id
if self._mnff.mr is not None:
return self._mnff.mr.device_id
raise NotImplementedError

def copy_to(self, dst: Buffer = None, *, stream) -> Buffer:
Expand All @@ -129,12 +136,12 @@ def copy_to(self, dst: Buffer = None, *, stream) -> Buffer:
if stream is None:
raise ValueError("stream must be provided")
if dst is None:
if self._mr is None:
if self._mnff.mr is None:
raise ValueError("a destination buffer must be provided")
dst = self._mr.allocate(self._size, stream)
if dst._size != self._size:
dst = self._mnff.mr.allocate(self._mnff.size, stream)
if dst._mnff.size != self._mnff.size:
raise ValueError("buffer sizes mismatch between src and dst")
handle_return(cuda.cuMemcpyAsync(dst._ptr, self._ptr, self._size, stream._handle))
handle_return(cuda.cuMemcpyAsync(dst._mnff.ptr, self._mnff.ptr, self._mnff.size, stream.handle))
return dst

def copy_from(self, src: Buffer, *, stream):
Expand All @@ -151,9 +158,9 @@ def copy_from(self, src: Buffer, *, stream):
"""
if stream is None:
raise ValueError("stream must be provided")
if src._size != self._size:
if src._mnff.size != self._mnff.size:
raise ValueError("buffer sizes mismatch between src and dst")
handle_return(cuda.cuMemcpyAsync(self._ptr, src._ptr, self._size, stream._handle))
handle_return(cuda.cuMemcpyAsync(self._mnff.ptr, src._mnff.ptr, self._mnff.size, stream.handle))

def __dlpack__(
self,
Expand Down Expand Up @@ -242,13 +249,13 @@ def __init__(self, dev_id):
def allocate(self, size, stream=None) -> Buffer:
if stream is None:
stream = default_stream()
ptr = handle_return(cuda.cuMemAllocFromPoolAsync(size, self._handle, stream._handle))
ptr = handle_return(cuda.cuMemAllocFromPoolAsync(size, self._handle, stream.handle))
return Buffer(ptr, size, self)

def deallocate(self, ptr, size, stream=None):
if stream is None:
stream = default_stream()
handle_return(cuda.cuMemFreeAsync(ptr, stream._handle))
handle_return(cuda.cuMemFreeAsync(ptr, stream.handle))

@property
def is_device_accessible(self) -> bool:
Expand Down
4 changes: 1 addition & 3 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
self._module = module
self._sym_map = {} if symbol_mapping is None else symbol_mapping

def __del__(self):
# TODO: do we want to unload? Probably not..
pass
# TODO: do we want to unload in a finalizer? Probably not..

def get_kernel(self, name):
"""Return the :obj:`Kernel` of a specified name from this object code.
Expand Down
45 changes: 27 additions & 18 deletions cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import weakref

from cuda import nvrtc
from cuda.core.experimental._module import ObjectCode
from cuda.core.experimental._utils import handle_return
Expand All @@ -24,12 +26,25 @@ class Program:
"""

__slots__ = ("_handle", "_backend")
class _MembersNeededForFinalize:
__slots__ = ("handle",)

def __init__(self, program_obj, handle):
self.handle = handle
weakref.finalize(program_obj, self.close)

def close(self):
if self.handle is not None:
handle_return(nvrtc.nvrtcDestroyProgram(self.handle))
self.handle = None

__slots__ = ("__weakref__", "_mnff", "_backend")
_supported_code_type = ("c++",)
_supported_target_type = ("ptx", "cubin", "ltoir")

def __init__(self, code, code_type):
self._handle = None
self._mnff = Program._MembersNeededForFinalize(self, None)

if code_type not in self._supported_code_type:
raise NotImplementedError

Expand All @@ -38,20 +53,14 @@ def __init__(self, code, code_type):
raise TypeError
# TODO: support pre-loaded headers & include names
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
self._handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
self._backend = "nvrtc"
else:
raise NotImplementedError

def __del__(self):
"""Return close(self)."""
self.close()

def close(self):
"""Destroy this program."""
if self._handle is not None:
handle_return(nvrtc.nvrtcDestroyProgram(self._handle))
self._handle = None
self._mnff.close()

def compile(self, target_type, options=(), name_expressions=(), logs=None):
"""Compile the program with a specific compilation type.
Expand Down Expand Up @@ -84,29 +93,29 @@ def compile(self, target_type, options=(), name_expressions=(), logs=None):
if self._backend == "nvrtc":
if name_expressions:
for n in name_expressions:
handle_return(nvrtc.nvrtcAddNameExpression(self._handle, n.encode()), handle=self._handle)
handle_return(nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), handle=self._mnff.handle)
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
options = list(o.encode() for o in options)
handle_return(nvrtc.nvrtcCompileProgram(self._handle, len(options), options), handle=self._handle)
handle_return(nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options), handle=self._mnff.handle)

size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size")
comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}")
size = handle_return(size_func(self._handle), handle=self._handle)
size = handle_return(size_func(self._mnff.handle), handle=self._mnff.handle)
data = b" " * size
handle_return(comp_func(self._handle, data), handle=self._handle)
handle_return(comp_func(self._mnff.handle, data), handle=self._mnff.handle)

symbol_mapping = {}
if name_expressions:
for n in name_expressions:
symbol_mapping[n] = handle_return(
nvrtc.nvrtcGetLoweredName(self._handle, n.encode()), handle=self._handle
nvrtc.nvrtcGetLoweredName(self._mnff.handle, n.encode()), handle=self._mnff.handle
)

if logs is not None:
logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._handle), handle=self._handle)
logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._mnff.handle), handle=self._mnff.handle)
if logsize > 1:
log = b" " * logsize
handle_return(nvrtc.nvrtcGetProgramLog(self._handle, log), handle=self._handle)
handle_return(nvrtc.nvrtcGetProgramLog(self._mnff.handle, log), handle=self._mnff.handle)
logs.write(log.decode())

# TODO: handle jit_options for ptx?
Expand All @@ -121,4 +130,4 @@ def backend(self):
@property
def handle(self):
"""Return the program handle object."""
return self._handle
return self._mnff.handle
Loading

0 comments on commit c3077da

Please sign in to comment.