|
17 | 17 | """ Defines unit test cases for the SyclQueue class. |
18 | 18 | """ |
19 | 19 |
|
| 20 | +import ctypes |
| 21 | +import sys |
| 22 | + |
20 | 23 | import pytest |
21 | 24 |
|
22 | 25 | import dpctl |
@@ -395,22 +398,22 @@ def test_hashing_of_queue(): |
395 | 398 | assert queue_dict |
396 | 399 |
|
397 | 400 |
|
398 | | -def test_channeling_device_properties(): |
| 401 | +def test_channeling_device_properties(capsys): |
399 | 402 | try: |
400 | 403 | q = dpctl.SyclQueue() |
401 | 404 | dev = q.sycl_device |
402 | 405 | except dpctl.SyclQueueCreationError: |
403 | 406 | pytest.fail("Failed to create device from default selector") |
404 | | - import io |
405 | | - from contextlib import redirect_stdout |
406 | | - |
407 | | - f1 = io.StringIO() |
408 | | - with redirect_stdout(f1): |
409 | | - q.print_device_info() # should execute without raising |
410 | | - f2 = io.StringIO() |
411 | | - with redirect_stdout(f2): |
412 | | - dev.print_device_info() |
413 | | - assert f1.getvalue() == f2.getvalue(), "Mismatch in print_device_info" |
| 407 | + |
| 408 | + q.print_device_info() # should execute without raising |
| 409 | + q_captured = capsys.readouterr() |
| 410 | + q_output = q_captured.out |
| 411 | + dev.print_device_info() |
| 412 | + d_captured = capsys.readouterr() |
| 413 | + d_output = d_captured.out |
| 414 | + assert q_output, "No output captured" |
| 415 | + assert q_output == d_output, "Mismatch in print_device_info" |
| 416 | + assert q_captured.err == "" and d_captured.err == "" |
414 | 417 | for pr in ["backend", "name", "driver_version"]: |
415 | 418 | assert getattr(q, pr) == getattr( |
416 | 419 | dev, pr |
@@ -468,9 +471,6 @@ def test_queue_capsule(): |
468 | 471 |
|
469 | 472 |
|
470 | 473 | def test_cpython_api(): |
471 | | - import ctypes |
472 | | - import sys |
473 | | - |
474 | 474 | q = dpctl.SyclQueue() |
475 | 475 | mod = sys.modules[q.__class__.__module__] |
476 | 476 | # get capsule storign get_context_ref function ptr |
|
0 commit comments