Skip to content

Commit

Permalink
Fix model_accepts_loss_kwargs for timm model (#35257)
Browse files Browse the repository at this point in the history
* Fix for timm model

* Add comment
  • Loading branch information
qubvel authored Dec 27, 2024
1 parent 3b0a94e commit 5c75087
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
config_class = TimmWrapperConfig
_no_split_modules = []

# used in Trainer to avoid passing `loss_kwargs` to model forward
accepts_loss_kwargs = False

def __init__(self, *args, **kwargs):
requires_backends(self, ["vision", "timm"])
super().__init__(*args, **kwargs)
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,15 @@ def __init__(
else unwrapped_model.get_base_model().forward
)
forward_params = inspect.signature(model_forward).parameters
self.model_accepts_loss_kwargs = any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())

# Check if the model has explicit setup for loss kwargs,
# if not, check if `**kwargs` are in model.forward
if hasattr(model, "accepts_loss_kwargs"):
self.model_accepts_loss_kwargs = model.accepts_loss_kwargs
else:
self.model_accepts_loss_kwargs = any(
k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
)

self.neftune_noise_alpha = args.neftune_noise_alpha

Expand Down

0 comments on commit 5c75087

Please sign in to comment.