Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert_tree_shape_prefix requires tuple instead of Sequence #274

Closed
hylkedonker opened this issue May 30, 2023 · 2 comments
Closed

assert_tree_shape_prefix requires tuple instead of Sequence #274

hylkedonker opened this issue May 30, 2023 · 2 comments

Comments

@hylkedonker
Copy link

Hi,

Thanks for sharing this nice package! I especially like all the assertions for pytrees.
However, I came across the following inconsistency in the documentation:
Current situation
According to the docs, the shape_prefix argument of assert_tree_shape_prefix is of type Sequence[int].
However, when I pass sequence such as list (instead of a tuple)

import chex
import jax.numpy as jnp

mytree = {'a': jnp.array([[[1], [2]]])}
chex.assert_tree_shape_prefix(mytree, shape_prefix=[1, 2])  # AssertionError!

The assertion raises an exception:

AssertionError: [Chex] Assertion assert_tree_shape_prefix failed: Tree leaf 'a' has a shape prefix different from expected: (1, 2) != [1, 2].

The error can simply be fixed by using a tuple instead:

chex.assert_tree_shape_prefix(mytree, shape_prefix=(1, 2))  # OK!

I think this is not the only inconsistent function, but I did not check for others.

Desired situation
Ideally, I would like the function to behave like in the docs, so that I can also pass a list. Why? I think being able to choose square brackets (i.e., list) after closing parenthesis helps readability.

I would be interested to hear your opinion and I am happy to contribute a pull request.

Keep up the good work!

Hylke

@stompchicken
Copy link
Collaborator

Hi Hylke, thanks for raising the issue! I agree with you that assert_tree_shape_prefix ought to work with sequences, as it claims to in the type annotation. If you would like to contribute a pull request, that would be great.

I think the best thing to do would be to convert the shape_prefix argument to a tuple inside assert_tree_shape_prefix before the assert_fn is defined. Ideally you would also add a unit test :)

@hylkedonker
Copy link
Author

Thanks for your prompt response.
I'Il try to give it a go the coming days. :-)

hylkedonker pushed a commit to hylkedonker/chex that referenced this issue May 30, 2023
hylkedonker pushed a commit to hylkedonker/chex that referenced this issue May 30, 2023
hylkedonker pushed a commit to hylkedonker/chex that referenced this issue Jun 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants