From 09d57073d83b3590a0a56660b0ec394b4c837744 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Fri, 7 Jan 2022 15:32:20 +0100 Subject: [PATCH] Pass forward pass args to context --- src/transformers/adapters/context.py | 6 +++--- src/transformers/adapters/model_mixin.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/adapters/context.py b/src/transformers/adapters/context.py index 215e77ec8..f7a007fd8 100644 --- a/src/transformers/adapters/context.py +++ b/src/transformers/adapters/context.py @@ -72,10 +72,10 @@ class ForwardContext: # thread-local storage that holds a stack of active contexts storage = threading.local() - def __init__(self, model): + def __init__(self, model, *args, **kwargs): # If the model has a method ``forward_context()``, use it to create the context. if hasattr(model, "forward_context"): - model.forward_context(self) + model.forward_context(self, *args, **kwargs) @classmethod def wrap(cls, f): @@ -85,7 +85,7 @@ def wrap(cls, f): @functools.wraps(f) def wrapper_func(self, *args, **kwargs): - context = cls(self) + context = cls(self, *args, **kwargs) cls.get_contexts().append(context) results = f(self, *args, **kwargs) cls.get_contexts().pop() diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 058cca869..82eb1f641 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -563,7 +563,7 @@ def freeze_model(self, freeze=True): param.requires_grad = not freeze self.model_freezed = freeze - def forward_context(self, context: ForwardContext): + def forward_context(self, context: ForwardContext, *args, **kwargs): """ This method is called by the ``ForwardContext`` at the beginning of the forward pass. """