-
Notifications
You must be signed in to change notification settings - Fork 670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scale gradient for backward pass #521
Comments
You can use Not sure if this is what you want, if not, please provide some python code in TF, PyTorch or MXNet so we can take a look. |
Thank you very much for your reply. I would like to give some more information about the use case I am looking at: MuZero use case: Java implementation of MuZero based on DJL (MXNet as Framework). Need: The MuZero paper comes with Python-Pseudocode (see inside the suplimentary data). The pseudocode uses this function
to scale down the error backpropagation from the recurrently called dynamic function. Support in the frameworks |
I think I found the function in the MXNet-Python API: BlockGrad |
It would be great to have it in Java, too. |
I'll test this
|
The stopGradient works well for me: I could remove my workaround and therefore gained gpu memory ... enough to double the batchsize. As it is a general functionality (e.g. used in MuZero) it would be very useful to add the functionality on the Java API, too. e.g. in the NDArray interface and its implementations. |
Question or maybe Enhancement
I'm missing a feature to scale the gradient for backward pass (as e.g. used in MuZero) ... something like
tensor * scale + stop_gradient(tensor) * (1 - scale)
I'm not sure if the feature is missing or I'm simply not seeing the proper way how to do it.
Workaround
I worked around it by adding an additional forward pass, keeping the tensor as outputs and putting them in on training forward as "stop_gradient(tensor)"-inputs. This works functionally, but comes at the cost of
The text was updated successfully, but these errors were encountered: