-
Notifications
You must be signed in to change notification settings - Fork 512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support normalize in metric infidelity #639
Conversation
@aobo-y has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
@aobo-y has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for adding this @aobo-y ! Just some minor nits on documentation.
captum/metrics/_core/infidelity.py
Outdated
@@ -345,6 +346,15 @@ def infidelity( | |||
`input batch size * n_perturb_samples`. | |||
|
|||
Default: None | |||
normalize (bool, optional): Normalize the dot product of the input | |||
perturbation and the attribution so the infedelity value is invariant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: infidelity
captum/metrics/_core/infidelity.py
Outdated
perturbation and the attribution so the infedelity value is invariant | ||
to constant scaling of the attribution values. The normalization factor | ||
is defined as the ratio of two mean values across all perturbations: | ||
`mean(dot product * func value diff) / mean(dot product * dot product)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Would it be possible to make this line a little more detailed to explain this factor for new users? E.g. it may not be immediately clear that func value diff is the same as the difference between the predictor function at its input and perturbed input described above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. I'd recommend using paper notation with latex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, i can use latex here
captum/metrics/_core/infidelity.py
Outdated
max_examples_per_batch=max_examples_per_batch, | ||
) | ||
return metrics_sum * 1 / n_perturb_samples | ||
|
||
attr_times_perturb_sums, perturbed_fwd_diffs = metrics_sum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can include this above as return of _divide_and_aggregate_metrics ? metrics_sum isn't really applicable anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this PR @aobo-y!
If I understand this implementation correctly we aggregate all inputs and all perturbations in the memory which can lead to out of memory very quickly. The idea of using torch.add or torch.mean as an aggregate function was to help to scale the implementation and avoid out of memory as much as possible.
In _next_infidelity_tensors
you have access to normalize input argument (I've just verified it. In python you are still running in the same context and you have access to normalize argument in that context). You can keep track of np.mean(pdt_diff * exp_sum) and np.mean(exp_sum * exp_sum) in _next_infidelity_tensors
too and ultimately apply final beta per example in the end. Let me know what you think.
captum/metrics/_core/infidelity.py
Outdated
perturbation and the attribution so the infedelity value is invariant | ||
to constant scaling of the attribution values. The normalization factor | ||
is defined as the ratio of two mean values across all perturbations: | ||
`mean(dot product * func value diff) / mean(dot product * dot product)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. I'd recommend using paper notation with latex.
@NarineK if you mean I keep all the perturbation results in the memory, then yes, I am aware of the consumption of memory. If the perturbation number is too large, we will surely end with out of memory. In python, I can access the context in the nested aggregation function and know if we need to normalize. Unfortunately, it is not the reason that I have to keep all perturbation result. If I understand correctly, the author's normalization applies the What you suggested does help calculate the two Mathematically, what we need is |
I agree that maintaining just the means necessary for beta would not be sufficient to avoid maintaining results per perturbation sample. I think the approach to not store the sample-wise results would be to maintain based on the expansion of (a - b)^2 This approach would avoid the additional memory, but at the tradeoff of a potentially trickier formulation to follow. The additional memory used here should be on the order of batch_size * n_perturb_samples (full input perturbations wouldn't be maintained), so this shouldn't be a large issue / lead to OOMs with typical use cases. But to be on the safer side, if we expect potentially larger values, might be worth considering the alternative approach. What do you think @NarineK , @aobo-y ? |
I saw that beta is a scalar and it looked to me that we should be able to do that sample based. Sorry for the confusion. sum(a - b) = sum(a) - sum(b), so if we have sum(a) and sum(b) we can compute sum(beta * a - b) = beta * sum(a) - sum(b) |
1029503
to
c6d8e54
Compare
\mathbb{E}_{I \sim \mu_I} [ I^T \Phi(f, x) (f(x) - f(x - I)) ] | ||
}{ | ||
\mathbb{E}_{I \sim \mu_I} [ (I^T \Phi(f, x))^2 ] | ||
} $$. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK @vivekmig The unittest I tried to study where the differences come from and found an inconsistent behavior of just in case this is unknown before. For this PR, this means due to the precision in float32, the difference between |
Here is the detailed example https://github.com/pytorch/captum/blob/master/tests/metrics/test_infidelity.py#L268 But in the latest pytorch, both |
@aobo-y has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
) | ||
else: | ||
# returns (a-b)^2 if no need to normalize | ||
return ((attr_times_perturb_sums - perturbed_fwd_diffs).pow(2).sum(-1),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
due to the above issue, I keep two ways for aggregation based on normalize
instead of always using a^2-2ab+b^2
:
(a-b)^2
if not normalizea^2
,ab
,b^2
if normalize
This change allows me to pass the tests.
But still worth noting that in older version of pytorch, when normalize, a^2-2ab+b^2
will lose some precision compared with direct (a-b)^2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, sounds good, this fix looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this, @aobo-y! Looks good to me! I left one small nit comment.
captum/metrics/_core/infidelity.py
Outdated
beta_num = agg_tensors[1] | ||
beta_denorm = agg_tensors[0] | ||
|
||
beta_denorm[beta_denorm == 0] += 1e-10 # safe divide |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to use a common function for safe divide: safe_div
https://github.com/pytorch/captum/blob/master/captum/_utils/common.py#L26
@aobo-y has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Updated the PR with |
#613
Support normalizing the infidelity like the author's implementation https://github.com/chihkuanyeh/saliency_evaluation/blob/44a66e2531f30b803be3bf5b0786971b7e7f72a1/infid_sen_utils.py#L295