Skip to content

Commit

Permalink
Remove default independent sampler jitter but ensure positive variance (
Browse files Browse the repository at this point in the history
#888)

* Cap reparam sampling jitter

* minimum not maximum

* Add a few simple tests

* Oops, misunderstood point. Try again

* Float, don't int

* Add tests

* Leave other samplers alone

* Really leave them alone

* Assert model variance is zero in test

* Adress review comments

* Remove outdated test

---------

Co-authored-by: Uri Granta <uri.granta@secondmind.ai>
  • Loading branch information
uri-granta and Uri Granta authored Jan 21, 2025
1 parent 9645d8c commit 2365c50
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 13 deletions.
23 changes: 16 additions & 7 deletions tests/unit/models/gpflow/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,6 @@ def test_reparametrization_sampler_reprs(
)


@pytest.mark.parametrize("qmc", [True, False])
def test_independent_reparametrization_sampler_sample_raises_for_negative_jitter(qmc: bool) -> None:
sampler = IndependentReparametrizationSampler(100, QuadraticMeanAndRBFKernel(), qmc=qmc)
with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
sampler.sample(tf.constant([[0.0]]), jitter=-1e-6)


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("sample_size", [0, -2])
def test_independent_reparametrization_sampler_raises_for_invalid_sample_size(
Expand Down Expand Up @@ -285,6 +278,22 @@ def test_independent_reparametrization_sampler_reset_sampler(qmc: bool, qmc_skip
npt.assert_array_less(1e-9, tf.abs(samples2 - samples1))


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("dtype", [tf.float32, tf.float64])
def test_independent_reparametrization_sampler_sample_ensures_positive_variance(
qmc: bool, dtype: tf.DType
) -> None:
model = QuadraticMeanAndRBFKernel(kernel_amplitude=tf.constant(0, dtype=dtype))
sampler = IndependentReparametrizationSampler(100, model, qmc=qmc)
x = tf.constant([[1.0]], dtype=dtype)
_, model_var = model.predict(x)
npt.assert_array_equal(model_var, tf.constant([[0]]))
variance = tf.math.reduce_variance(sampler.sample(x)) # default jitter
assert variance > 0
variance = tf.math.reduce_variance(sampler.sample(x, jitter=-15)) # explicit negative jitter
assert variance > 0


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("sample_size", [0, -2])
def test_batch_reparametrization_sampler_raises_for_invalid_sample_size(
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
LocalizedTag,
Ok,
Timer,
ensure_positive,
flatten_leading_dims,
get_value_for_tag,
jit,
Expand Down Expand Up @@ -222,3 +223,28 @@ def test_flatten_leading_dims_invalid_output_dims(output_dims: int) -> None:
x_old = tf.random.uniform([2, 3, 4, 5]) # [2, 3, 4, 5]
with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
flatten_leading_dims(x_old, output_dims=output_dims)


@pytest.mark.parametrize(
"t, expected",
[
(
tf.constant(0, dtype=tf.float32),
tf.constant(1e-15, dtype=tf.float32),
),
(
tf.constant(0, dtype=tf.float64),
tf.constant(1e-30, dtype=tf.float64),
),
(
tf.constant([[-1.0, 0.0], [1e-35, 1.0]], dtype=tf.float32),
tf.constant([[1e-15, 1e-15], [1e-15, 1.0]], dtype=tf.float32),
),
(
tf.constant([[-1.0, 0.0], [1e-35, 1.0]], dtype=tf.float64),
tf.constant([[1e-30, 1e-30], [1e-30, 1.0]], dtype=tf.float64),
),
],
)
def test_ensure_positive(t: TensorType, expected: TensorType) -> None:
npt.assert_array_equal(ensure_positive(t), expected)
2 changes: 1 addition & 1 deletion tests/util/acquisition/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy
"""
:param at: Batches of query points at which to sample the predictive distribution, with
shape `[..., B, D]`, for batches of size `B` of points of dimension `D`.
:param jitter: placeholder
:param jitter: unused
:return: The samples, of shape `[..., S, B, L]`, where `S` is the `sample_size`, `B` the
number of points per batch, and `L` the dimension of the model's predictive
distribution.
Expand Down
10 changes: 5 additions & 5 deletions trieste/models/gpflow/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ...space import EncoderFunction
from ...types import TensorType
from ...utils import DEFAULTS, flatten_leading_dims
from ...utils.misc import ensure_positive
from ..interfaces import (
ProbabilisticModel,
ReparametrizationSampler,
Expand Down Expand Up @@ -114,7 +115,7 @@ def __init__(
"at: [N..., 1, D] # IndependentReparametrizationSampler only supports batch sizes of one",
"return: [N..., S, 1, L]",
)
def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorType:
def sample(self, at: TensorType, *, jitter: float = -1.0) -> TensorType:
"""
Return approximate samples from the `model` specified at :meth:`__init__`. Multiple calls to
:meth:`sample`, for any given :class:`IndependentReparametrizationSampler` and ``at``, will
Expand All @@ -124,16 +125,15 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy
:param at: Where to sample the predictive distribution, with shape `[..., 1, D]`, for points
of dimension `D`.
:param jitter: The size of the jitter to use when stabilising the Cholesky decomposition of
the covariance matrix.
the covariance matrix. If a negative value is passed then no jitter is applied but
all values are capped to a hardcoded minimum.
:return: The samples, of shape `[..., S, 1, L]`, where `S` is the `sample_size` and `L` is
the number of latent model dimensions.
:raise ValueError (or InvalidArgumentError): If ``at`` has an invalid shape or ``jitter``
is negative.
"""
tf.debugging.assert_greater_equal(jitter, 0.0)

mean, var = self._model.predict(at[..., None, :, :]) # [..., 1, 1, L], [..., 1, 1, L]
var = var + jitter
var = ensure_positive(var) if jitter < 0 else (var + jitter)

def sample_eps() -> tf.Tensor:
self._initialized.assign(True)
Expand Down
6 changes: 6 additions & 0 deletions trieste/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,9 @@ def _flatten_module( # type: ignore[no-untyped-def]
for subvalue in subvalues:
# Predicate is already tested for these values.
yield subvalue


def ensure_positive(x: TensorType) -> TensorType:
"""Ensure that all the elements in `x` are strictly positive (using a dtype-dependent
capping threshold)."""
return tf.math.maximum(x, 1e-15 if x.dtype == tf.float32 else 1e-30)

0 comments on commit 2365c50

Please sign in to comment.