diff --git a/vizier/_src/jax/models/gaussian_process_model.py b/vizier/_src/jax/models/gaussian_process_model.py index bce9a6afc..6bdbd47c6 100644 --- a/vizier/_src/jax/models/gaussian_process_model.py +++ b/vizier/_src/jax/models/gaussian_process_model.py @@ -124,7 +124,7 @@ def __call__( index_points=inputs, observation_noise_variance=observation_noise_variance, validate_args=self._use_tfp_runtime_validation, - always_yield_multivariate_normal=True) + ) class GaussianProcessARDWithCategorical(sp_model.ModelCoroutine): @@ -257,5 +257,4 @@ def __call__( index_points=inputs, observation_noise_variance=observation_noise_variance, validate_args=self._use_tfp_runtime_validation, - always_yield_multivariate_normal=True, ) diff --git a/vizier/_src/jax/models/hebo_gp_model.py b/vizier/_src/jax/models/hebo_gp_model.py index cd38a1984..c8a28604c 100644 --- a/vizier/_src/jax/models/hebo_gp_model.py +++ b/vizier/_src/jax/models/hebo_gp_model.py @@ -143,4 +143,4 @@ def _inverse_constraint_fn(f): index_points=inputs, observation_noise_variance=observation_noise_variance, cholesky_fn=None, - always_yield_multivariate_normal=True) + ) diff --git a/vizier/_src/jax/models/tuned_gp_models.py b/vizier/_src/jax/models/tuned_gp_models.py index 674d05c36..f1a45f0fa 100644 --- a/vizier/_src/jax/models/tuned_gp_models.py +++ b/vizier/_src/jax/models/tuned_gp_models.py @@ -174,5 +174,4 @@ def __call__( index_points=inputs, observation_noise_variance=observation_noise_variance, cholesky_fn=cholesky_fn, - always_yield_multivariate_normal=True, )