Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Rng Variable tags #3807

Merged
merged 1 commit into from
Apr 2, 2024
Merged

[nnx] Rng Variable tags #3807

merged 1 commit into from
Apr 2, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 31, 2024

What does this PR do?

This PR adds the necessary changes so we can filter RngState Variables based on the stream name using Filters, this allows you to manage RNG state a regular graph state but using the previous name-based filter convention.

Changes

  • Adds RngKey and RngCount Variable types that inherit from RngState. RngKey is expected to have atag: str metadata attribute.
  • RngStream.key is now a RngKey and RngStream.count is now a RngCount. Rngs sets RngKey.tag as the stream's name.
  • str filters now create a WithTag filter that selects Variables with a tag attribute that matches that string.
  • Removes Rngs._rngs, RngStreams are now stored as attributes of Rngs.
  • Refactors RngStream.make_rng to __call__.
  • Rngs.__getattr__ and Rngs.__getitem__ now returns a RngStream.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae cgarciae changed the title [nnx] Rngs and RngStream inherit from GraphNode [nnx] Rng Variable tags Mar 31, 2024
@cgarciae cgarciae force-pushed the nnx-rng-tags branch 6 times, most recently from 3a66387 to 9869b79 Compare April 1, 2024 08:55
Base automatically changed from nnx-rngs-are-nodes to main April 1, 2024 13:02
@cgarciae cgarciae changed the base branch from main to nnx-graph-non-str-keys April 1, 2024 17:41
Base automatically changed from nnx-graph-non-str-keys to main April 1, 2024 22:45
@cgarciae cgarciae marked this pull request as ready for review April 1, 2024 22:58
@copybara-service copybara-service bot merged commit 1579586 into main Apr 2, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-rng-tags branch April 2, 2024 17:24

def __call__(self, path: PathParts, x: tp.Any):
return self.str_key in path
return isinstance(x, _HasTag) and x.tag == self.tag
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can simply use hasattr() here?

class AtPath:
str_key: str
class WithTag:
tag: str

def __call__(self, path: PathParts, x: tp.Any):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy will not type check this unless you add -> bool IIRC.


def __contains__(self, name: tp.Any) -> bool:
return name in self._rngs
return name in vars(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to special-case "_graph_node__state" here for consistency with other methods?


def __len__(self) -> int:
return len(self._rngs)
return len(vars(self)) - 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "_graph_node__state" always present? If not, you should probably do - ('_graph_node__state' in vars(self)).


def replace(self, **kwargs: tp.Union[int, jax.Array, RngStream]) -> 'Rngs':
rngs: dict[str, tp.Any] = self._rngs.copy()
rngs: dict[str, tp.Any] = vars(self).copy()
del rngs['_graph_node__state']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unfortunate that the implementation detail of GraphNode leaks into the subclasses. Can we do better?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants