Skip to content

Commit

Permalink
Merge branch 'pyro-ppl:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
deoxyribose authored Dec 4, 2023
2 parents ce3be2d + b16741c commit 783f9f3
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 8 deletions.
23 changes: 17 additions & 6 deletions examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,28 @@ def run_inference(model, args, rng_key, X, Y):


# do GP prediction for a given set of hyperparameters. this makes use of the well-known
# formula for gaussian process predictions
def predict(rng_key, X, Y, X_test, var, length, noise):
# formula for Gaussian process predictions
def predict(rng_key, X, Y, X_test, var, length, noise, use_cholesky=True):
# compute kernels between train and test data, etc.
k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
k_XX = kernel(X, X, var, length, noise, include_noise=True)
K_xx_inv = jnp.linalg.inv(k_XX)
K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))

# since K_xx is symmetric positive-definite, we can use the more efficient and
# stable Cholesky decomposition instead of matrix inversion
if use_cholesky:
K_xx_cho = jax.scipy.linalg.cho_factor(k_XX)
K = k_pp - jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, k_pX.T))
mean = jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, Y))
else:
K_xx_inv = jnp.linalg.inv(k_XX)
K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))

sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
rng_key, X_test.shape[:1]
)
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))

# we return both the mean function and a sample from the posterior predictive for the
# given set of hyperparameters
return mean, mean + sigma_noise
Expand Down Expand Up @@ -148,7 +158,7 @@ def main(args):
)
means, predictions = vmap(
lambda rng_key, var, length, noise: predict(
rng_key, X, Y, X_test, var, length, noise
rng_key, X, Y, X_test, var, length, noise, use_cholesky=args.use_cholesky
)
)(*vmap_args)

Expand Down Expand Up @@ -184,6 +194,7 @@ def main(args):
type=str,
choices=["median", "feasible", "value", "uniform", "sample"],
)
parser.add_argument("--no-cholesky", dest="use_cholesky", action="store_false")
args = parser.parse_args()

numpyro.set_platform(args.device)
Expand Down
8 changes: 6 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ def gather_pytree_data_fields(cls):
return all_pytree_data_fields

@classmethod
def gather_pytree_aux_fields(cls):
def gather_pytree_aux_fields(cls) -> tuple:
bases = inspect.getmro(cls)

all_pytree_aux_fields = ()
all_pytree_aux_fields = ("_validate_args",)
for base in bases:
if issubclass(base, Distribution):
all_pytree_aux_fields += base.__dict__.get("pytree_aux_fields", ())
Expand Down Expand Up @@ -203,11 +203,15 @@ def tree_unflatten(cls, aux_data, params):
for k, v in pytree_aux_fields_dict.items():
setattr(d, k, v)

# disable args validation during `tree_unflatten` it is called by jax with
# placeholder attributes that would make validation fail
d._validate_args = False
Distribution.__init__(
d,
pytree_aux_fields_dict["_batch_shape"],
pytree_aux_fields_dict["_event_shape"],
)
d._validate_args = pytree_aux_fields_dict["_validate_args"]
return d

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"tpu": f"jax[tpu]{_jax_version_constraints}",
"cuda": f"jax[cuda]{_jax_version_constraints}",
},
python_requires=">=3.9",
long_description=long_description,
long_description_content_type="text/markdown",
keywords="probabilistic machine learning bayesian statistics",
Expand Down
15 changes: 15 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3092,6 +3092,21 @@ def sample(d: dist.Distribution):
assert samples_batched_dist.shape == (1, *samples_dist.shape)


def test_vmap_validate_args():
# Test for #1684: vmapping distributions whould work when `validate_args=True`
v_dist = jax.vmap(
lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=True),
in_axes=(0, 0),
)(jnp.zeros((2,)), jnp.zeros((2,)))

# non-regression test
v_dist = jax.vmap(
lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=False),
in_axes=(0, 0),
)(jnp.zeros((2,)), jnp.zeros((2,)))
assert not v_dist._validate_args


def test_multinomial_abstract_total_count():
probs = jnp.array([0.2, 0.5, 0.3])
key = random.PRNGKey(0)
Expand Down

0 comments on commit 783f9f3

Please sign in to comment.