Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jitted functions can't take keyword arguments named f in v0.4.36 #25329

Closed
f0uriest opened this issue Dec 7, 2024 · 6 comments
Closed

Jitted functions can't take keyword arguments named f in v0.4.36 #25329

f0uriest opened this issue Dec 7, 2024 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@f0uriest
Copy link

f0uriest commented Dec 7, 2024

Description

For example:

def myfun(x, f):
    return x*f

jax.jit(myfun)(1.0, f=2.0)

Gives

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[6], line 4
      1 def myfun(x, f):
      2     return x*f
----> 4 jax.jit(myfun)(1.0, f=2.0)

    [... skipping hidden 11 frame]

File [~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/api_util.py:74](http://127.0.0.1:8888/lab/tree/SCHOOL/Princeton/PPPL/DESC/local/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/api_util.py#line=73), in flatten_fun(f, store, in_tree, *args_flat)
     71 @lu.transformation_with_aux2
     72 def flatten_fun(f, store, in_tree, *args_flat):
     73   py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
---> 74   ans = f(*py_args, **py_kwargs)
     75   ans, out_tree = tree_flatten(ans)
     76   store.store(out_tree)

TypeError: result_paths() got multiple values for argument 'f'

This only started happening in 0.4.36, I think the offending commit is 1c9b23c

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.36
jaxlib: 0.4.36
numpy:  1.24.4
python: 3.10.11 (main, May 16 2023, 00:28:57) [GCC 11.2.0]
device info: cpu-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='Discovery', release='5.15.0-125-generic', version='#135~20.04.1-Ubuntu SMP Mon Oct 7 13:56:22 UTC 2024', machine='x86_64')
@f0uriest f0uriest added the bug Something isn't working label Dec 7, 2024
@f0uriest
Copy link
Author

f0uriest commented Dec 7, 2024

it also happens if f is renamed to store, but works fine if renamed to something else like y etc, which seems to confirm its due to 1c9b23c

@f0uriest
Copy link
Author

f0uriest commented Dec 7, 2024

Note that the error is a bit different if using static args:

def myfun(x, f, y=5):
    return x*f + y

jax.jit(myfun, static_argnames=("y",))(1.0, f=2.0, y=2)

gives

TypeError: _argnames_partial() got multiple values for argument 'f'

This also breaks on previous versions of jax if f is renamed fixed_kwargs fixed_args, dyn_argnums (and possibly others, seems like any named args taken by some of the functions in api_util can't be the same name as args to jitted functions?)

dougalm added a commit that referenced this issue Dec 9, 2024
This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util.
We probably shouldn't use linear_util...
@dougalm
Copy link
Collaborator

dougalm commented Dec 9, 2024

Here's a quick fix: #25349 . It just switches our internal names to _store instead of store etc. Soon I want to get rid of linear util altogether.

@f0uriest
Copy link
Author

f0uriest commented Dec 9, 2024

Will there be a new release soon? This is a breaking change for interpax and likely other packages (f is a pretty common variable name for math-y functions)

@hawkinsp
Copy link
Collaborator

hawkinsp commented Dec 9, 2024

Actually, yes. I'm aiming to make one on or about Dec 15.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 9, 2024

Fixed by #25349

@jakevdp jakevdp closed this as completed Dec 9, 2024
hawkinsp pushed a commit that referenced this issue Dec 9, 2024
This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util.
We probably shouldn't use linear_util...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants