@@ -55,6 +55,28 @@ class TimmWrapperModelOutput(ModelOutput):
5555    attentions : Optional [tuple [torch .FloatTensor , ...]] =  None 
5656
5757
58+ def  _create_timm_model_with_error_handling (config : "TimmWrapperConfig" , ** model_kwargs ):
59+     """ 
60+     Creates a timm model and provides a clear error message if the model is not found, 
61+     suggesting a library update. 
62+     """ 
63+     try :
64+         model  =  timm .create_model (
65+             config .architecture ,
66+             pretrained = False ,
67+             ** model_kwargs ,
68+         )
69+         return  model 
70+     except  RuntimeError  as  e :
71+         if  "Unknown model"  in  str (e ):
72+             # A good general check for unknown models. 
73+             raise  ImportError (
74+                 f"The model architecture '{ config .architecture } { timm .__version__ }  
75+                 "Please upgrade timm to a more recent version with `pip install -U timm`." 
76+             ) from  e 
77+         raise  e 
78+ 
79+ 
5880@auto_docstring  
5981class  TimmWrapperPreTrainedModel (PreTrainedModel ):
6082    main_input_name  =  "pixel_values" 
@@ -138,7 +160,7 @@ def __init__(self, config: TimmWrapperConfig):
138160        super ().__init__ (config )
139161        # using num_classes=0 to avoid creating classification head 
140162        extra_init_kwargs  =  config .model_args  or  {}
141-         self .timm_model  =  timm . create_model (config . architecture ,  pretrained = False , num_classes = 0 , ** extra_init_kwargs )
163+         self .timm_model  =  _create_timm_model_with_error_handling (config , num_classes = 0 , ** extra_init_kwargs )
142164        self .post_init ()
143165
144166    @auto_docstring  
@@ -254,8 +276,8 @@ def __init__(self, config: TimmWrapperConfig):
254276            )
255277
256278        extra_init_kwargs  =  config .model_args  or  {}
257-         self .timm_model  =  timm . create_model (
258-             config . architecture ,  pretrained = False , num_classes = config .num_labels , ** extra_init_kwargs 
279+         self .timm_model  =  _create_timm_model_with_error_handling (
280+             config , num_classes = config .num_labels , ** extra_init_kwargs 
259281        )
260282        self .num_labels  =  config .num_labels 
261283        self .post_init ()
0 commit comments