Skip to content

Commit

Permalink
Allow to use Delta on numpy arrays without moving them to jax devices (
Browse files Browse the repository at this point in the history
…#1777)

* allow to use Delta on numpy arrays without moving them to devices

* fix lint

* fix test
  • Loading branch information
fehiepsi authored Apr 12, 2024
1 parent b2cee89 commit e84f004
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
2 changes: 2 additions & 0 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,8 @@ def support(self):
return constraints.independent(constraints.real, self.event_dim)

def sample(self, key, sample_shape=()):
if not sample_shape:
return self.v
shape = sample_shape + self.batch_shape + self.event_shape
return jnp.broadcast_to(self.v, shape)

Expand Down
11 changes: 6 additions & 5 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,15 +1102,16 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):


@pytest.mark.parametrize(
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
"jax_dist_cls, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
@pytest.mark.parametrize("prepend_shape", [(), (2,), (2, 3)])
def test_dist_shape(jax_dist, sp_dist, params, prepend_shape):
jax_dist = jax_dist(*params)
def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape):
jax_dist = jax_dist_cls(*params)
rng_key = random.PRNGKey(0)
expected_shape = prepend_shape + jax_dist.batch_shape + jax_dist.event_shape
samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
assert isinstance(samples, jnp.ndarray)
if jax_dist_cls is not dist.Delta:
assert isinstance(samples, jnp.ndarray)
assert jnp.shape(samples) == expected_shape
if (
sp_dist
Expand Down Expand Up @@ -2620,7 +2621,7 @@ def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape):
rng_key = random.PRNGKey(0)
samples = expanded_dist.sample(rng_key, sample_shape)
assert expanded_dist.batch_shape == new_batch_shape
assert samples.shape == sample_shape + new_batch_shape + jax_dist.event_shape
assert jnp.shape(samples) == sample_shape + new_batch_shape + jax_dist.event_shape
assert expanded_dist.log_prob(samples).shape == sample_shape + new_batch_shape
# test expand of expand
assert (
Expand Down

0 comments on commit e84f004

Please sign in to comment.