-
Notifications
You must be signed in to change notification settings - Fork 512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convert forward return to tensor in FeatureAblation #1049
Conversation
# our tests expect int -> torch.int64, float -> torch.float64 | ||
# but this may actually depend on the machine | ||
# ref: https://docs.python.org/3.10/library/stdtypes.html#typesnumeric | ||
return torch.tensor(forward_output, dtype=output_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I inherit our original logic that passing python types to torch dtype, like dtype=float
. But this is not an officially documented operation. Existing tests assume it must equal to dtype=torch.float64
https://github.com/pytorch/captum/blob/5f878af6a7/tests/attr/test_feature_ablation.py#L429
But this may be machine dependent https://docs.python.org/3.10/library/stdtypes.html#typesnumeric .
Floating point numbers are usually implemented using double in C; information about the precision and internal representation of floating point numbers for the machine on which your program is running is available in sys.float_info
Two other alternatives are:
- explicitly map python type to torch dtype: float -> torch.float64
- do not set dtype, rely on torch's default dtype (float32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting point! I looked into it a bit, this functionality seems to be added in this PR: pytorch/pytorch#21215
It looks like the type mapping is done explicitly on the C++ side using the PyObject type, so this shouldn't be affected by the internal representation. This is the logic for mapping:
PyObject *obj = args[i];
if (obj == (PyObject*)&PyFloat_Type) {
return at::ScalarType::Double;
}
if (obj == (PyObject*)&PyBool_Type) {
return at::ScalarType::Bool;
}
if (obj == (PyObject*)&PyLong_Type
#if PY_MAJOR_VERSION == 2
|| obj == (PyObject*)&PyInt_Type
#endif
) {
return at::ScalarType::Long;
}
So if float is set as dtype, this would be passed through the Python / C++ bindings as PyFloat_Type, which should always correspond to ScalarType::Double / torch.float64. The tests in the original PR also verify this mapping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the deep dive @vivekmig !
Then I will just add the comment to refer the mapping, also as a caveat.
After all, it is not a documented torch usage. May have breaking changes someday.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, sounds good!
@aobo-y has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks! Just one note on the dtype comments.
# our tests expect int -> torch.int64, float -> torch.float64 | ||
# but this may actually depend on the machine | ||
# ref: https://docs.python.org/3.10/library/stdtypes.html#typesnumeric | ||
return torch.tensor(forward_output, dtype=output_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting point! I looked into it a bit, this functionality seems to be added in this PR: pytorch/pytorch#21215
It looks like the type mapping is done explicitly on the C++ side using the PyObject type, so this shouldn't be affected by the internal representation. This is the logic for mapping:
PyObject *obj = args[i];
if (obj == (PyObject*)&PyFloat_Type) {
return at::ScalarType::Double;
}
if (obj == (PyObject*)&PyBool_Type) {
return at::ScalarType::Bool;
}
if (obj == (PyObject*)&PyLong_Type
#if PY_MAJOR_VERSION == 2
|| obj == (PyObject*)&PyInt_Type
#endif
) {
return at::ScalarType::Long;
}
So if float is set as dtype, this would be passed through the Python / C++ bindings as PyFloat_Type, which should always correspond to ScalarType::Double / torch.float64. The tests in the original PR also verify this mapping.
@@ -601,3 +593,20 @@ def _find_output_mode( | |||
feature_mask is None | |||
or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask) | |||
) | |||
|
|||
def _run_forward(self, *args, **kwargs) -> Tensor: | |||
forward_output = _run_forward(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It seems a bit confusing when seeing both the instance method and original method named as _run_forward, could consider renaming this one slightly, but either way is fine.
@aobo-y has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
FeatureAblation(Permutation) never clearly define the type of
forward
's return. According to the documentation,tensor
is the only acceptable typeHowever, returning a single
int
orfloat
is a common use case we have already supported (ref #1047 (comment)). But our code did not explicitly raise error for unexpected types. Other types likelist
,tuple
, ornumpy.array
may pass unexpectedly or fail in unexpected places with confusing error messages, such as we may uselist
astorch.dtype
captum/captum/attr/_core/feature_ablation.py
Line 320 in 5f878af
The PR explicitly assert the return type and convert everything into
tensor
. The assertion & conversion is done in a new private_run_forward
wrapper instead of_run_forward
from global utils, which is shared by many other classes. I will update others progressively and eventually push the logic to the shared_run_forward
.