[nnx] improve nnx.scan in_axes/out_axes #4157
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Makes
nnx.scan
'sin_axes
andout_axes
API is more flexible in the following ways:carry, outputs
,out_axes
can useCarry
to control its structure, by default its now(Carry, 0)
to match current behavior.Carry
is no longer required to be present inin_axes
/out_axes
, however ifCarry
is present inin_axes
it must also be present inout_axes
and vice versa.in_axes
is not limited to aSequence
, as withvmap
a single top-level integer orNone
can now be used to configure all arguments. A top-levelCarry
is also accepted byin_axes
but its interpreted as(Carry,)
, it means that the function accepts a single argument that is the Carry.Example: Scan over layers