diff --git a/conftest.py b/conftest.py index a65fa04..2030d4f 100644 --- a/conftest.py +++ b/conftest.py @@ -11,6 +11,8 @@ FW_STRS = ["numpy", "jax", "tensorflow", "torch"] + + @pytest.fixture(autouse=True) def run_around_tests(dev_str, f, compile_graph, fw): if "gpu" in dev_str and fw == "numpy":