-
Notifications
You must be signed in to change notification settings - Fork 512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model inputs' marginal effects flag #432
Conversation
captum/_utils/common.py
Outdated
@@ -23,17 +23,17 @@ class ExpansionTypes(Enum): | |||
|
|||
|
|||
def safe_div( | |||
denom: Tensor, quotient: Union[Tensor, float], default_value: Tensor | |||
nominator: Tensor, denom: Union[Tensor, float], default_value: Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thought that nom / denom are more common names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Haven't heard of nominator, maybe numerator might be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's called numerator :D Thank you 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great 👍 ! Just a few nits, and some thoughts on naming.
captum/_utils/common.py
Outdated
@@ -23,17 +23,17 @@ class ExpansionTypes(Enum): | |||
|
|||
|
|||
def safe_div( | |||
denom: Tensor, quotient: Union[Tensor, float], default_value: Tensor | |||
nominator: Tensor, denom: Union[Tensor, float], default_value: Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Haven't heard of nominator, maybe numerator might be better?
captum/attr/_core/deep_lift.py
Outdated
@@ -103,7 +103,7 @@ class DeepLift(GradientAttribution): | |||
https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/ | |||
""" | |||
|
|||
def __init__(self, model: Module) -> None: | |||
def __init__(self, model: Module, use_input_marginal_effects: bool = True) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to add documentation here and also particularly highlight that this flag has no effect on attributions when custom_attribution_func is provided.
captum/attr/_core/gradient_shap.py
Outdated
|
||
class InputBaselineXGradient(GradientAttribution): | ||
def __init__(self, forward_func: Callable) -> None: | ||
def __init__(self, forward_func: Callable, use_input_marginal_effects=True) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Setting this to False is equivalent to Saliency, do we need this option here or can we set this to always be True and use Saliency instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, use_input_marginal_effects=False
will result to Saliency if baseline is zero but GradientShap internally explicitly calls this class. This class isn't exposed externally. In order to have a local version of GradientShap we will need to have this option but we can turn it off for input_x_gradient.py
@@ -16,14 +16,17 @@ class InputXGradient(GradientAttribution): | |||
https://arxiv.org/abs/1611.07270 | |||
""" | |||
|
|||
def __init__(self, forward_func: Callable) -> None: | |||
def __init__( | |||
self, forward_func: Callable, use_input_marginal_effects: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Setting this to False is also equivalent to Saliency, do we need this option here or can we set this to always be True and use Saliency instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can disable it here. use_input_marginal_effects=False is equivalent to Saliency
r""" | ||
Args: | ||
|
||
forward_func (callable): The forward function of the model or any | ||
modification of it | ||
modification of it | ||
use_input_marginal_effects (bool): Indicates whether to factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what the best choice for naming would be. I think the paper on local vs global attributions calls global as including marginal effects, but to me, marginal effects intuitively seems to be the effect per unit of input change (essentially like the gradient). In this view, the marginal effect would correspond to the local attribution (or change per unit in input) and multiplying the marginal attribution by input (or input - baseline) would give the global attribution. I think the paper considers it as a marginal effect in a binary space of 0 being baseline and 1 being the input, so I feel like marginal effects could be confusing.
What about something like include_input_multipler or something that clarifies particularly the multiplier that is or isn't factored in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The paper refers to the marginal effects of the input features but it looks like it can be misinterpreted. Marginal effects for continuous inputs are also less interpretable, since the unit change is being scaled. We can give it a simpler name such as: multiply_by_inputs
.
DeepLiftShap.__init__(self, model) | ||
LayerDeepLift.__init__(self, model, layer, use_input_marginal_effects) | ||
DeepLiftShap.__init__(self, model, use_input_marginal_effects) | ||
self._use_input_marginal_effects = use_input_marginal_effects |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Isn't this already set in the LayerDeepLift or DeepLiftShap constructors, is this line also necessary? Also could consider adding this to the base Attribution class and passing it there to avoid setting the attribute in every child class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, line 396 can be removed and also no need to pass the argument in line 395.
For the default cases the property is hard coded. For some algorithms such as Saliency
this flag is unnecessary in the constructor. That's why I avoided to pass this as a constructor argument to the parent Attribution
class.
captum/metrics/_core/infidelity.py
Outdated
This attribution scores can be computed using the implementations | ||
provided in the `captum.attr` package. Some of those attribution | ||
approaches are so called global methods, which means that | ||
they foctor in inputs' marginal effects, as described in: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: factor
captum/metrics/_core/infidelity.py
Outdated
@@ -145,20 +159,26 @@ def my_perturb_func(inputs): | |||
`infidelity_perturb_func_decorator` decorator such as: | |||
|
|||
from captum.metrics import infidelity_perturb_func_decorator | |||
@infidelity_perturb_func_decorator | |||
@infidelity_perturb_func_decorator(<use_input_marginal_effects flag>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Could consider making this signature take the attribution method for consistency with sensitivity? Not particularly related to this PR though, just something to consider. It would also be nice if there was a way to easily switch between attribution algorithms when computing infidelity without changing this flag on the perturbation decorator (since it can be obtained from the attribution method), but not sure if it's possible with the decorator approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point! Originally, I thought that the users can use any attribution / explanation scores, not necessarily the ones that we provide in Captum. We can consider trying the other option out and seeing how it turns out..
In terms of infidelity_perturb_func_decorator
we might be able to detect if the function is decorated and pass the local / global flag from the caller. I can try to play with it before we make the release and also include the typhints in a separate PR.
@@ -34,6 +35,17 @@ def test_compare_with_emb_patching(self) -> None: | |||
input1, baseline1, additional_args=(input2, input3) | |||
) | |||
|
|||
def test_compare_with_emb_patching_wo_inp_marginal_effects(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test seems identical to the previous one with marginal effects, maybe needs to pass an option for marginal effects?
use_input_marginal_effects=False, | ||
) | ||
|
||
def test_simple_conductance_input_linear2_wo_(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like the same test as test_simple_conductance_input_linear2 above?
Addressed the comments. It looks like we the upgrade doesn't work as expected: You are using pip version 10.0.1, however version 20.2.1 is available. The builds keep failing, cc: @vivekmig update I had to make changes in config.yml related to isort and rerun the isort. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Adding a flag that allows to switch on and off inputs' marginal effects in some attribution algorithms. If `use_input_marginal_effects ` flag is True then model inputs' marginal effects will be factored in, otherwise they won't. Default behavior factors those effects in. In addition to that I also made `infidelity_perturb_func_decorator` decorator parametrized. It allows as to take inputs marginal effects into account. Fixed the documentation for infidelity and added an example documentation in IG for `use_input_marginal_effects` argument. I'll copy that documentation in all affected arguments as well once it is reviewed. Added test cases for all methods that are affected by this change. Pull Request resolved: pytorch#432 Reviewed By: edward-io, vivekmig Differential Revision: D23010668 Pulled By: NarineK fbshipit-source-id: 69aa8835da6cf815176a552d0006a28d599b28c7
Adding a flag that allows to switch on and off inputs' marginal effects in some attribution algorithms.
If
use_input_marginal_effects
flag is True then model inputs' marginal effects will be factored in, otherwise they won't.Default behavior factors those effects in.
In addition to that I also made
infidelity_perturb_func_decorator
decorator parametrized. It allows as to take inputs marginal effects into account.Fixed the documentation for infidelity and added an example documentation in IG for
use_input_marginal_effects
argument. I'll copy that documentation in all affected arguments as well once it is reviewed.Added test cases for all methods that are affected by this change.