Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop gradient feature #181

Open
Red-Portal opened this issue Jun 13, 2024 · 1 comment
Open

Stop gradient feature #181

Red-Portal opened this issue Jun 13, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@Red-Portal
Copy link

Hi,

Stopping/dropping gradients is quite a common feature for a lot of applications. It would be great to have that in Tapir as well. Currently, ChainRulesCore provides ignore_gradient and non_differentiable as a unified interface, but pretty much only Zygote supports it. However, ReverseDiff provides its own macro: ReverseDiff.@skip.

@willtebbutt willtebbutt added the enhancement New feature or request label Jun 14, 2024
@willtebbutt
Copy link
Member

This should be straightforward, albeit slightly different from how ChainRules / Zygote do it, because we do a lot of in-place incrementation and propagation of memory locations at which to increment on the forwards-pass of AD.

I think this could be done by adding a rule which

  1. replaces the fdata associated to the thing whose gradient we're dropping with a copy that doesn't live at the same memory address, and
  2. returning zero_rdata always.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants