Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] jit accepts many Modules #3783

Merged
merged 1 commit into from
Mar 28, 2024
Merged

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 22, 2024

What does this PR do?

  • nnx.jit now accepts multiple Modules.
  • Added donate_object_state argument to donate the buffers of all input graph nodes.

@cgarciae cgarciae force-pushed the nnx-jit-accepts-many-Modules branch from b903c11 to 93bafcb Compare March 24, 2024 10:34
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae cgarciae changed the base branch from nnx-jit-new-update to nnx-leaves-are-leaves March 24, 2024 10:36
@cgarciae cgarciae force-pushed the nnx-jit-accepts-many-Modules branch 10 times, most recently from 227e6d8 to 40b2979 Compare March 28, 2024 10:36
@cgarciae cgarciae changed the base branch from nnx-leaves-are-leaves to main March 28, 2024 10:47

def _maybe_extract(x):
if is_graph_node(x):
if x in nodes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI you can do this in a single line:

index = nodes.setdefault(x, len(nodes)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clever

cls,
in_shardings: tp.Any,
out_shardings: tp.Any,
static_argnums: int | tp.Sequence[int] | None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why static_argunums is a Sequence, but static_argnames can be an Iterable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this form jax.jit, maybe it has to support repeated iteration so they request sequence?

abstracted_axes: tp.Optional[tp.Any],
donate_object_state: bool,
):
_static_argnums: tuple[int, ...]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe have a local function which packs T | tp.Iterable[T] | None into tuple[T, ...] and use it for all the processing below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds like a good idea

@@ -302,7 +411,9 @@ def _submodule(self) -> M:
def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> Any:
self.accessor = accessor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly related to the PR, but it looks like this doesn't allow calling the module from multiple threads concurrently?

static_argnums: tp.Union[int, tp.Sequence[int], None] = None,
static_argnames: tp.Union[str, tp.Iterable[str], None] = None,
donate_argnums: tp.Union[int, tp.Sequence[int]] = (),
static_argnums: int | tp.Sequence[int] | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update this if you change the type annotations above.

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 97.40260% with 6 lines in your changes are missing coverage. Please review.

Project coverage is 60.17%. Comparing base (6f1f1ef) to head (e452f5a).

Files Patch % Lines
flax/experimental/nnx/nnx/transforms.py 91.42% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3783      +/-   ##
==========================================
+ Coverage   59.62%   60.17%   +0.55%     
==========================================
  Files         101      101              
  Lines       12655    12840     +185     
==========================================
+ Hits         7546     7727     +181     
- Misses       5109     5113       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@copybara-service copybara-service bot merged commit fe99574 into main Mar 28, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-jit-accepts-many-Modules branch March 28, 2024 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants