Skip to content

Commit

Permalink
Merge pull request #354 from gchq/fix/batch-size
Browse files Browse the repository at this point in the history
Fix batch size
  • Loading branch information
db091756 authored Aug 19, 2024
2 parents 4645fa0 + 6230628 commit 6e85e7e
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 40 deletions.
65 changes: 31 additions & 34 deletions tests/integration/test_base_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,53 @@
"""

import unittest
from typing import Optional

import numpy as np
import pytest

from tests.cases import get_default_rng
from vanguard.kernels import ScaledRBFKernel
from vanguard.vanilla import GaussianGPController


class VanguardTestCase(unittest.TestCase):
"""
A subclass of TestCase designed to check end-to-end usage of base code.
"""

def setUp(self) -> None:
"""
Define data shared across tests.
"""
self.rng = get_default_rng()
self.num_train_points = 500
self.num_test_points = 500
self.n_sgd_iters = 100
self.small_noise = 0.1
self.confidence_interval_alpha = 0.9
# When generating confidence intervals, how far from the expected number of
# points must we empirically observe to be we willing to consider a test a
# failure? As an example, if we have 90% confidence interval, we might expect
# 10% of points to lie outside of this, 5% above and 5% below if everything is
# symmetric. However, we expect some noise due to errors and finite datasets, so
# we would only consider the test a failure if more than
# 5% + accepted_confidence_interval_error lie above the upper confidence
# interval
self.accepted_confidence_interval_error = 3
self.expected_percent_outside_one_sided = (100.0 * (1 - self.confidence_interval_alpha)) / 2

def test_basic_gp(self) -> None:
class TestBaseUsage:
num_train_points = 500
num_test_points = 500
n_sgd_iters = 100
small_noise = 0.1
confidence_interval_alpha = 0.9
# When generating confidence intervals, how far from the expected number of
# points must we empirically observe to be we willing to consider a test a
# failure? As an example, if we have 90% confidence interval, we might expect
# 10% of points to lie outside of this, 5% above and 5% below if everything is
# symmetric. However, we expect some noise due to errors and finite datasets, so
# we would only consider the test a failure if more than
# 5% + accepted_confidence_interval_error lie above the upper confidence
# interval
accepted_confidence_interval_error = 3
expected_percent_outside_one_sided = (100.0 * (1 - confidence_interval_alpha)) / 2

@pytest.mark.parametrize("batch_size", [None, 100])
def test_basic_gp(self, batch_size: Optional[int]) -> None:
"""
Verify Vanguard usage on a simple, single variable regression problem.
We generate a single feature `x` and a continuous target `y`, and verify that a
GP can be fit to this data. We check that the confidence intervals are ordered
correctly, and they contain the expected number of points in both the training
and testing data.
We test this both in and out of batch mode.
"""
# Define some data for the test
rng = get_default_rng()

x = np.linspace(start=0, stop=10, num=self.num_train_points + self.num_test_points).reshape(-1, 1)
y = np.squeeze(x * np.sin(x))

# Split data into training and testing
train_indices = self.rng.choice(np.arange(y.shape[0]), size=self.num_train_points, replace=False)
train_indices = rng.choice(np.arange(y.shape[0]), size=self.num_train_points, replace=False)
test_indices = np.setdiff1d(np.arange(y.shape[0]), train_indices)

# Define the controller object, with an assumed small amount of noise
Expand All @@ -74,7 +72,8 @@ def test_basic_gp(self) -> None:
train_y=y[train_indices],
kernel_class=ScaledRBFKernel,
y_std=self.small_noise * np.ones_like(y[train_indices]),
rng=self.rng,
rng=rng,
batch_size=batch_size,
)

# Fit the GP
Expand All @@ -88,8 +87,8 @@ def test_basic_gp(self) -> None:
)

# Sense check the outputs
self.assertTrue(np.all(prediction_means <= prediction_ci_upper))
self.assertTrue(np.all(prediction_means >= prediction_ci_lower))
assert np.all(prediction_means <= prediction_ci_upper)
assert np.all(prediction_means >= prediction_ci_lower)

# Are the prediction intervals reasonable?
pct_above_ci_upper_train = (
Expand All @@ -110,9 +109,7 @@ def test_basic_gp(self) -> None:
pct_below_ci_lower_train,
pct_below_ci_lower_test,
]:
self.assertLessEqual(
pct_check, self.expected_percent_outside_one_sided + self.accepted_confidence_interval_error
)
assert pct_check <= self.expected_percent_outside_one_sided + self.accepted_confidence_interval_error


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions tests/units/base/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,6 @@ def test_error_handling_of_higher_rank_features(self) -> None:
with self.assertRaisesRegex(ValueError, expected_regex):
gp.fit()

@unittest.skip # TODO: fix test; underlying issues in batch mode
# https://github.com/gchq/Vanguard/issues/265
def test_error_handling_of_batch_size(self) -> None:
"""Test that a UserWarning is raised when both batch_size and gradient_every are not None."""
gp = GaussianGPController(
Expand All @@ -183,6 +181,7 @@ def test_error_handling_of_batch_size(self) -> None:
kernel_class=PeriodicRBFKernel,
y_std=self.DATASET.train_y_std,
batch_size=20,
rng=get_default_rng(),
)
gradient_every = 2
gp.fit()
Expand Down
18 changes: 16 additions & 2 deletions vanguard/base/basecontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,17 @@ class SafeGPModelClass(self.gp_model_class):
class SafeMarginalLogLikelihoodClass(marginal_log_likelihood_class):
pass

if self.batch_size is not None:
# then the training data will be updated at each training iteration
gp_train_x = None
gp_train_y = None
else:
gp_train_x = self.train_x
gp_train_y = self.train_y.squeeze(dim=-1)

self._gp = SafeGPModelClass(
self.train_x,
self.train_y.squeeze(dim=-1),
gp_train_x,
gp_train_y,
covar_module=self.kernel,
likelihood=self.likelihood,
mean_module=self.mean,
Expand Down Expand Up @@ -375,6 +383,12 @@ def _sgd_round(

for iter_num, (train_x, train_y, train_y_noise) in enumerate(islice(self.train_data_generator, n_iters)):
self.likelihood_noise = train_y_noise
if self.batch_size is not None:
# update the training data to the current train_x and train_y, to avoid "You must train on the
# training data!"
self._gp.set_train_data(train_x, train_y.squeeze(dim=-1), strict=False)
# TODO: consider using get_fantasy_model() instead if possible, when using ExactGP?
# https://github.com/gchq/Vanguard/issues/352
try:
loss = self._single_optimisation_step(train_x, train_y, retain_graph=iter_num < n_iters - 1)

Expand Down
2 changes: 0 additions & 2 deletions vanguard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ def shuffle(array: numpy.typing.NDArray) -> None: # pylint: disable=unused-argu

def shuffle(array: numpy.typing.NDArray) -> None:
"""Random shuffle function."""
# TODO: Shuffling when batch_size is not None raises RuntimeError("You must train on the training inputs!")
# https://github.com/gchq/Vanguard/issues/265
rng.shuffle(array)

index = 0
Expand Down

0 comments on commit 6e85e7e

Please sign in to comment.