Skip to content

Commit

Permalink
Pass forward pass args to context
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jan 7, 2022
1 parent 1d39ec4 commit 09d5707
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/transformers/adapters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ class ForwardContext:
# thread-local storage that holds a stack of active contexts
storage = threading.local()

def __init__(self, model):
def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
if hasattr(model, "forward_context"):
model.forward_context(self)
model.forward_context(self, *args, **kwargs)

@classmethod
def wrap(cls, f):
Expand All @@ -85,7 +85,7 @@ def wrap(cls, f):

@functools.wraps(f)
def wrapper_func(self, *args, **kwargs):
context = cls(self)
context = cls(self, *args, **kwargs)
cls.get_contexts().append(context)
results = f(self, *args, **kwargs)
cls.get_contexts().pop()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def freeze_model(self, freeze=True):
param.requires_grad = not freeze
self.model_freezed = freeze

def forward_context(self, context: ForwardContext):
def forward_context(self, context: ForwardContext, *args, **kwargs):
"""
This method is called by the ``ForwardContext`` at the beginning of the forward pass.
"""
Expand Down

0 comments on commit 09d5707

Please sign in to comment.