@@ -126,15 +126,23 @@ cdef class Stream:
126126 @classmethod
127127 def _legacy_default (cls ):
128128 cdef Stream self = Stream.__new__ (cls )
129+ cdef cydriver.CUcontext ctx
129130 self ._handle = < cydriver.CUstream> (cydriver.CU_STREAM_LEGACY)
130131 self ._builtin = True
132+ with nogil:
133+ HANDLE_RETURN(cydriver.cuCtxGetCurrent(& ctx))
134+ self ._ctx_handle = ctx
131135 return self
132136
133137 @classmethod
134138 def _per_thread_default (cls ):
135139 cdef Stream self = Stream.__new__ (cls )
140+ cdef cydriver.CUcontext ctx
136141 self ._handle = < cydriver.CUstream> (cydriver.CU_STREAM_PER_THREAD)
137142 self ._builtin = True
143+ with nogil:
144+ HANDLE_RETURN(cydriver.cuCtxGetCurrent(& ctx))
145+ self ._ctx_handle = ctx
138146 return self
139147
140148 @classmethod
@@ -144,9 +152,16 @@ cdef class Stream:
144152 if obj is not None and options is not None :
145153 raise ValueError (" obj and options cannot be both specified" )
146154 if obj is not None :
155+ cdef cydriver.CUcontext ctx
156+ cdef cydriver.CUresult err
147157 self ._handle = _try_to_get_stream_ptr(obj)
148158 # TODO: check if obj is created under the current context/device
149159 self ._owner = obj
160+ with nogil:
161+ err = cydriver.cuStreamGetCtx(self ._handle, & ctx)
162+ if err != cydriver.CUresult.CUDA_SUCCESS:
163+ HANDLE_RETURN(cydriver.cuCtxGetCurrent(& ctx))
164+ self ._ctx_handle = ctx
150165 return self
151166
152167 cdef StreamOptions opts = check_or_create_options(StreamOptions, options, " Stream options" )
@@ -167,12 +182,15 @@ cdef class Stream:
167182 prio = high
168183
169184 cdef cydriver.CUstream s
185+ cdef cydriver.CUcontext ctx
170186 with nogil:
171187 HANDLE_RETURN(cydriver.cuStreamCreateWithPriority(& s, flags, prio))
188+ HANDLE_RETURN(cydriver.cuStreamGetCtx(s, & ctx))
172189 self ._handle = s
173190 self ._nonblocking = int (nonblocking)
174191 self ._priority = prio
175192 self ._device_id = device_id if device_id is not None else self ._device_id
193+ self ._ctx_handle = ctx
176194 return self
177195
178196 def __dealloc__ (self ):
@@ -198,7 +216,7 @@ cdef class Stream:
198216 return (0, <uintptr_t>(self._handle ))
199217
200218 def __hash__(self ) -> int:
201- """Return hash based on the underlying CUstream handle address.
219+ """Return hash based on the underlying CUstream handle address and context .
202220
203221 This enables Stream objects to be used as dictionary keys and in sets.
204222 Two Stream objects wrapping the same underlying CUDA stream will hash
@@ -207,15 +225,24 @@ cdef class Stream:
207225 Returns
208226 -------
209227 int
210- Hash value based on the stream handle address.
228+ Hash value based on the stream handle address and context handle.
229+
230+ Notes
231+ -----
232+ Includes the context handle in the hash to prevent collisions when
233+ handles are reused across different contexts. While handles are
234+ context-scoped and typically not reused across contexts , including
235+ the context provides defense-in-depth against hash collisions.
236+ The context is fetched and cached during Stream construction.
211237
212238 Warning
213239 -------
214240 Using a closed or destroyed stream as a dictionary key or in a set
215241 results in undefined behavior. The stream handle may be reused by
216242 the CUDA driver for new streams.
217243 """
218- return hash((type(self ), <uintptr_t>(self._handle )))
244+ # Context should always be set as a post-condition of Stream construction
245+ return hash((type(self ), <uintptr_t>(self._ctx_handle ), <uintptr_t>(self._handle )))
219246
220247 def __eq__(self , other ) -> bool:
221248 """Check equality based on the underlying CUstream handle address.
0 commit comments