-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] jit constrain object state #3817
Conversation
0cd5f38
to
1ee69d6
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #3817 +/- ##
==========================================
+ Coverage 61.05% 61.25% +0.20%
==========================================
Files 105 105
Lines 13100 13167 +67
==========================================
+ Hits 7998 8066 +68
+ Misses 5102 5101 -1 ☔ View full report in Codecov by Sentry. |
1ee69d6
to
a88fbbb
Compare
donate_object_state: bool | ||
constrain_object_state: tp.Callable[[State], State] | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please document what this field does either here or elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -138,6 +138,11 @@ class JitStaticOutputs: | |||
|
|||
jax.tree_util.register_static(JitStaticOutputs) | |||
|
|||
def _default_constrain_object_state(state: State) -> State: | |||
state_spec = spmd.get_partition_spec(state) | |||
state = jax.lax.with_sharding_constraint(state, state_spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is WSC a noop if there is no mesh?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, it will crash.
a88fbbb
to
91ee438
Compare
91ee438
to
8192b96
Compare
What does this PR do?
Adds a
constrain_object_state
argument tonnx.jit
which can be either aCallable[[State], State]
(a function creating the new constrained State, or abool
. WhenTrue
theState
will be constrained usingnnx.get_partition_spec
passed tojax.lax.with_sharding_constraint
.