diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 10789856..f1c24b3e 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -13,7 +13,6 @@ from cuda.core.experimental._module import Kernel, ObjectCode -@pytest.fixture def can_load_generated_ptx(): _, driver_ver = cuda.cuDriverGetVersion() _, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion() @@ -42,7 +41,7 @@ def test_program_init_invalid_code_format(): # TODO: incorporate this check in Program -@pytest.mark.xfail(not can_load_generated_ptx, reason="PTX version too new") +@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new") def test_program_compile_valid_target_type(): code = 'extern "C" __global__ void my_kernel() {}' program = Program(code, "c++")