From 38de8544ac64a523e09aa6a4d0be9e65f33a0844 Mon Sep 17 00:00:00 2001 From: Amir Zur Date: Wed, 1 May 2024 09:36:04 -0700 Subject: [PATCH 1/2] Added `use_cache` flag to intervenable model forward call --- pyvene/models/intervenable_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 62ba64b..21eb650 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1320,6 +1320,7 @@ def forward( labels: Optional[torch.LongTensor] = None, output_original_output: Optional[bool] = False, return_dict: Optional[bool] = None, + use_cache: Optional[bool] = True, ): """ Main forward function that serves a wrapper to @@ -1440,9 +1441,9 @@ def forward( # run intervened forward if labels is not None: - counterfactual_outputs = self.model(**base, labels=labels) + counterfactual_outputs = self.model(**base, labels=labels, use_cache=use_cache) else: - counterfactual_outputs = self.model(**base) + counterfactual_outputs = self.model(**base, use_cache=use_cache) set_handlers_to_remove.remove() self._output_validation() From 4d7eabd16c9e548fa22d882b80216537b8400e6b Mon Sep 17 00:00:00 2001 From: Amir Zur Date: Wed, 1 May 2024 14:09:44 -0700 Subject: [PATCH 2/2] Debugged `use_cache` flag for MLP model (which doesn't take in the `use_cache` argument in its forward call) --- pyvene/models/intervenable_base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 21eb650..68aea2f 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1440,10 +1440,14 @@ def forward( ) # run intervened forward - if labels is not None: - counterfactual_outputs = self.model(**base, labels=labels, use_cache=use_cache) - else: - counterfactual_outputs = self.model(**base, use_cache=use_cache) + model_kwargs = {} + if labels is not None: # for training + model_kwargs["labels"] = labels + if 'use_cache' in self.model.config.to_dict(): # for transformer models + model_kwargs["use_cache"] = use_cache + + counterfactual_outputs = self.model(**base, **model_kwargs) + set_handlers_to_remove.remove() self._output_validation()