|
6 | 6 | # this software and related documentation outside the terms of the EULA |
7 | 7 | # is strictly prohibited. |
8 | 8 |
|
9 | | -from cuda.core.experimental._stream import Stream, StreamOptions, LEGACY_DEFAULT_STREAM, PER_THREAD_DEFAULT_STREAM, default_stream |
10 | | -from cuda.core.experimental._event import Event, EventOptions |
11 | | -from cuda.core.experimental._device import Device |
| 9 | +from cuda.core.experimental import Device, Stream, StreamOptions |
| 10 | +from cuda.core.experimental._stream import LEGACY_DEFAULT_STREAM, PER_THREAD_DEFAULT_STREAM, default_stream |
| 11 | +from cuda.core.experimental._event import Event |
12 | 12 | import pytest |
13 | 13 |
|
14 | 14 | def test_stream_init(): |
15 | 15 | with pytest.raises(NotImplementedError): |
16 | 16 | Stream() |
17 | 17 |
|
18 | | -def test_stream_init_with_options(): |
19 | | - stream = Stream._init(options=StreamOptions(nonblocking=True, priority=0)) |
| 18 | +def test_stream_init_with_options(init_cuda): |
| 19 | + stream = Device().create_stream(options=StreamOptions(nonblocking=True, priority=0)) |
20 | 20 | assert stream.is_nonblocking is True |
21 | 21 | assert stream.priority == 0 |
22 | 22 |
|
23 | | -def test_stream_handle(): |
24 | | - stream = Stream._init(options=StreamOptions()) |
| 23 | +def test_stream_handle(init_cuda): |
| 24 | + stream = Device().create_stream(options=StreamOptions()) |
25 | 25 | assert isinstance(stream.handle, int) |
26 | 26 |
|
27 | | -def test_stream_is_nonblocking(): |
28 | | - stream = Stream._init(options=StreamOptions(nonblocking=True)) |
| 27 | +def test_stream_is_nonblocking(init_cuda): |
| 28 | + stream = Device().create_stream(options=StreamOptions(nonblocking=True)) |
29 | 29 | assert stream.is_nonblocking is True |
30 | 30 |
|
31 | | -def test_stream_priority(): |
32 | | - stream = Stream._init(options=StreamOptions(priority=0)) |
| 31 | +def test_stream_priority(init_cuda): |
| 32 | + stream = Device().create_stream(options=StreamOptions(priority=0)) |
33 | 33 | assert stream.priority == 0 |
34 | | - stream = Stream._init(options=StreamOptions(priority=-1)) |
| 34 | + stream = Device().create_stream(options=StreamOptions(priority=-1)) |
35 | 35 | assert stream.priority == -1 |
36 | 36 | with pytest.raises(ValueError): |
37 | | - stream = Stream._init(options=StreamOptions(priority=1)) |
| 37 | + stream = Device().create_stream(options=StreamOptions(priority=1)) |
38 | 38 |
|
39 | | -def test_stream_sync(): |
40 | | - stream = Stream._init(options=StreamOptions()) |
| 39 | +def test_stream_sync(init_cuda): |
| 40 | + stream = Device().create_stream(options=StreamOptions()) |
41 | 41 | stream.sync() # Should not raise any exceptions |
42 | 42 |
|
43 | | -def test_stream_record(): |
44 | | - stream = Stream._init(options=StreamOptions()) |
| 43 | +def test_stream_record(init_cuda): |
| 44 | + stream = Device().create_stream(options=StreamOptions()) |
45 | 45 | event = stream.record() |
46 | 46 | assert isinstance(event, Event) |
47 | 47 |
|
48 | | -def test_stream_record_invalid_event(): |
49 | | - stream = Stream._init(options=StreamOptions()) |
| 48 | +def test_stream_record_invalid_event(init_cuda): |
| 49 | + stream = Device().create_stream(options=StreamOptions()) |
50 | 50 | with pytest.raises(TypeError): |
51 | 51 | stream.record(event="invalid_event") |
52 | 52 |
|
53 | | -def test_stream_wait_event(): |
54 | | - stream = Stream._init(options=StreamOptions()) |
55 | | - event = Event._init() |
56 | | - stream.record(event) |
57 | | - stream.wait(event) # Should not raise any exceptions |
| 53 | +def test_stream_wait_event(init_cuda): |
| 54 | + s1 = Device().create_stream() |
| 55 | + s2 = Device().create_stream() |
| 56 | + e1 = s1.record() |
| 57 | + s2.wait(e1) # Should not raise any exceptions |
| 58 | + s2.sync() |
58 | 59 |
|
59 | | -def test_stream_wait_invalid_event(): |
60 | | - stream = Stream._init(options=StreamOptions()) |
| 60 | +def test_stream_wait_invalid_event(init_cuda): |
| 61 | + stream = Device().create_stream(options=StreamOptions()) |
61 | 62 | with pytest.raises(ValueError): |
62 | 63 | stream.wait(event_or_stream="invalid_event") |
63 | 64 |
|
64 | | -def test_stream_device(): |
65 | | - stream = Stream._init(options=StreamOptions()) |
| 65 | +def test_stream_device(init_cuda): |
| 66 | + stream = Device().create_stream(options=StreamOptions()) |
66 | 67 | device = stream.device |
67 | 68 | assert isinstance(device, Device) |
68 | 69 |
|
69 | | -def test_stream_context(): |
70 | | - stream = Stream._init(options=StreamOptions()) |
| 70 | +def test_stream_context(init_cuda): |
| 71 | + stream = Device().create_stream(options=StreamOptions()) |
71 | 72 | context = stream.context |
72 | 73 | assert context is not None |
73 | 74 |
|
|
0 commit comments