You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] linux
0.7.3 2.1.0 2.1.0
Problem description
I tried to recreate the jaxopt example for root finding with implicit differentiation using a very simple iterative solver. In jax, it just works. With torchopt, however, the gradient is zero for scalar inputs.
Required prerequisites
What version of TorchOpt are you using?
0.7.3
System information
3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] linux
0.7.3 2.1.0 2.1.0
Problem description
I tried to recreate the jaxopt example for root finding with implicit differentiation using a very simple iterative solver. In jax, it just works. With torchopt, however, the gradient is zero for scalar inputs.
The problem stems from:
torchopt/torchopt/linalg/cg.py
Lines 120 to 122 in a4cfc49
Here, the size becomes zero for scalar
b
and themaxiter
is wrongly set to 0. The same piece of code in jax produces a size of 1.Reproducible example code
The Python snippets:
Traceback
No response
Expected behavior
No response
Additional context
It works upon making the tensors 1D:
It just works in jax.
The text was updated successfully, but these errors were encountered: