Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] jit constrain object state #3817

Merged
merged 1 commit into from
Apr 12, 2024
Merged

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Apr 2, 2024

What does this PR do?

Adds a constrain_object_state argument to nnx.jit which can be either a Callable[[State], State] (a function creating the new constrained State, or a bool. When True the State will be constrained using nnx.get_partition_spec passed to jax.lax.with_sharding_constraint.

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 61.25%. Comparing base (bc93cb1) to head (1ee69d6).
Report is 4 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3817      +/-   ##
==========================================
+ Coverage   61.05%   61.25%   +0.20%     
==========================================
  Files         105      105              
  Lines       13100    13167      +67     
==========================================
+ Hits         7998     8066      +68     
+ Misses       5102     5101       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@cgarciae cgarciae marked this pull request as ready for review April 12, 2024 09:44
donate_object_state: bool
constrain_object_state: tp.Callable[[State], State] | None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document what this field does either here or elsewhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -138,6 +138,11 @@ class JitStaticOutputs:

jax.tree_util.register_static(JitStaticOutputs)

def _default_constrain_object_state(state: State) -> State:
state_spec = spmd.get_partition_spec(state)
state = jax.lax.with_sharding_constraint(state, state_spec)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is WSC a noop if there is no mesh?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it will crash.

@copybara-service copybara-service bot merged commit 2438518 into main Apr 12, 2024
12 of 21 checks passed
@copybara-service copybara-service bot deleted the nnx-jit-apply-sharding branch April 12, 2024 17:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants