Skip to content

Commit

Permalink
Merge pull request #203 from NVIDIA/ksimpson/fix_device_from_ctx
Browse files Browse the repository at this point in the history
Fix _util.device_from_ctx
  • Loading branch information
ksimpson-work authored Nov 12, 2024
2 parents a855762 + 8f357c7 commit 92aa731
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cuda_core/cuda/core/experimental/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def inner(*args, **kwargs):

def get_device_from_ctx(ctx_handle) -> int:
"""Get device ID from the given ctx."""
prev_ctx = Device().context.handle
if ctx_handle != prev_ctx:
from cuda.core.experimental._device import Device # avoid circular import
prev_ctx = Device().context._handle
if int(ctx_handle) != int(prev_ctx):
switch_context = True
else:
switch_context = False
Expand Down
10 changes: 10 additions & 0 deletions cuda_core/tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def test_stream_context():
context = stream.context
assert context is not None

def test_stream_from_foreign_stream():
device = Device()
other_stream = device.create_stream(options=StreamOptions())
stream = device.create_stream(obj=other_stream)
assert other_stream.handle == stream.handle
device = stream.device
assert isinstance(device, Device)
context = stream.context
assert context is not None

def test_stream_from_handle():
stream = Stream.from_handle(0)
assert isinstance(stream, Stream)
Expand Down

0 comments on commit 92aa731

Please sign in to comment.