diff --git a/src/transformers/adapters/context.py b/src/transformers/adapters/context.py index 215e77ec83..f7a007fd80 100644 --- a/src/transformers/adapters/context.py +++ b/src/transformers/adapters/context.py @@ -72,10 +72,10 @@ class ForwardContext: # thread-local storage that holds a stack of active contexts storage = threading.local() - def __init__(self, model): + def __init__(self, model, *args, **kwargs): # If the model has a method ``forward_context()``, use it to create the context. if hasattr(model, "forward_context"): - model.forward_context(self) + model.forward_context(self, *args, **kwargs) @classmethod def wrap(cls, f): @@ -85,7 +85,7 @@ def wrap(cls, f): @functools.wraps(f) def wrapper_func(self, *args, **kwargs): - context = cls(self) + context = cls(self, *args, **kwargs) cls.get_contexts().append(context) results = f(self, *args, **kwargs) cls.get_contexts().pop() diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 980d83ccaa..9ab0aa2cc3 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -534,7 +534,7 @@ def freeze_model(self, freeze=True): param.requires_grad = not freeze self.model_freezed = freeze - def forward_context(self, context: ForwardContext): + def forward_context(self, context: ForwardContext, *args, **kwargs): """ This method is called by the ``ForwardContext`` at the beginning of the forward pass. """