2020
2121import pytest
2222
23+ from tests .utils .constants import TEST_MODELS
2324from tests .utils .managed_process import ManagedProcess
2425
25- # Custom format inspired by your example
26+
27+ def pytest_configure (config ):
28+ # Defining model morker to avoid `'model' not found in `markers` configuration option`
29+ # error when pyproject.toml is not available in the container
30+ config .addinivalue_line ("markers" , "model: model id used by a test or parameter" )
31+
32+
2633LOG_FORMAT = "[TEST] %(asctime)s %(levelname)s %(name)s: %(message)s"
2734DATE_FORMAT = "%Y-%m-%dT%H:%M:%S"
2835
29- # Configure logging
3036logging .basicConfig (
3137 level = logging .INFO ,
3238 format = LOG_FORMAT ,
3339 datefmt = DATE_FORMAT , # ISO 8601 UTC format
3440)
3541
36- # List of models used in tests
37- TEST_MODELS = [
38- "Qwen/Qwen3-0.6B" ,
39- "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" ,
40- "llava-hf/llava-1.5-7b-hf" ,
41- ]
42-
4342
4443def download_models (model_list = None , ignore_weights = False ):
4544 """Download models - can be called directly or via fixture
@@ -107,16 +106,34 @@ def download_models(model_list=None, ignore_weights=False):
107106
108107
109108@pytest .fixture (scope = "session" )
110- def predownload_models ():
111- """Fixture wrapper around download_models for all TEST_MODELS"""
112- download_models ()
109+ def predownload_models (pytestconfig ):
110+ """Fixture wrapper around download_models for models used in collected tests"""
111+ # Get models from pytest config if available, otherwise fall back to TEST_MODELS
112+ models = getattr (pytestconfig , "models_to_download" , None )
113+ if models :
114+ logging .info (
115+ f"Downloading { len (models )} models needed for collected tests\n Models: { models } "
116+ )
117+ download_models (model_list = list (models ))
118+ else :
119+ # Fallback to original behavior if extraction failed
120+ download_models ()
113121 yield
114122
115123
116124@pytest .fixture (scope = "session" )
117- def predownload_tokenizers ():
118- """Fixture wrapper around download_models for all TEST_MODELS"""
119- download_models (ignore_weights = True )
125+ def predownload_tokenizers (pytestconfig ):
126+ """Fixture wrapper around download_models for tokenizers used in collected tests"""
127+ # Get models from pytest config if available, otherwise fall back to TEST_MODELS
128+ models = getattr (pytestconfig , "models_to_download" , None )
129+ if models :
130+ logging .info (
131+ f"Downloading tokenizers for { len (models )} models needed for collected tests\n Models: { models } "
132+ )
133+ download_models (model_list = list (models ), ignore_weights = True )
134+ else :
135+ # Fallback to original behavior if extraction failed
136+ download_models (ignore_weights = True )
120137 yield
121138
122139
@@ -135,42 +152,26 @@ def logger(request):
135152 logger .removeHandler (handler )
136153
137154
155+ @pytest .hookimpl (trylast = True )
138156def pytest_collection_modifyitems (config , items ):
139157 """
140158 This function is called to modify the list of tests to run.
141- It is used to skip tests that are not supported on all environments.
142159 """
143-
144- # Tests marked with trtllm requires specific environment with tensorrtllm
145- # installed. Hence, we skip them if the user did not explicitly ask for them.
146- if config .getoption ("-m" ) and "trtllm_marker" in config .getoption ("-m" ):
147- return
148- skip_trtllm = pytest .mark .skip (reason = "need -m trtllm_marker to run" )
160+ # Collect models via explicit pytest mark from final filtered items only
161+ models_to_download = set ()
149162 for item in items :
150- if "trtllm_marker" in item .keywords :
151- item .add_marker (skip_trtllm )
152-
153- # Auto-inject predownload_models fixture for serve tests only (not router tests)
154- # Skip items that don't have fixturenames (like MypyFileItem)
155- if hasattr (item , "fixturenames" ):
156- # Guard clause: skip if already has the fixtures
157- if (
158- "predownload_models" in item .fixturenames
159- or "predownload_tokenizers" in item .fixturenames
160- ):
161- continue
162-
163- # Guard clause: skip if marked with skip_model_download
164- if item .get_closest_marker ("skip_model_download" ):
165- continue
166-
167- # Add appropriate fixture based on test path
168- if "serve" in str (item .path ):
169- item .fixturenames = list (item .fixturenames )
170- item .fixturenames .append ("predownload_models" )
171- elif "router" in str (item .path ):
172- item .fixturenames = list (item .fixturenames )
173- item .fixturenames .append ("predownload_tokenizers" )
163+ # Only collect from items that are not skipped
164+ if any (
165+ getattr (m , "name" , "" ) == "skip" for m in getattr (item , "own_markers" , [])
166+ ):
167+ continue
168+ model_mark = item .get_closest_marker ("model" )
169+ if model_mark and model_mark .args :
170+ models_to_download .add (model_mark .args [0 ])
171+
172+ # Store models to download in pytest config for fixtures to access
173+ if models_to_download :
174+ config .models_to_download = models_to_download
174175
175176
176177class EtcdServer (ManagedProcess ):
0 commit comments