Skip to content

Commit 5622608

Browse files
committed
Adding context to the stream hash implementation
1 parent 4227651 commit 5622608

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

cuda_core/cuda/core/experimental/_event.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,16 +233,20 @@ cdef class Event:
233233
def from_ipc_descriptor(cls, ipc_descriptor: IPCEventDescriptor) -> Event:
234234
"""Import an event that was exported from another process."""
235235
cdef cydriver.CUipcEventHandle data
236+
cdef cydriver.CUcontext curr_ctx
237+
cdef cydriver.CUdevice dev
236238
memcpy(data.reserved, <const void*><const char*>(ipc_descriptor._reserved), sizeof(data.reserved))
237239
cdef Event self = Event.__new__(cls)
238240
with nogil:
239241
HANDLE_RETURN(cydriver.cuIpcOpenEventHandle(&self._handle, data))
242+
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&curr_ctx))
243+
HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev))
240244
self._timing_disabled = True
241245
self._busy_waited = ipc_descriptor._busy_waited
242246
self._ipc_enabled = True
243247
self._ipc_descriptor = ipc_descriptor
244-
self._device_id = -1 # ??
245-
self._ctx_handle = None # ??
248+
self._device_id = <int>dev
249+
self._ctx_handle = driver.CUcontext(<uintptr_t>curr_ctx)
246250
return self
247251

248252
@property

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,23 @@ cdef class Stream:
126126
@classmethod
127127
def _legacy_default(cls):
128128
cdef Stream self = Stream.__new__(cls)
129+
cdef cydriver.CUcontext ctx
129130
self._handle = <cydriver.CUstream>(cydriver.CU_STREAM_LEGACY)
130131
self._builtin = True
132+
with nogil:
133+
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
134+
self._ctx_handle = ctx
131135
return self
132136

133137
@classmethod
134138
def _per_thread_default(cls):
135139
cdef Stream self = Stream.__new__(cls)
140+
cdef cydriver.CUcontext ctx
136141
self._handle = <cydriver.CUstream>(cydriver.CU_STREAM_PER_THREAD)
137142
self._builtin = True
143+
with nogil:
144+
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
145+
self._ctx_handle = ctx
138146
return self
139147

140148
@classmethod
@@ -144,9 +152,16 @@ cdef class Stream:
144152
if obj is not None and options is not None:
145153
raise ValueError("obj and options cannot be both specified")
146154
if obj is not None:
155+
cdef cydriver.CUcontext ctx
156+
cdef cydriver.CUresult err
147157
self._handle = _try_to_get_stream_ptr(obj)
148158
# TODO: check if obj is created under the current context/device
149159
self._owner = obj
160+
with nogil:
161+
err = cydriver.cuStreamGetCtx(self._handle, &ctx)
162+
if err != cydriver.CUresult.CUDA_SUCCESS:
163+
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
164+
self._ctx_handle = ctx
150165
return self
151166

152167
cdef StreamOptions opts = check_or_create_options(StreamOptions, options, "Stream options")
@@ -167,12 +182,15 @@ cdef class Stream:
167182
prio = high
168183

169184
cdef cydriver.CUstream s
185+
cdef cydriver.CUcontext ctx
170186
with nogil:
171187
HANDLE_RETURN(cydriver.cuStreamCreateWithPriority(&s, flags, prio))
188+
HANDLE_RETURN(cydriver.cuStreamGetCtx(s, &ctx))
172189
self._handle = s
173190
self._nonblocking = int(nonblocking)
174191
self._priority = prio
175192
self._device_id = device_id if device_id is not None else self._device_id
193+
self._ctx_handle = ctx
176194
return self
177195

178196
def __dealloc__(self):
@@ -198,7 +216,7 @@ cdef class Stream:
198216
return (0, <uintptr_t>(self._handle))
199217

200218
def __hash__(self) -> int:
201-
"""Return hash based on the underlying CUstream handle address.
219+
"""Return hash based on the underlying CUstream handle address and context.
202220

203221
This enables Stream objects to be used as dictionary keys and in sets.
204222
Two Stream objects wrapping the same underlying CUDA stream will hash
@@ -207,15 +225,24 @@ cdef class Stream:
207225
Returns
208226
-------
209227
int
210-
Hash value based on the stream handle address.
228+
Hash value based on the stream handle address and context handle.
229+
230+
Notes
231+
-----
232+
Includes the context handle in the hash to prevent collisions when
233+
handles are reused across different contexts. While handles are
234+
context-scoped and typically not reused across contexts, including
235+
the context provides defense-in-depth against hash collisions.
236+
The context is fetched and cached during Stream construction.
211237

212238
Warning
213239
-------
214240
Using a closed or destroyed stream as a dictionary key or in a set
215241
results in undefined behavior. The stream handle may be reused by
216242
the CUDA driver for new streams.
217243
"""
218-
return hash((type(self), <uintptr_t>(self._handle)))
244+
# Context should always be set as a post-condition of Stream construction
245+
return hash((type(self), <uintptr_t>(self._ctx_handle), <uintptr_t>(self._handle)))
219246

220247
def __eq__(self, other) -> bool:
221248
"""Check equality based on the underlying CUstream handle address.

0 commit comments

Comments
 (0)