Skip to content

Commit

Permalink
feat: use top_k and partitioning in recombination
Browse files Browse the repository at this point in the history
  • Loading branch information
tc85324 committed May 30, 2024
1 parent 75d33ea commit 3a66572
Showing 1 changed file with 92 additions and 64 deletions.
156 changes: 92 additions & 64 deletions coreax/recombination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import math
import warnings
from collections import namedtuple
from collections.abc import Callable
from typing import cast, Literal

Expand Down Expand Up @@ -198,7 +199,7 @@ def _compute_null_space(
raise ValueError(f"Invalid mode specified; got {mode}, expected 'svd' or 'qr'")

left_null_space_rank = reveal_left_null_space_rank(nodes.shape, s, rcond)
largest_left_null_space_basis = q
largest_left_null_space_basis = q[:, ::-1]
return largest_left_null_space_basis, left_null_space_rank


Expand Down Expand Up @@ -247,6 +248,11 @@ def caratheodory_measure_reduction(
return output_weights[retained_indices], nodes[retained_indices]


EliminationState = namedtuple(
"EliminationState", ["weights", "nodes", "indices", "iteration"]
)


# Can make use of eqxi.while to reduce memory usage here with checkpointing and buffers?
# Perhaps can replace with a palas kernel?
def implicit_caratheodory_measure_reduction(
Expand Down Expand Up @@ -285,58 +291,69 @@ def implicit_caratheodory_measure_reduction(
largest_left_null_space_basis, left_null_space_rank = _compute_null_space(
augmented_nodes, rcond, mode=mode
)
zero_tol = 1e-12

def _cond(
state: tuple[Shaped[Array, " n"], Shaped[Array, "hat_n M"], int],
):
*_, i = state
return i < left_null_space_rank

def _body(
state: tuple[Shaped[Array, " n"], Shaped[Array, "hat_n M"], int],
):
# TODO: docstring explaining what is happening here.
weights, left_null_space_basis, i = state
basis_index = -(i + 1)
# Perform partial-pivoting (row-exchange); avoid divide-by-zero errors that
# arrise in the subsequent Gaussian-elimination step.
def _cond(state: EliminationState):
# TODO: add skip condition if null space is all zeros?
*_, basis_index = state
return basis_index < left_null_space_rank

def _body(state: EliminationState):
weights, left_null_space_basis, indices, basis_index = state
basis_vector = left_null_space_basis[:, basis_index]
pivot_index = jnp.argmax(jnp.abs(basis_vector))
pivot = (basis_index, pivot_index)
pivoted_indices = indices.at[pivot[::-1], ...].set(indices[pivot, ...])
pivoted_weights = weights[pivoted_indices]
pivoted_left_null_space_basis = left_null_space_basis[pivoted_indices]
pivoted_basis_vector = pivoted_left_null_space_basis[:, basis_index]
# Perform Gaussian-elimination; the smallest `rescaled_weights` are eliminated
# (zeroed) and the remaining weights updated to maintain the constraint
# :math:`\sum_{i=1}^n \hat{w}_i = 1`.

# Do normal pivoting -> then pivot for the argmin?

# This handles both possible cases, rather than just positive/negative.
absolute_basis_vector = jnp.abs(basis_vector)
_rescaled_weights = weights / absolute_basis_vector
rescaled_weights = jnp.where(
pivoted_basis_vector > 0.0, pivoted_weights / pivoted_basis_vector, jnp.inf
absolute_basis_vector > zero_tol, _rescaled_weights, jnp.inf
)
elimination_index = jnp.argmin(rescaled_weights)
weight_update = rescaled_weights[elimination_index] * pivoted_basis_vector
updated_weights = pivoted_weights - weight_update
updated_weights = updated_weights.at[elimination_index].set(0.0)
updated_weights = weights.at[pivoted_indices].set(updated_weights)
# pivot = (basis_index, elimination_index)
# pivoted_indices = indices.at[pivot, ...].set(indices[pivot[::-1], ...])

weights_update = rescaled_weights[elimination_index] * basis_vector
update_sign = jnp.sign(basis_vector[elimination_index])
updated_weights = weights - update_sign * weights_update
# updated_weights = updated_weights.at[elimination_index].set(0.0)
# updated_weights = weights.at[pivoted_indices].set(updated_weights)

basis_update = jnp.tensordot(
pivoted_basis_vector / pivoted_basis_vector[elimination_index],
pivoted_left_null_space_basis[elimination_index],
basis_vector / basis_vector[elimination_index],
left_null_space_basis[elimination_index],
axes=0,
)
updated_left_null_space_basis = pivoted_left_null_space_basis - basis_update
updated_left_null_space_basis = left_null_space_basis.at[pivoted_indices].set(
updated_left_null_space_basis
updated_left_null_space_basis = left_null_space_basis - basis_update
# updated_left_null_space_basis = left_null_space_basis.at[pivoted_indices].set(
# updated_left_null_space_basis
# )
# jax.debug.print("W: {W}", W=updated_weights)
# jax.debug.breakpoint()
return EliminationState(
updated_weights,
updated_left_null_space_basis,
indices,
basis_index + 1,
)
return updated_weights, updated_left_null_space_basis, i + 1

n, m = augmented_nodes.shape
upper_bound_loop_count = n
indices = jnp.arange(upper_bound_loop_count)
in_state = (probability_weights, largest_left_null_space_basis, 0)
in_state = EliminationState(
probability_weights,
largest_left_null_space_basis,
jnp.arange(upper_bound_loop_count),
0,
)
out_state = jax.lax.while_loop(_cond, _body, in_state)
output_weights, *_ = out_state
retained_indices = jnp.argsort(output_weights)[-m:]
_, retained_indices = jax.lax.top_k(output_weights, m)
# jax.debug.breakpoint()
jnp.set_printoptions(precision=2, linewidth=120)
# q, s, vt = jsp.linalg.svd(augmented_nodes, full_matrices=True)
# p1, l1, u1 = jsp.linalg.lu(largest_left_null_space_basis)
# jax.debug.breakpoint()
return output_weights, retained_indices


Expand Down Expand Up @@ -406,27 +423,35 @@ def _tree_recombination(
rcond=rcond,
mode=mode,
)

target_com = _centroid(padded_weights[None, ...], padded_nodes[None, ...])
target_com = _centroid(weights[None, ...], nodes[None, ...])

def _tree_reduction_step(_, state):
weights, indices = state
reshaped_indices = indices.reshape(tree_count, -1)
reshaped_indices = indices.reshape(tree_count, -1, order="F")
centroid_weights, centroid_nodes = _centroid(
weights[reshaped_indices], padded_nodes[reshaped_indices]
)
# TODO: have centering that only needs computing once.
# Centre the centroids to have zero the dataset CoM at zero.
centred_centroid_nodes = centroid_nodes.at[:, 1:].add(-target_com[-1][..., 1:])
updated_centroid_weights, _ = caratheodory_measure_reduction(
centroid_weights, centroid_nodes
centroid_weights, centred_centroid_nodes
)
updated_indices = jnp.where(
updated_centroid_weights[..., None] > 0, reshaped_indices, -1
)

# Updated weights
weight_update = updated_centroid_weights / centroid_weights
# If weight update is 1, then nothing has happened.
updated_weights = jnp.nan_to_num(
weights.at[reshaped_indices].multiply(weight_update[..., None])
)

# Update indices
# TODO: check degenerate possibilities for this.
_, eliminated_indices = jax.lax.top_k(-updated_centroid_weights, d)
_updated_indices = reshaped_indices.at[eliminated_indices].set(-1)
updated_indices = jnp.partition(
_updated_indices.reshape(-1), n // tree_reduction_factor
)

current_com = _centroid(updated_weights[None, ...], padded_nodes[None, ...])
indexed_com = _centroid(
updated_weights[updated_indices].reshape(1, -1),
Expand All @@ -441,32 +466,35 @@ def _tree_reduction_step(_, state):
(updated_centroid_weights.sum(), centroid_nodes),
),
)
jax.debug.print(
"\nCOM DIFF\n--------\nMASKED: {x};\nINDEXED: {y};\nCENTROID: {z}",
x=com_diff[0],
y=com_diff[1],
z=com_diff[2],
)
jax.debug.breakpoint()
return updated_weights, updated_indices.reshape(-1, order="F")
# jax.debug.print(
# "\nCOM DIFF\n--------\nMASKED: {x};\nINDEXED: {y};\nCENTROID: {z}",
# x=com_diff[0],
# y=com_diff[1],
# z=com_diff[2],
# )
# jax.debug.breakpoint()
return updated_weights, updated_indices

if n <= d:
return caratheodory_measure_reduction(weights, nodes)

in_state = (padded_weights, padded_indices)
root_weights, root_indices = jax.lax.fori_loop(
root_weights, _ = jax.lax.fori_loop(
0, max_tree_depth, _tree_reduction_step, in_state
)
retained_root_indices = jnp.sort(root_indices)[-tree_count:]
breakpoint()
leaf_weights, _ = caratheodory_measure_reduction(...)
retained_leaf_indices = jnp.argsort(leaf_weights)[-d:]
retained_indices = retained_root_indices[retained_leaf_indices]
return leaf_weights[retained_leaf_indices], retained_indices
output_weights, retained_root_indices = jax.lax.top_k(root_weights, d)
# jax.debug.breakpoint()
return output_weights, retained_root_indices
# leaf_weights, retained_leaf_indices = caratheodory_measure_reduction(
# root_weights[retained_root_indices], padded_nodes[retained_root_indices]
# )
# retained_indices = retained_root_indices[retained_leaf_indices]
# return leaf_weights[retained_leaf_indices], retained_indices


@jax.vmap
def _centroid(weights, nodes):
"""Compute the centroid mass and node centre (centre of mass)."""
centroid_weights = jnp.sum(weights)
centroid_nodes = jnp.nan_to_num(
jnp.average(nodes, 0, weights)
).at[..., 0].set(1)
centroid_nodes = jnp.nan_to_num(jnp.average(nodes, 0, weights)).at[..., 0].set(1)
return centroid_weights, centroid_nodes

0 comments on commit 3a66572

Please sign in to comment.