diff --git a/conftest.py b/conftest.py index 3c04f0680a11..78f028052c0e 100644 --- a/conftest.py +++ b/conftest.py @@ -16,10 +16,14 @@ # under the License. import hashlib import pytest +import sys import os -from collections import OrderedDict + +from pathlib import Path pytest_plugins = ["tvm.testing.plugin"] +IS_IN_CI = os.getenv("CI", "") == "true" +REPO_ROOT = Path(__file__).resolve().parent # These are long running tests (manually curated and extracted from CI logs) @@ -96,3 +100,12 @@ def pytest_collection_modifyitems(config, items): reason=f"Test running on shard {item_shard_index} of {num_shards}", ) ) + + +def pytest_sessionstart(): + if IS_IN_CI: + hook_script_dir = REPO_ROOT / "tests" / "scripts" / "request_hook" + sys.path.append(str(hook_script_dir)) + import request_hook # pylint: disable=import-outside-toplevel + + request_hook.init()