Skip to content

Commit

Permalink
Replace uses of jnp.array in types with jnp.ndarray.
Browse files Browse the repository at this point in the history
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`.

PiperOrigin-RevId: 559395727
  • Loading branch information
hawkinsp authored and copybara-github committed Aug 23, 2023
1 parent 36a79cb commit 52786e2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions haiku/_src/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,15 +661,15 @@ def __init__(self, name: Optional[str] = None):
# Support @dataclass annotated modules.
__post_init__ = __init__

def params_dict(self) -> Mapping[str, jnp.array]:
def params_dict(self) -> Mapping[str, jnp.ndarray]:
"""Returns parameters keyed by name for this module and submodules."""
if not base.frame_stack:
raise ValueError(
"`module.params_dict()` must be used as part of an `hk.transform`.")

return params_or_state_dict(self.module_name, self._submodules, "params")

def state_dict(self) -> Mapping[str, jnp.array]:
def state_dict(self) -> Mapping[str, jnp.ndarray]:
"""Returns state keyed by name for this module and submodules."""
if not base.frame_stack:
raise ValueError(
Expand All @@ -682,7 +682,7 @@ def params_or_state_dict(
module_name: str,
submodules: Set[str],
which: str,
) -> Mapping[str, jnp.array]:
) -> Mapping[str, jnp.ndarray]:
"""Returns module parameters or state for the given module or submodules."""
assert which in ("params", "state")
out = {}
Expand Down

0 comments on commit 52786e2

Please sign in to comment.