Skip to content

Commit

Permalink
Merge pull request #152 from amirzur2023/main
Browse files Browse the repository at this point in the history
[Minor] Adding `use_cache` flag to intervenable model forward call
  • Loading branch information
frankaging authored May 2, 2024
2 parents 9b3e296 + 4d7eabd commit 1dc9243
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_original_output: Optional[bool] = False,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = True,
):
"""
Main forward function that serves a wrapper to
Expand Down Expand Up @@ -1439,10 +1440,14 @@ def forward(
)

# run intervened forward
if labels is not None:
counterfactual_outputs = self.model(**base, labels=labels)
else:
counterfactual_outputs = self.model(**base)
model_kwargs = {}
if labels is not None: # for training
model_kwargs["labels"] = labels
if 'use_cache' in self.model.config.to_dict(): # for transformer models
model_kwargs["use_cache"] = use_cache

counterfactual_outputs = self.model(**base, **model_kwargs)

set_handlers_to_remove.remove()

self._output_validation()
Expand Down

0 comments on commit 1dc9243

Please sign in to comment.