From 8086f2f18666af7fc20726e3c0d8af59eaa5067c Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Fri, 13 Dec 2024 11:43:14 +0000 Subject: [PATCH 1/2] Fix for timm model --- .../models/timm_wrapper/modeling_timm_wrapper.py | 1 + src/transformers/trainer.py | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index dfb14dfccec4c6..1dc0e8728cf6ee 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -80,6 +80,7 @@ class TimmWrapperModelOutput(ModelOutput): class TimmWrapperPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" config_class = TimmWrapperConfig + accepts_loss_kwargs = False _no_split_modules = [] def __init__(self, *args, **kwargs): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a708d8deb4efcc..a9680ec980eb1b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -623,7 +623,15 @@ def __init__( else unwrapped_model.get_base_model().forward ) forward_params = inspect.signature(model_forward).parameters - self.model_accepts_loss_kwargs = any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()) + + # Check if the model has explicit setup for loss kwargs, + # if not, check if `**kwargs` are in model.forward + if hasattr(model, "accepts_loss_kwargs"): + self.model_accepts_loss_kwargs = model.accepts_loss_kwargs + else: + self.model_accepts_loss_kwargs = any( + k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values() + ) self.neftune_noise_alpha = args.neftune_noise_alpha From 2ced051fa777e5931b91956b2a6e535b3e1c9a0b Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Fri, 13 Dec 2024 11:50:37 +0000 Subject: [PATCH 2/2] Add comment --- src/transformers/models/timm_wrapper/modeling_timm_wrapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 1dc0e8728cf6ee..47e8944583b4ca 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -80,9 +80,11 @@ class TimmWrapperModelOutput(ModelOutput): class TimmWrapperPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" config_class = TimmWrapperConfig - accepts_loss_kwargs = False _no_split_modules = [] + # used in Trainer to avoid passing `loss_kwargs` to model forward + accepts_loss_kwargs = False + def __init__(self, *args, **kwargs): requires_backends(self, ["vision", "timm"]) super().__init__(*args, **kwargs)