Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: add analytic kernel herding tests. #794

Merged
merged 3 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions documentation/source/examples/analytical_kernel_herding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ and choose a ``length_scale`` of :math:`\frac{1}{\sqrt{2}}` to simplify computat
with the ``SquaredExponentialKernel``, in particular it becomes:

.. math::
k(x, y) = e^{-||x - y||^2}

in this example.
k(x, y) = e^{-||x - y||^2}.

Kernel herding should do as follows:
- Compute the Gramian row mean, that is for each data-point :math:`x` and all other
Expand Down Expand Up @@ -142,3 +140,87 @@ This example would be run in coreax using:
print(coreset.unweighted_indices) # The coreset_indices
print(coreset.coreset.data) # The data-points in the coreset
print(solver_state.gramian_row_mean) # The stored gramian_row_mean

Coreax also supports weighted data. If we have the same data as described above, but
weights of:

.. math::
w = \begin{pmatrix}
0.8 \\
0.1 \\
0.1
\end{pmatrix}

we would expect a different resulting coreset. The computation of the gramian
row mean, :math:`\mathbb{E}[k(x, x')]`, becomes:

.. math::
\mathbb{E}[k(x, x')] = \begin{pmatrix}
0.8 \cdot k([0.3, 0.25]', [0.3, 0.25]') + 0.1 \cdot k([0.3, 0.25]', [0.4, 0.2]') + 0.1 \cdot k([0.3, 0.25]', [0.5, 0.125]') \\
0.8 \cdot k([0.4, 0.2]', [0.3, 0.25]') + 0.1 \cdot k([0.4, 0.2]', [0.4, 0.2]') + 0.1 \cdot k([0.4, 0.2]', [0.5, 0.125]') \\
0.8 \cdot k([0.5, 0.125]', [0.3, 0.25]') + 0.1 \cdot k([0.5, 0.125]', [0.4, 0.2]') + 0.1 \cdot k([0.5, 0.125]', [0.5, 0.125]')
\end{pmatrix}

resulting in:

.. math::
\mathbb{E}[k(x, x')] = \begin{pmatrix}
0.9933471580051769 \\
0.988511884095646 \\
0.9551646673468503
\end{pmatrix}

The largest value in this array is 0.9933471580051769, so we expect the first coreset
point to be [0.3 0.25], that is the data-point at index 0 in the dataset. At this point
we have ``coreset_indices`` as [0, ?].

We then compute the penalty update term
:math:`\frac{1}{T+1}\sum_{t=1}^T k(x, x_t)` with :math:`T = 1` and get:

.. math::
\frac{1}{T+1}\sum_{t=1}^T k(x, x_t) = \begin{pmatrix}
0.5 \\
0.4937889002469407 \\
0.4729468897789434
\end{pmatrix}

Finally, we select the next coreset point to maximise:

.. math::
\mathbb{E}[k(x, x')] - \frac{1}{T+1}\sum_{t=1}^T k(x, x_t) = \begin{pmatrix}
0.4933471580051769 \\
0.49472298384870533 \\
0.48221777756790696
\end{pmatrix}

which means our final ``coreset_indices`` should be [0, 1]. In coreax, this example
would be run as:

.. code-block::

from coreax import Data, SquaredExponentialKernel, KernelHerding
import equinox as eqx

# Define the data
coreset_size = 2
length_scale = 1.0 / jnp.sqrt(2)
x = jnp.array([
[0.3, 0.25],
[0.4, 0.2],
[0.5, 0.125],
])
weights = jnp.array([0.8, 0.1, 0.1])

# Define a kernel
kernel = SquaredExponentialKernel(length_scale=length_scale)

# Generate the coreset, using equinox to JIT compile the code and speed up
# generation for larger datasets
data = Data(x, weights=weights)
solver = KernelHerding(coreset_size=coreset_size, kernel=kernel, unique=True)
coreset, solver_state = eqx.filter_jit(solver.reduce)(data)

# Inspect results
print(coreset.unweighted_indices) # The coreset_indices
print(coreset.coreset.data) # The data-points in the coreset
print(solver_state.gramian_row_mean) # The stored gramian_row_mean
167 changes: 161 additions & 6 deletions tests/unit/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,7 @@ def test_kernel_herding_analytic_unique(self) -> None:
computations with the ``SquaredExponentialKernel``, in particular it becomes:

.. math::
k(x, y) = e^{-||x - y||^2}

in this example.
k(x, y) = e^{-||x - y||^2}.

Kernel herding should do as follows:
- Compute the Gramian row mean, that is for each data-point :math:`x` and
Expand Down Expand Up @@ -698,9 +696,7 @@ def test_kernel_herding_analytic_not_unique(self) -> None:
computations with the ``SquaredExponentialKernel``, in particular it becomes:

.. math::
k(x, y) = e^{-||x - y||^2}

in this example.
k(x, y) = e^{-||x - y||^2}.

Kernel herding should do as follows:
- Compute the Gramian row mean, that is for each data-point :math:`x` and
Expand Down Expand Up @@ -833,6 +829,165 @@ def test_kernel_herding_analytic_not_unique(self) -> None:
solver_state.gramian_row_mean, expected_gramian_row_mean
)

def test_kernel_herding_analytic_unique_weighted_data(self) -> None:
# pylint: disable=line-too-long
r"""
Test kernel herding on a weighted analytical example, with a unique coreset.

In this example, we have data of:

.. math::
x = \begin{pmatrix}
0.3 & 0.25 \\
0.4 & 0.2 \\
0.5 & 0.125
\end{pmatrix}

with weights:

.. math::
w = \begin{pmatrix}
0.8 \\
0.1 \\
0.1
\end{pmatrix}

and choose a ``length_scale`` of :math:`\frac{1}{\sqrt{2}}` to simplify
computations with the ``SquaredExponentialKernel``, in particular it becomes:

.. math::
k(x, y) = e^{-||x - y||^2}.

# TODO: Update from here
qh681248 marked this conversation as resolved.
Show resolved Hide resolved

Kernel herding should do as follows:
- Compute the Gramian row mean, that is for each data-point :math:`x` and
all other data-points :math:`x'`, :math:`\sum_{x'} w_{x'} \cdot k(x, x')`
where we sum over all :math:`N` data-points.
- Select the first coreset point :math:`x_{1}` as the data-point where the
Gramian row mean is highest.
- Compute all future coreset points as
:math:`x_{T+1} = \arg\max_{x} \left( \mathbb{E}[k(x, x')] - \frac{1}{T+1}\sum_{t=1}^T w_{x_t} \cdot k(x, x_t) \right)`
where we currently have :math:`T` points in the coreset.

We ask for a coreset of size 2 in this example. With an empty coreset, we first
compute :math:`\mathbb{E}[k(x, x')]` as:

.. math::
\mathbb{E}[k(x, x')] = \begin{pmatrix}
0.8 \cdot k([0.3, 0.25]', [0.3, 0.25]') + 0.1 \cdot k([0.3, 0.25]', [0.4, 0.2]') + 0.1 \cdot k([0.3, 0.25]', [0.5, 0.125]') \\
0.8 \cdot k([0.4, 0.2]', [0.3, 0.25]') + 0.1 \cdot k([0.4, 0.2]', [0.4, 0.2]') + 0.1 \cdot k([0.4, 0.2]', [0.5, 0.125]') \\
0.8 \cdot k([0.5, 0.125]', [0.3, 0.25]') + 0.1 \cdot k([0.5, 0.125]', [0.4, 0.2]') + 0.1 \cdot k([0.5, 0.125]', [0.5, 0.125]')
\end{pmatrix}

resulting in:

.. math::
\mathbb{E}[k(x, x')] = \begin{pmatrix}
0.9933471580051769 \\
0.988511884095646 \\
0.9551646673468503
\end{pmatrix}

The largest value in this array is 0.9933471580051769, so we expect the first
coreset point to be [0.3 0.25], that is the data-point at index 0 in the
dataset. At this point we have ``coreset_indices`` as [0, ?].

We then compute the penalty update term
:math:`\frac{1}{T+1}\sum_{t=1}^T k(x, x_t)` with :math:`T = 1`:

.. math::
\frac{1}{T+1}\sum_{t=1}^T k(x, x_t) = \frac{1}{2} \cdot \begin{pmatrix}
k([0.3, 0.25]', [0.3, 0.25]') \\
k([0.4, 0.2]', [0.3, 0.25]') \\
k([0.5, 0.125]', [0.3, 0.25]')
\end{pmatrix}

which evaluates to:

.. math::
\frac{1}{T+1}\sum_{t=1}^T k(x, x_t) = \begin{pmatrix}
0.5 \\
0.4937889002469407 \\
0.4729468897789434
\end{pmatrix}

We now select the data-point that maximises
:math:`\mathbb{E}[k(x, x')] - \frac{1}{T+1}\sum_{t=1}^T k(x, x_t)`,
which evaluates to:

.. math::
\mathbb{E}[k(x, x')] - \frac{1}{T+1}\sum_{t=1}^T k(x, x_t) = \begin{pmatrix}
0.9933471580051769 - 0.5 \\
0.988511884095646 - 0.4937889002469407 \\
0.9551646673468503 - 0.4729468897789434
\end{pmatrix}

giving a final result of:

.. math::
\mathbb{E}[k(x, x')] - \frac{1}{T+1}\sum_{t=1}^T k(x, x_t) = \begin{pmatrix}
0.4933471580051769 \\
0.49472298384870533 \\
0.48221777756790696
\end{pmatrix}

The largest value in this array is at index 1, which means we choose
the point [0.4, 0.2] for the coreset. This means our final ``coreset_indices``
should be [0, 1].

Finally, the solver state tracks variables we need not compute repeatedly. In
the case of kernel herding, we don't need to recompute
:math:`\mathbb{E}[k(x, x')]` at every single step - so the solver state from the
coreset reduce method should be set to:

.. math::
\mathbb{E}[k(x, x')] = \begin{pmatrix}
0.9933471580051769 \\
0.988511884095646 \\
0.9551646673468503
\end{pmatrix}
""" # noqa: E501
# pylint: enable=line-too-long
# Setup example data - note we have specifically selected points that are very
# close to manipulate the penalty applied for nearby points, and hence enable
# a check of unique points using the same data.
coreset_size = 2
length_scale = 1.0 / jnp.sqrt(2)
x = jnp.array(
[
[0.3, 0.25],
[0.4, 0.2],
[0.5, 0.125],
]
)
weights = jnp.array([0.8, 0.1, 0.1])

# Define a kernel
kernel = SquaredExponentialKernel(length_scale=length_scale)

# Generate the coreset
data = Data(data=x, weights=weights)
solver = KernelHerding(coreset_size=coreset_size, kernel=kernel, unique=True)
coreset, solver_state = solver.reduce(data)

# Define the expected outputs, following the arguments in the docstring
expected_coreset_indices = jnp.array([0, 1])
expected_gramian_row_mean = jnp.array(
[0.9933471580051769, 0.988511884095646, 0.9551646673468503]
)

# Check output matches expected
np.testing.assert_array_equal(
coreset.unweighted_indices, expected_coreset_indices
)
np.testing.assert_array_equal(
coreset.coreset.data, data.data[expected_coreset_indices]
)
np.testing.assert_array_almost_equal(
solver_state.gramian_row_mean, expected_gramian_row_mean
)


qh681248 marked this conversation as resolved.
Show resolved Hide resolved
class TestRandomSample(ExplicitSizeSolverTest):
"""Test cases for :class:`coreax.solvers.coresubset.RandomSample`."""
Expand Down