Skip to content

Commit 6e51ac3

Browse files
[timm_wrapper] better handling of "Unknown model" exception in timm (#40951)
* fix(timm): Add exception handling for unknown Gemma3n model * nit: Let’s cater to this specific issue * nit: Simplify error handling
1 parent 9378f87 commit 6e51ac3

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

src/transformers/models/timm_wrapper/modeling_timm_wrapper.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,28 @@ class TimmWrapperModelOutput(ModelOutput):
5555
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
5656

5757

58+
def _create_timm_model_with_error_handling(config: "TimmWrapperConfig", **model_kwargs):
59+
"""
60+
Creates a timm model and provides a clear error message if the model is not found,
61+
suggesting a library update.
62+
"""
63+
try:
64+
model = timm.create_model(
65+
config.architecture,
66+
pretrained=False,
67+
**model_kwargs,
68+
)
69+
return model
70+
except RuntimeError as e:
71+
if "Unknown model" in str(e):
72+
# A good general check for unknown models.
73+
raise ImportError(
74+
f"The model architecture '{config.architecture}' is not supported in your version of timm ({timm.__version__}). "
75+
"Please upgrade timm to a more recent version with `pip install -U timm`."
76+
) from e
77+
raise e
78+
79+
5880
@auto_docstring
5981
class TimmWrapperPreTrainedModel(PreTrainedModel):
6082
main_input_name = "pixel_values"
@@ -138,7 +160,7 @@ def __init__(self, config: TimmWrapperConfig):
138160
super().__init__(config)
139161
# using num_classes=0 to avoid creating classification head
140162
extra_init_kwargs = config.model_args or {}
141-
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0, **extra_init_kwargs)
163+
self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs)
142164
self.post_init()
143165

144166
@auto_docstring
@@ -254,8 +276,8 @@ def __init__(self, config: TimmWrapperConfig):
254276
)
255277

256278
extra_init_kwargs = config.model_args or {}
257-
self.timm_model = timm.create_model(
258-
config.architecture, pretrained=False, num_classes=config.num_labels, **extra_init_kwargs
279+
self.timm_model = _create_timm_model_with_error_handling(
280+
config, num_classes=config.num_labels, **extra_init_kwargs
259281
)
260282
self.num_labels = config.num_labels
261283
self.post_init()

0 commit comments

Comments
 (0)