Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward all arguments when using nnx.transforms.deprecated.scan as a decorator. #4208

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion flax/nnx/transforms/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,19 @@ def scan(
) -> F | tp.Callable[[F], F]:
if isinstance(f, Missing):
return functools.partial(
scan, length=length, reverse=reverse, unroll=unroll
scan,
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
in_axes=in_axes,
in_axes_kwargs=in_axes_kwargs,
out_axes=out_axes,
carry_argnum=carry_argnum,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
scan_output=scan_output,
) # type: ignore[return-value]

@functools.wraps(f)
Expand Down
Loading