diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index dfb14dfccec4c6..47e8944583b4ca 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -82,6 +82,9 @@ class TimmWrapperPreTrainedModel(PreTrainedModel): config_class = TimmWrapperConfig _no_split_modules = [] + # used in Trainer to avoid passing `loss_kwargs` to model forward + accepts_loss_kwargs = False + def __init__(self, *args, **kwargs): requires_backends(self, ["vision", "timm"]) super().__init__(*args, **kwargs) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5957f8025d2a0b..87479805e70670 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -622,7 +622,15 @@ def __init__( else unwrapped_model.get_base_model().forward ) forward_params = inspect.signature(model_forward).parameters - self.model_accepts_loss_kwargs = any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()) + + # Check if the model has explicit setup for loss kwargs, + # if not, check if `**kwargs` are in model.forward + if hasattr(model, "accepts_loss_kwargs"): + self.model_accepts_loss_kwargs = model.accepts_loss_kwargs + else: + self.model_accepts_loss_kwargs = any( + k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values() + ) self.neftune_noise_alpha = args.neftune_noise_alpha