Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map reduce index bug #790

Merged
merged 18 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
272da26
#779 At coreax.solvers.composite, changed the return statement of the…
qh681248 Sep 24, 2024
d9045a4
#779 At coreax.solvers.composite, changed the return statement of the…
qh681248 Sep 24, 2024
d87afde
Added a test in unit/test_solvers.py that checks if MapReduce's reduc…
qh681248 Sep 25, 2024
08f666a
Added an if statement on MapReduce.reduce method, it now only assigns…
qh681248 Sep 25, 2024
7d41a4b
replaced mapreduce by map_reduce in test_map_reduce_diverse_selection…
qh681248 Sep 26, 2024
48a1037
In coreax/solvers/composite.py, the reduce method updates indices onl…
qh681248 Sep 26, 2024
1bd8501
Removed the line plt.show() in examples/pounce.py
qh681248 Sep 27, 2024
082ada7
Made requested changes in the PR #790 (fixed type hints, comments etc.)
qh681248 Oct 3, 2024
e5f5122
Removed changed NoneType to None (NoneType is not compatible with pyt…
qh681248 Oct 3, 2024
da827a7
removed a folder that wasn't supposed to be added
qh681248 Oct 3, 2024
c187818
On composite.py changed return statement of _jit_tree from dataset[in…
qh681248 Oct 4, 2024
a4f762c
Added MapReduce bugfix to `CHANGELOG.md`
qh681248 Oct 8, 2024
bbae67d
Added double backticks on comments in `composite.py` when referring t…
qh681248 Oct 8, 2024
cf537ed
Removed a redundant comment on a test on `TestMapReduce` class
qh681248 Oct 8, 2024
c894e07
Added analytic test for `MapReduce`
qh681248 Oct 9, 2024
a5fb1be
docs: make suggested changes in the docstring
qh681248 Oct 14, 2024
b48bda4
docs: make suggested changes in the docstring
qh681248 Oct 17, 2024
2a9a1c4
Merge remote-tracking branch 'origin/main' into bugfix/MapReduce-inde…
qh681248 Oct 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ disable tqdm progress bar terminal output. Defaults to disabled (`False`).


### Fixed

- `MapReduce` in `coreax.solvers.composite.py` now keeps track of the indices.


### Changed
Expand Down
60 changes: 48 additions & 12 deletions coreax/solvers/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jax import Array
from sklearn.neighbors import BallTree, KDTree
from typing_extensions import TypeAlias, override

from coreax.coreset import Coreset, Coresubset
from coreax.data import Data
from coreax.solvers.base import ExplicitSizeSolver, PaddingInvariantSolver, Solver
from coreax.util import tree_zero_pad_leading_axis
from coreax.util import ArrayLike, tree_zero_pad_leading_axis

BinaryTree: TypeAlias = Union[KDTree, BallTree]
_Data = TypeVar("_Data", bound=Data)
_Coreset = TypeVar("_Coreset", Coreset, Coresubset)
_State = TypeVar("_State")
_Indices = TypeVar("_Indices", ArrayLike, None)


class CompositeSolver(
Expand Down Expand Up @@ -125,22 +127,56 @@ def reduce(
# There is no obvious way to use state information here.
del solver_state

def _reduce_coreset(data: _Data) -> tuple[_Coreset, _State]:
def _reduce_coreset(
data: _Data, _indices: Optional[_Indices] = None
) -> tuple[_Coreset, _State, _Indices]:
if len(data) <= self.leaf_size:
return self.base_solver.reduce(data)
partitioned_dataset = _jit_tree(data, self.leaf_size, self.tree_type)
coreset_ensemble, _ = jax.vmap(self.base_solver.reduce)(partitioned_dataset)
coreset, state = self.base_solver.reduce(data)
if _indices is not None:
_indices = _indices[coreset.nodes.data]
return coreset, state, _indices

def wrapper(partition: _Data) -> tuple[_Data, Array]:
"""
Apply the `reduce` method of the base solver on a partition.

This is a wrapper for `reduce()` for processing a single partition.
The data is partitioned with `_jit_tree()`.
The reduction is performed on each partition via ``vmap()``.
"""
x, _ = self.base_solver.reduce(partition)
return x.coreset, x.nodes.data

partitioned_dataset, partitioned_indices = _jit_tree(
data, self.leaf_size, self.tree_type
)
# Reduce each partition and get indices from each
coreset_ensemble, ensemble_indices = jax.vmap(wrapper)(partitioned_dataset)
# Calculate the indices with respect to the original data
concatenated_indices = jax.vmap(lambda x, index: x[index])(
partitioned_indices, ensemble_indices
)
concatenated_indices = jnp.ravel(concatenated_indices)
_coreset = jtu.tree_map(jnp.concatenate, coreset_ensemble)
return _reduce_coreset(_coreset.coreset)

coreset_wrong_pre_coreset_data, output_solver_state = _reduce_coreset(dataset)
coreset = eqx.tree_at(
lambda x: x.pre_coreset_data, coreset_wrong_pre_coreset_data, dataset
)
if _indices is not None:
final_indices = _indices[concatenated_indices]
else:
final_indices = concatenated_indices
return _reduce_coreset(_coreset, final_indices)

(pre_coreset, output_solver_state, _indices) = _reduce_coreset(dataset)
# Correct the pre-coreset data and the indices
coreset = eqx.tree_at(lambda x: x.pre_coreset_data, pre_coreset, dataset)
if _indices is not None:
if isinstance(coreset, Coresubset):
coreset = eqx.tree_at(lambda x: x.nodes.data, coreset, _indices)
return coreset, output_solver_state


def _jit_tree(dataset: _Data, leaf_size: int, tree_type: type[BinaryTree]) -> _Data:
def _jit_tree(
dataset: _Data, leaf_size: int, tree_type: type[BinaryTree]
) -> tuple[_Data, _Indices]:
"""
Return JIT compatible BinaryTree partitioning of 'dataset'.

Expand Down Expand Up @@ -183,4 +219,4 @@ def _binary_tree(_input_data: Data) -> np.ndarray:
return node_indices.reshape(n_leaves, -1).astype(np.int32)

indices = jax.pure_callback(_binary_tree, result_shape, padded_dataset)
return dataset[indices]
return padded_dataset[indices], jnp.arange(len(dataset))[indices]
158 changes: 158 additions & 0 deletions tests/unit/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,164 @@ def test_base_solver(
solver_factory.keywords["base_solver"] = base_solver
solver_factory()

def test_map_reduce_diverse_selection(self):
"""Check if MapReduce returns indices from multiple partitions."""
dataset_size = 40
data_dim = 5
coreset_size = 6
leaf_size = 12

key = jr.PRNGKey(0)
dataset = jr.normal(key, shape=(dataset_size, data_dim))

kernel = SquaredExponentialKernel()
base_solver = KernelHerding(coreset_size=coreset_size, kernel=kernel)
bk958178 marked this conversation as resolved.
Show resolved Hide resolved

solver = MapReduce(base_solver=base_solver, leaf_size=leaf_size)
coreset, _ = solver.reduce(Data(dataset))
selected_indices = coreset.nodes.data

assert jnp.any(
selected_indices >= coreset_size
), "MapReduce should select points beyond the first few"

# Check if there are indices from different partitions
partitions_represented = jnp.unique(selected_indices // leaf_size)
assert (
len(partitions_represented) > 1
), "MapReduce should select points from multiple partitions"

def test_map_reduce_analytic(self):
r"""
Test ``MapReduce`` on an analytical example, enforcing a unique coreset.

In this example, we start with the original dataset
:math:`[10, 20, 30, 210, 40, 60, 180, 90, 150, 70, 120,
200, 50, 140, 80, 170, 100, 190, 110, 160, 130]`.

Suppose we want a subset size of 3, and we want maximum leaf size of 6.

We can see that we have a dataset of size 21. The partitioning scheme
only allows for :math:`n` partitions where :math:`n` is a power of 2.
Therefore, we can partition into:

1. 1 partition of size 21
2. 2 partitions of size :math:`\lceil 10.5 \rceil = 11` each (with one padded 0)
3. 4 partitions of size :math:`\lceil 5.25 \rceil = 6` each (with 3 padded 0's)
4. 8 partitions of size :math:`\lceil 2.625 \rceil = 3` each (with 3 padded 0's)

Since we set the maximum leaf size :math:`m = 6`, we choose the largest
partition size that is less than or equal to 6. Thus, we have 4 partitions
each of size 6.

This results in the following 4 partitions (see how
data is in ascending order):

1. :math:`[0, 0, 0, 10, 20, 30]`
2. :math:`[40, 50, 60, 70, 80, 90]`
3. :math:`[100, 110, 120, 130, 140, 150]`
4. :math:`[160, 170, 180, 190, 200, 210]`

Now we want to reduce each partition with our ``interleaved_base_solver``
which is designed to choose first, last, second, second-last, third,
bk958178 marked this conversation as resolved.
Show resolved Hide resolved
third-last elements etc. until the coreset of correct size is formed.
Hence, we obtain:

1. :math:`[0, 30, 0]`
2. :math:`[40, 90, 50]`
3. :math:`[100, 150, 110]`
4. :math:`[160, 210, 170]`

Concatenating we obtain
:math:`[0, 30, 0, 40, 90, 50, 100, 150, 110, 160, 210, 170]`.
We repeat the process, checking how many partitions we want to divide this
intermediate dataset (of size 12) into. Recall, this number of partitions must
be a power of 2. Our options are:

1. 1 partition of size 12
2. 2 partitions of size 6
3. 4 partitions of size 3
4. 8 partitions of size 1.5 (rounded up to 2)

Given our maximum leaf size :math:`m = 6`, we choose the largest partition size
that is less than or equal to 6. Therefore, we select 2 partitions of size 6.
This time no padding is necessary. The two partitions resulting from this step
are (note that it is again in ascending order):

1. :math:`[0, 0, 30, 40, 50, 90]`
2. :math:`[100, 110, 150, 160, 170, 210]`

Applying our ``interleaved_base_solver`` with `coreset_size` 3 on
each partition, we obtain:

1. :math:`[0, 90, 0]`
2. :math:`[100, 210, 110]`

Now, we concatenate the two subsets and repeat the process to
obtain only one partition:

1. Concatenated subset: :math:`[0, 90, 0, 100, 210, 110]`

Note that the size of the dataset is 6,
therefore, no more partitioning is necessary.

Applying ``interleaved_base_solver`` one last time we obtain the final coreset:
:math:`[0, 110, 90]`.
"""
interleaved_base_solver = MagicMock(_ExplicitPaddingInvariantSolver)
interleaved_base_solver.coreset_size = 3

def interleaved_mock_reduce(
dataset: Data, solver_state: None = None
) -> tuple[Coreset[Data], None]:
half_size = interleaved_base_solver.coreset_size // 2
indices = jnp.arange(interleaved_base_solver.coreset_size)
forward_indices = indices[:half_size]
backward_indices = -(indices[:half_size] + 1)
interleaved_indices = jnp.stack(
[forward_indices, backward_indices], axis=1
).ravel()

if interleaved_base_solver.coreset_size % 2 != 0:
interleaved_indices = jnp.append(interleaved_indices, half_size)
return Coreset(dataset[interleaved_indices], dataset), solver_state

interleaved_base_solver.reduce = interleaved_mock_reduce

original_data = Data(
bk958178 marked this conversation as resolved.
Show resolved Hide resolved
jnp.array(
[
10,
20,
30,
210,
40,
60,
180,
90,
150,
70,
120,
200,
50,
140,
80,
170,
100,
190,
110,
160,
130,
]
)
)
expected_coreset_data = Data(jnp.array([0, 110, 90]))

coreset, _ = MapReduce(base_solver=interleaved_base_solver, leaf_size=6).reduce(
original_data
)
assert eqx.tree_equal(coreset.coreset.data == expected_coreset_data.data)


class TestCaratheodoryRecombination(RecombinationSolverTest):
"""Tests for :class:`coreax.solvers.recombination.CaratheodoryRecombination`."""
Expand Down