-
Notifications
You must be signed in to change notification settings - Fork 512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LayerDeepLift Hook Fix #415
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Minor, naming suggestion
captum/attr/_core/deep_lift.py
Outdated
@@ -496,8 +496,10 @@ def _can_register_hook(self, module: Module) -> bool: | |||
or not self._is_non_linear(module) | |||
) | |||
|
|||
def _register_hooks(self, module: Module) -> None: | |||
if not self._can_register_hook(module): | |||
def _register_hooks(self, module: Module, skip_target_layer: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: target layer can be confusing because we also have target output. Perhaps, we can use the same naming for attributing to input or output. attribute_to_layer_input=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This avoids adding a hook to the target layer non-linearity when using LayerDeepLift if attributing with respect to layer output. This fixes an issue with in-place modules caused by skipping cloning and adds a corresponding test case. Pull Request resolved: pytorch#415 Reviewed By: edward-io Differential Revision: D22241851 Pulled By: vivekmig fbshipit-source-id: 771a7cbb3cf77438bba901237defe937a26c415c
This avoids adding a hook to the target layer non-linearity when using LayerDeepLift if attributing with respect to layer output. This fixes an issue with in-place modules caused by skipping cloning and adds a corresponding test case.