diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index bf8182796c86..b86447c9fbb7 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -5,6 +5,7 @@ import os import time +import pickle import torch import pytest import itertools @@ -65,7 +66,13 @@ ] # Get a list of all models and mapping from task to supported models -_hf_models = list(HfApi().list_models()) +try: + with open("hf_models.pkl", "rb") as fp: + _hf_models = pickle.load(fp) +except FileNotFoundError: + _hf_models = list(HfApi().list_models()) + with open("hf_models.pkl", "wb") as fp: + pickle.dump(_hf_models, fp) _hf_model_names = [m.modelId for m in _hf_models] _hf_task_to_models = {task: [m.modelId for m in _hf_models if m.pipeline_tag == task] for task in _test_tasks}