Skip to content

Commit

Permalink
Use private names for args in api_util to avoid shadowing kwargs keys.
Browse files Browse the repository at this point in the history
This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util.
We probably shouldn't use linear_util...
  • Loading branch information
dougalm committed Dec 9, 2024
1 parent 1ac6b76 commit dd74394
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args

@lu.transformation2
def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs):
def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs):
sentinel = object()
args = [sentinel] * (len(fixed_args) + len(dyn_args))
for i, arg in zip(dyn_argnums, dyn_args):
args = [sentinel] * (len(_fixed_args) + len(dyn_args))
for i, arg in zip(_dyn_argnums, dyn_args):
args[i] = arg
fixed_args_ = iter(fixed_args)
fixed_args_ = iter(_fixed_args)
args = [next(fixed_args_).val if x is sentinel else x for x in args]
assert next(fixed_args_, sentinel) is sentinel
return f(*args, **kwargs)
return _fun(*args, **kwargs)

def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
kwargs: dict[str, Any]):
Expand All @@ -315,9 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs

@lu.transformation2
def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
return f(*args, **kwargs)
def _argnames_partial(_fun, _fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
kwargs = dict({k: v.val for k, v in _fixed_kwargs.val.items()}, **dyn_kwargs)
return _fun(*args, **kwargs)


@lru_cache(maxsize=4096)
Expand Down Expand Up @@ -438,9 +438,9 @@ def flat_out_axes(
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))

@lu.transformation_with_aux2
def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
ans = f(*args, **kwargs)
spec = tree_unflatten(treedef, leaves)
def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs):
ans = _fun(*args, **kwargs)
spec = tree_unflatten(_treedef, _leaves)
try:
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
except ValueError:
Expand All @@ -451,7 +451,7 @@ def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
"pmapped function's output.")
raise ValueError(msg) from None
store.store(spec_flat)
_store.store(spec_flat)
return ans

def check_callable(fun):
Expand Down Expand Up @@ -687,10 +687,10 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
for path, l in generate_key_paths(x) if l is not static)

@lu.transformation_with_aux2
def result_paths(f, store, *args, **kwargs):
def result_paths(_fun, _store, *args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
ans = f(*args, **kwargs)
store.store([keystr(path) for path, _ in generate_key_paths(ans)])
ans = _fun(*args, **kwargs)
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans

def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
Expand Down

0 comments on commit dd74394

Please sign in to comment.