-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] transforms refactor #3927
Conversation
cf966a1
to
4b5ed58
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did my best to read through, but this PR seems to have too many moving parts to really follow what's going on.
I recommend
- adding more details to the PR description -- what were the bugs, why were they important, what is the high level fix idea etc etc;
- doing smaller more scoped PRs going forward.
If you can, ask one other person to review as well.
flax/experimental/nnx/nnx/graph.py
Outdated
|
||
for ref in self.refmap: | ||
if isinstance(ref, Variable): | ||
ref.raw_value = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to add a clear()
method to Variable
doing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed this code, found a better way.
flax/experimental/nnx/nnx/rnglib.py
Outdated
@@ -173,7 +160,7 @@ def fork( | |||
state: State, | |||
split_filter: filterlib.Filter, | |||
split_pattern: SplitPattern, | |||
) -> tuple[State, State]: | |||
) -> tuple[State, State, State, State]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend making this a typing.NamedTuple
or a dataclass so that the caller doesn't have to remember what each State
component corresponds to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
argnums = options.argnums[0] if len(options.argnums) == 1 else options.argnums | ||
# rebuild diff_state from substates in args | ||
diff_state = State({}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this be the same as
diff_state = State({i: _args[i] for i in diff_args})
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe like this:
diff_state = State({i: _args[i].raw_mapping for i in diff_args})
reduce_axes=reduce_axes, | ||
)(*_args, f, ctx, graphdef, non_diff_state, has_aux, diff_args) | ||
|
||
updates: State |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
used to force the type annotation on the update
definitions below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any updates
definitions below. Which lines are you referring to?
cc198dd
to
8d8d5ac
Compare
@@ -104,7 +105,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): | |||
self.din = din | |||
self.dout = dout | |||
|
|||
@nnx.jit | |||
# @nnx.jit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leftover debugging code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, thanks! removed
PiperOrigin-RevId: 636089212
What does this PR do?
cond
scan
not caching properlyscan
having different keys for the every step for non-split_rngs
keysjit
andscan
(TODO: port the other transforms to the new simplified style).fork
to make it easier to use.RngCount
now has atag
attribute (same asRngKey
), this enable filtering counts as well (needed forscan
).