11from dataclasses import dataclass , field
2- from typing import AbstractSet , Mapping , Optional
2+ from typing import AbstractSet , Any , Literal , Mapping , Optional
3+
4+ import pytest
5+ from packaging .version import Version
6+ from transformers import __version__ as TRANSFORMERS_VERSION
37
48
59@dataclass (frozen = True )
@@ -38,6 +42,50 @@ class _HfExamplesInfo:
3842 trust_remote_code : bool = False
3943 """The ``trust_remote_code`` level required to load the model."""
4044
45+ hf_overrides : dict [str , Any ] = field (default_factory = dict )
46+ """The ``hf_overrides`` required to load the model."""
47+
48+ def check_transformers_version (
49+ self ,
50+ * ,
51+ on_fail : Literal ["error" , "skip" ],
52+ ) -> None :
53+ """
54+ If the installed transformers version does not meet the requirements,
55+ perform the given action.
56+ """
57+ if self .min_transformers_version is None :
58+ return
59+
60+ current_version = TRANSFORMERS_VERSION
61+ required_version = self .min_transformers_version
62+ if Version (current_version ) < Version (required_version ):
63+ msg = (
64+ f"You have `transformers=={ current_version } ` installed, but "
65+ f"`transformers>={ required_version } ` is required to run this "
66+ "model" )
67+
68+ if on_fail == "error" :
69+ raise RuntimeError (msg )
70+ else :
71+ pytest .skip (msg )
72+
73+ def check_available_online (
74+ self ,
75+ * ,
76+ on_fail : Literal ["error" , "skip" ],
77+ ) -> None :
78+ """
79+ If the model is not available online, perform the given action.
80+ """
81+ if not self .is_available_online :
82+ msg = "Model is not available online"
83+
84+ if on_fail == "error" :
85+ raise RuntimeError (msg )
86+ else :
87+ pytest .skip (msg )
88+
4189
4290# yapf: disable
4391_TEXT_GENERATION_EXAMPLE_MODELS = {
@@ -48,8 +96,6 @@ class _HfExamplesInfo:
4896 trust_remote_code = True ),
4997 "ArcticForCausalLM" : _HfExamplesInfo ("Snowflake/snowflake-arctic-instruct" ,
5098 trust_remote_code = True ),
51- "AriaForConditionalGeneration" : _HfExamplesInfo ("rhymes-ai/Aria" ,
52- trust_remote_code = True ),
5399 "BaiChuanForCausalLM" : _HfExamplesInfo ("baichuan-inc/Baichuan-7B" ,
54100 trust_remote_code = True ),
55101 "BaichuanForCausalLM" : _HfExamplesInfo ("baichuan-inc/Baichuan2-7B-chat" ,
@@ -176,14 +222,17 @@ class _HfExamplesInfo:
176222
177223_MULTIMODAL_EXAMPLE_MODELS = {
178224 # [Decoder-only]
225+ "AriaForConditionalGeneration" : _HfExamplesInfo ("rhymes-ai/Aria" ,
226+ min_transformers_version = "4.48" ),
179227 "Blip2ForConditionalGeneration" : _HfExamplesInfo ("Salesforce/blip2-opt-2.7b" ), # noqa: E501
180228 "ChameleonForConditionalGeneration" : _HfExamplesInfo ("facebook/chameleon-7b" ), # noqa: E501
181229 "ChatGLMModel" : _HfExamplesInfo ("THUDM/glm-4v-9b" ,
182230 extras = {"text_only" : "THUDM/chatglm3-6b" },
183231 trust_remote_code = True ),
184232 "ChatGLMForConditionalGeneration" : _HfExamplesInfo ("chatglm2-6b" ,
185233 is_available_online = False ),
186- "DeepseekVLV2ForCausalLM" : _HfExamplesInfo ("deepseek-ai/deepseek-vl2-tiny" ), # noqa: E501
234+ "DeepseekVLV2ForCausalLM" : _HfExamplesInfo ("deepseek-ai/deepseek-vl2-tiny" , # noqa: E501
235+ hf_overrides = {"architectures" : ["DeepseekVLV2ForCausalLM" ]}), # noqa: E501
187236 "FuyuForCausalLM" : _HfExamplesInfo ("adept/fuyu-8b" ),
188237 "H2OVLChatModel" : _HfExamplesInfo ("h2oai/h2ovl-mississippi-800m" ),
189238 "InternVLChatModel" : _HfExamplesInfo ("OpenGVLab/InternVL2-1B" ,
@@ -194,7 +243,8 @@ class _HfExamplesInfo:
194243 "LlavaNextForConditionalGeneration" : _HfExamplesInfo ("llava-hf/llava-v1.6-mistral-7b-hf" ), # noqa: E501
195244 "LlavaNextVideoForConditionalGeneration" : _HfExamplesInfo ("llava-hf/LLaVA-NeXT-Video-7B-hf" ), # noqa: E501
196245 "LlavaOnevisionForConditionalGeneration" : _HfExamplesInfo ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf" ), # noqa: E501
197- "MantisForConditionalGeneration" : _HfExamplesInfo ("TIGER-Lab/Mantis-8B-siglip-llama3" ), # noqa: E501
246+ "MantisForConditionalGeneration" : _HfExamplesInfo ("TIGER-Lab/Mantis-8B-siglip-llama3" , # noqa: E501
247+ hf_overrides = {"architectures" : ["MantisForConditionalGeneration" ]}), # noqa: E501
198248 "MiniCPMV" : _HfExamplesInfo ("openbmb/MiniCPM-Llama3-V-2_5" ,
199249 trust_remote_code = True ),
200250 "MolmoForCausalLM" : _HfExamplesInfo ("allenai/Molmo-7B-D-0924" ,
@@ -247,5 +297,12 @@ def get_supported_archs(self) -> AbstractSet[str]:
247297 def get_hf_info (self , model_arch : str ) -> _HfExamplesInfo :
248298 return self .hf_models [model_arch ]
249299
300+ def find_hf_info (self , model_id : str ) -> _HfExamplesInfo :
301+ for info in self .hf_models .values ():
302+ if info .default == model_id :
303+ return info
304+
305+ raise ValueError (f"No example model defined for { model_id } " )
306+
250307
251308HF_EXAMPLE_MODELS = HfExampleModels (_EXAMPLE_MODELS )
0 commit comments