-
Notifications
You must be signed in to change notification settings - Fork 536
refactor feature ablation #1047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -286,22 +286,20 @@ def attribute( | |
if show_progress: | ||
attr_progress.update() | ||
|
||
# number of elements in the output of forward_func | ||
n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1 | ||
|
||
# flatten eval outputs into 1D (n_outputs) | ||
# add the leading dim for n_feature_perturbed | ||
if isinstance(initial_eval, Tensor): | ||
initial_eval = initial_eval.reshape(1, -1) | ||
|
||
agg_output_mode = FeatureAblation._find_output_mode( | ||
perturbations_per_eval, feature_mask | ||
) | ||
|
||
# get as a 2D tensor (if it is not a scalar) | ||
if isinstance(initial_eval, torch.Tensor): | ||
initial_eval = initial_eval.reshape(1, -1) | ||
num_outputs = initial_eval.shape[1] | ||
else: | ||
num_outputs = 1 | ||
|
||
if not agg_output_mode: | ||
assert ( | ||
isinstance(initial_eval, torch.Tensor) | ||
and num_outputs == num_examples | ||
), ( | ||
assert isinstance(initial_eval, Tensor) and n_outputs == num_examples, ( | ||
"expected output of `forward_func` to have " | ||
+ "`batch_size` elements for perturbations_per_eval > 1 " | ||
+ "and all feature_mask.shape[0] > 1" | ||
|
@@ -316,8 +314,9 @@ def attribute( | |
) | ||
|
||
total_attrib = [ | ||
# attribute w.r.t each output element | ||
torch.zeros( | ||
(num_outputs,) + input.shape[1:], | ||
(n_outputs,) + input.shape[1:], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this functionality is nice if used for multiple per-example outputs, but might cause some confusion since the output attributions shape may no longer always be aligned with the input shape. Particularly, some models have outputs with some dimension sizes = 1, for example with input 6 x 4 and output 6 x 1 (or even 6 x 1 x 1), this would currently have 6 x 4 attribution rather than 6 x 1 x 4 or 6 x 1 x 1 x 4 with this change. Maybe we can squeeze all dimensions of 1 and then utilize the output shape? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
do you mean the number of dimension would no longer be the same between input and returned attribution? E.g., if the input is My suggestion tries to respect the output of the forward more. The dimension (or the entire shape) of the returned attr depend on the input and output shapes of the forward. E.g., if the input is in I can see the reasons behind both designs. But I would suggest to not squeeze if we use the 2nd way. Coz we do not understand the exact meaning of the dims of 1. I am afraid some weird edge cases that keeping the 1 can be very convenient (users want to reuse the selection indices created for the forward output). I would trust the users to know their own model output 😆 . If they feel the dims of 1 are really annoying, they can always control it by wrapping their own model |
||
dtype=attrib_type, | ||
device=input.device, | ||
) | ||
|
@@ -328,7 +327,7 @@ def attribute( | |
if self.use_weights: | ||
weights = [ | ||
torch.zeros( | ||
(num_outputs,) + input.shape[1:], device=input.device | ||
(n_outputs,) + input.shape[1:], device=input.device | ||
).float() | ||
for input in inputs | ||
] | ||
|
@@ -354,8 +353,11 @@ def attribute( | |
perturbations_per_eval, | ||
**kwargs, | ||
): | ||
# modified_eval dimensions: 1D tensor with length | ||
# equal to #num_examples * #features in batch | ||
# modified_eval has (n_feature_perturbed * n_outputs) elements | ||
# shape: | ||
# agg mode: (*initial_eval.shape) | ||
# non-agg mode: | ||
# (feature_perturbed * batch_size, *initial_eval.shape[1:]) | ||
modified_eval = _run_forward( | ||
self.forward_func, | ||
current_inputs, | ||
|
@@ -366,25 +368,34 @@ def attribute( | |
if show_progress: | ||
attr_progress.update() | ||
|
||
# (contains 1 more dimension than inputs). This adds extra | ||
# dimensions of 1 to make the tensor broadcastable with the inputs | ||
# tensor. | ||
if not isinstance(modified_eval, torch.Tensor): | ||
eval_diff = initial_eval - modified_eval | ||
else: | ||
if not agg_output_mode: | ||
# current_batch_size is not n_examples | ||
# it may get expanded by n_feature_perturbed | ||
current_batch_size = current_inputs[0].shape[0] | ||
assert ( | ||
modified_eval.numel() == current_inputs[0].shape[0] | ||
modified_eval.numel() == current_batch_size | ||
), """expected output of forward_func to grow with | ||
batch_size. If this is not the case for your model | ||
please set perturbations_per_eval = 1""" | ||
|
||
eval_diff = ( | ||
initial_eval - modified_eval.reshape((-1, num_outputs)) | ||
).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,)) | ||
# reshape the leading dim for n_feature_perturbed | ||
# flatten each feature's eval outputs into 1D of (n_outputs) | ||
modified_eval = modified_eval.reshape(-1, n_outputs) | ||
# eval_diff in shape (n_feature_perturbed, n_outputs) | ||
eval_diff = initial_eval - modified_eval | ||
|
||
# append the shape of one input example | ||
# to make it broadcastable to mask | ||
eval_diff = eval_diff.reshape( | ||
eval_diff.shape + (inputs[i].dim() - 1) * (1,) | ||
) | ||
eval_diff = eval_diff.to(total_attrib[i].device) | ||
if self.use_weights: | ||
weights[i] += current_mask.float().sum(dim=0) | ||
|
||
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum( | ||
dim=0 | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it common for the
forward
not return a Tensor? What else do we expect? I would guessfloat
ordouble
, but they are all very easy to be converted into tensor and users can just do it by wrapping the forward. So we can explicitly define the return type offorward
.Even if we really want to support non-Tensor returns, can we convert them in
_run_forward
to standardize the logic?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think this is fairly common, the model could return float / double / int, and there are use-cases depending on this. It would be best to maintain support for this use-case, although moving the logic to wrap as a Tensor to _run_forward should be fine. This would require an extra single-element tensor construction, but would avoid branching here, I'm fine either way.