-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] jit accepts many Modules #3783
Conversation
b903c11
to
93bafcb
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
227e6d8
to
40b2979
Compare
|
||
def _maybe_extract(x): | ||
if is_graph_node(x): | ||
if x in nodes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI you can do this in a single line:
index = nodes.setdefault(x, len(nodes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clever
cls, | ||
in_shardings: tp.Any, | ||
out_shardings: tp.Any, | ||
static_argnums: int | tp.Sequence[int] | None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why static_argunums
is a Sequence
, but static_argnames
can be an Iterable
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied this form jax.jit
, maybe it has to support repeated iteration so they request sequence?
abstracted_axes: tp.Optional[tp.Any], | ||
donate_object_state: bool, | ||
): | ||
_static_argnums: tuple[int, ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe have a local function which packs T | tp.Iterable[T] | None
into tuple[T, ...]
and use it for all the processing below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds like a good idea
@@ -302,7 +411,9 @@ def _submodule(self) -> M: | |||
def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> Any: | |||
self.accessor = accessor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not directly related to the PR, but it looks like this doesn't allow calling the module from multiple threads concurrently?
static_argnums: tp.Union[int, tp.Sequence[int], None] = None, | ||
static_argnames: tp.Union[str, tp.Iterable[str], None] = None, | ||
donate_argnums: tp.Union[int, tp.Sequence[int]] = (), | ||
static_argnums: int | tp.Sequence[int] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update this if you change the type annotations above.
40b2979
to
e452f5a
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3783 +/- ##
==========================================
+ Coverage 59.62% 60.17% +0.55%
==========================================
Files 101 101
Lines 12655 12840 +185
==========================================
+ Hits 7546 7727 +181
- Misses 5109 5113 +4 ☔ View full report in Codecov by Sentry. |
What does this PR do?
nnx.jit
now accepts multiple Modules.donate_object_state
argument to donate the buffers of all input graph nodes.