-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] Rng Variable tags #3807
[nnx] Rng Variable tags #3807
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
3a66387
to
9869b79
Compare
f08c285
to
25c0e68
Compare
25c0e68
to
1d5f718
Compare
|
||
def __call__(self, path: PathParts, x: tp.Any): | ||
return self.str_key in path | ||
return isinstance(x, _HasTag) and x.tag == self.tag |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can simply use hasattr()
here?
class AtPath: | ||
str_key: str | ||
class WithTag: | ||
tag: str | ||
|
||
def __call__(self, path: PathParts, x: tp.Any): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mypy will not type check this unless you add -> bool
IIRC.
|
||
def __contains__(self, name: tp.Any) -> bool: | ||
return name in self._rngs | ||
return name in vars(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to special-case "_graph_node__state"
here for consistency with other methods?
|
||
def __len__(self) -> int: | ||
return len(self._rngs) | ||
return len(vars(self)) - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "_graph_node__state"
always present? If not, you should probably do - ('_graph_node__state' in vars(self))
.
|
||
def replace(self, **kwargs: tp.Union[int, jax.Array, RngStream]) -> 'Rngs': | ||
rngs: dict[str, tp.Any] = self._rngs.copy() | ||
rngs: dict[str, tp.Any] = vars(self).copy() | ||
del rngs['_graph_node__state'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unfortunate that the implementation detail of GraphNode leaks into the subclasses. Can we do better?
What does this PR do?
This PR adds the necessary changes so we can filter
RngState
Variables based on the stream name usingFilter
s, this allows you to manage RNG state a regular graph state but using the previous name-based filter convention.Changes
RngKey
andRngCount
Variable types that inherit fromRngState
.RngKey
is expected to have atag: str
metadata attribute.RngStream.key
is now aRngKey
andRngStream.count
is now aRngCount
.Rngs
setsRngKey.tag
as the stream's name.str
filters now create aWithTag
filter that selects Variables with atag
attribute that matches that string.Rngs._rngs
,RngStreams
are now stored as attributes ofRngs
.RngStream.make_rng
to__call__
.Rngs.__getattr__
andRngs.__getitem__
now returns aRngStream
.