Skip to content

Commit

Permalink
Debugged use_cache flag for MLP model (which doesn't take in the `u…
Browse files Browse the repository at this point in the history
…se_cache` argument in its forward call)
  • Loading branch information
AmirZur committed May 1, 2024
1 parent 38de854 commit 4d7eabd
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,10 +1440,14 @@ def forward(
)

# run intervened forward
if labels is not None:
counterfactual_outputs = self.model(**base, labels=labels, use_cache=use_cache)
else:
counterfactual_outputs = self.model(**base, use_cache=use_cache)
model_kwargs = {}
if labels is not None: # for training
model_kwargs["labels"] = labels
if 'use_cache' in self.model.config.to_dict(): # for transformer models
model_kwargs["use_cache"] = use_cache

counterfactual_outputs = self.model(**base, **model_kwargs)

set_handlers_to_remove.remove()

self._output_validation()
Expand Down

0 comments on commit 4d7eabd

Please sign in to comment.