-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jitted functions can't take keyword arguments named f
in v0.4.36
#25329
Comments
it also happens if |
Note that the error is a bit different if using static args: def myfun(x, f, y=5):
return x*f + y
jax.jit(myfun, static_argnames=("y",))(1.0, f=2.0, y=2) gives
This also breaks on previous versions of jax if |
This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util. We probably shouldn't use linear_util...
Here's a quick fix: #25349 . It just switches our internal names to |
Will there be a new release soon? This is a breaking change for |
Actually, yes. I'm aiming to make one on or about Dec 15. |
Fixed by #25349 |
This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util. We probably shouldn't use linear_util...
Description
For example:
Gives
This only started happening in 0.4.36, I think the offending commit is 1c9b23c
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: