You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This issue is meant to “+1” the TODO – our tensors are large, errors are spammy, and I frequently find it necessary to drop into pdb to figure out what’s happening. I think pretty-printing can help a lot, and for now we opted into making it a dependency.
If moving equinox’s tree_pformat into jaxtyping is difficult, I think one alternative could be to allow the user to register a pprint function (e.g. through a global jaxtyping config). For my use cases either approach is fine - I didn't see any shortcomings in the equinox’s way of printing pytrees.
The text was updated successfully, but these errors were encountered:
Right! So moving this over was deemed nontrivial as it depends on JAX's own pretty-printing, and I didn't really want to duplicate all of JAX pp + Equinox pp into jaxtyping. (Although granted that's not that hard either!)
FWIW I have recently started a PyTorch project, for which the dependency on JAX (through Equinox) is undesirable to me as well, so I am actually hoping to fix this up in the next month or two.
For now I'm going to mark this as a feature request, and please feel free to ping this thread if nothing gets fixed after a couple of months!
By default, jaxtyping errors will directly print pytree contents (usually making errors long). If one depends on
equinox
(or have it installed), they can opt in for pretty printing, and there is a TODO for cleaning up this dependency: https://github.com/patrick-kidger/jaxtyping/blob/main/jaxtyping/_decorator.py#L767-L770This issue is meant to “+1” the TODO – our tensors are large, errors are spammy, and I frequently find it necessary to drop into pdb to figure out what’s happening. I think pretty-printing can help a lot, and for now we opted into making it a dependency.
If moving equinox’s tree_pformat into jaxtyping is difficult, I think one alternative could be to allow the user to register a pprint function (e.g. through a global jaxtyping config). For my use cases either approach is fine - I didn't see any shortcomings in the equinox’s way of printing pytrees.
The text was updated successfully, but these errors were encountered: