Skip to content

Commit

Permalink
Only include arg_constraints in pytree_data_fields if they are no…
Browse files Browse the repository at this point in the history
…t `lazy_property`s. (#1929)
  • Loading branch information
tillahoffmann authored Dec 5, 2024
1 parent 608a2c3 commit b74c0e9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
18 changes: 11 additions & 7 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,20 @@ def __init_subclass__(cls, **kwargs):
def gather_pytree_data_fields(cls):
bases = inspect.getmro(cls)

all_pytree_data_fields = ()
all_pytree_data_fields = set()
for base in bases:
if issubclass(base, Distribution):
all_pytree_data_fields += base.__dict__.get(
"pytree_data_fields",
tuple(base.__dict__.get("arg_constraints", {}).keys()),
all_pytree_data_fields.update(
base.__dict__.get(
"pytree_data_fields",
tuple(
arg
for arg in base.__dict__.get("arg_constraints", {})
if not isinstance(getattr(cls, arg, None), lazy_property)
),
)
)
# remove duplicates
all_pytree_data_fields = tuple(set(all_pytree_data_fields))
return all_pytree_data_fields
return tuple(all_pytree_data_fields)

@classmethod
def gather_pytree_aux_fields(cls) -> tuple:
Expand Down
13 changes: 13 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3521,3 +3521,16 @@ def test_gaussian_random_walk_state_space_equivalence():
assert jnp.allclose(x1, jnp.squeeze(x2, axis=-1))

assert jnp.allclose(d1.log_prob(x1), d2.log_prob(x2))


def test_consistent_pytree() -> None:
def make_dist():
return dist.MultivariateNormal(precision_matrix=jnp.eye(2))

init = make_dist()
# Access the covariance matrix to evaluate the lazy property.
init.covariance_matrix
assert "covariance_matrix" in init.__dict__

# Run scan which validates that pytree structures are consistent.
jax.lax.scan(lambda *_: (make_dist(), None), init, jnp.arange(7))

0 comments on commit b74c0e9

Please sign in to comment.