Skip to content

Commit

Permalink
feat: remove instrumentation
Browse files Browse the repository at this point in the history
  • Loading branch information
tc85324 committed May 30, 2024
1 parent 3a66572 commit 752cae6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 54 deletions.
55 changes: 3 additions & 52 deletions coreax/recombination.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,36 +301,24 @@ def _cond(state: EliminationState):
def _body(state: EliminationState):
weights, left_null_space_basis, indices, basis_index = state
basis_vector = left_null_space_basis[:, basis_index]

# Do normal pivoting -> then pivot for the argmin?

# This handles both possible cases, rather than just positive/negative.
absolute_basis_vector = jnp.abs(basis_vector)
_rescaled_weights = weights / absolute_basis_vector
rescaled_weights = jnp.where(
absolute_basis_vector > zero_tol, _rescaled_weights, jnp.inf
)
elimination_index = jnp.argmin(rescaled_weights)
# pivot = (basis_index, elimination_index)
# pivoted_indices = indices.at[pivot, ...].set(indices[pivot[::-1], ...])

weights_update = rescaled_weights[elimination_index] * basis_vector
update_sign = jnp.sign(basis_vector[elimination_index])
updated_weights = weights - update_sign * weights_update
# updated_weights = updated_weights.at[elimination_index].set(0.0)
# updated_weights = weights.at[pivoted_indices].set(updated_weights)

basis_update = jnp.tensordot(
basis_vector / basis_vector[elimination_index],
left_null_space_basis[elimination_index],
axes=0,
)
updated_left_null_space_basis = left_null_space_basis - basis_update
# updated_left_null_space_basis = left_null_space_basis.at[pivoted_indices].set(
# updated_left_null_space_basis
# )
# jax.debug.print("W: {W}", W=updated_weights)
# jax.debug.breakpoint()
return EliminationState(
updated_weights,
updated_left_null_space_basis,
Expand All @@ -349,11 +337,6 @@ def _body(state: EliminationState):
out_state = jax.lax.while_loop(_cond, _body, in_state)
output_weights, *_ = out_state
_, retained_indices = jax.lax.top_k(output_weights, m)
# jax.debug.breakpoint()
jnp.set_printoptions(precision=2, linewidth=120)
# q, s, vt = jsp.linalg.svd(augmented_nodes, full_matrices=True)
# p1, l1, u1 = jsp.linalg.lu(largest_left_null_space_basis)
# jax.debug.breakpoint()
return output_weights, retained_indices


Expand All @@ -366,7 +349,6 @@ def recombination(
tree_reduction_factor: int = 2,
mode: Literal["svd", "qr"] = "svd",
rcond: InexactScalarLike | None = None,
assume_non_degenerate: bool = False,
):
# Pre-process the weights
abs_weights = jnp.abs(weights)
Expand Down Expand Up @@ -431,70 +413,39 @@ def _tree_reduction_step(_, state):
centroid_weights, centroid_nodes = _centroid(
weights[reshaped_indices], padded_nodes[reshaped_indices]
)
# TODO: have centering that only needs computing once.
# Centre the centroids to have zero the dataset CoM at zero.
# TODO: This seems to have positive performance benefits.
centred_centroid_nodes = centroid_nodes.at[:, 1:].add(-target_com[-1][..., 1:])
updated_centroid_weights, _ = caratheodory_measure_reduction(
centroid_weights, centred_centroid_nodes
)
# Updated weights
weight_update = updated_centroid_weights / centroid_weights
# If weight update is 1, then nothing has happened.
updated_weights = jnp.nan_to_num(
weights.at[reshaped_indices].multiply(weight_update[..., None])
)

# Update indices
# TODO: check degenerate possibilities for this.
_, eliminated_indices = jax.lax.top_k(-updated_centroid_weights, d)
_updated_indices = reshaped_indices.at[eliminated_indices].set(-1)
updated_indices = jnp.partition(
_updated_indices.reshape(-1), n // tree_reduction_factor
)

current_com = _centroid(updated_weights[None, ...], padded_nodes[None, ...])
indexed_com = _centroid(
updated_weights[updated_indices].reshape(1, -1),
padded_nodes[updated_indices].reshape(1, -1, d),
)
com_diff = jtu.tree_map(
lambda x, y: jnp.linalg.norm(x - y),
(target_com, target_com, (centroid_weights.sum(), centroid_nodes)),
(
current_com,
indexed_com,
(updated_centroid_weights.sum(), centroid_nodes),
),
)
# jax.debug.print(
# "\nCOM DIFF\n--------\nMASKED: {x};\nINDEXED: {y};\nCENTROID: {z}",
# x=com_diff[0],
# y=com_diff[1],
# z=com_diff[2],
# )
# jax.debug.breakpoint()
return updated_weights, updated_indices

if n <= d:
return caratheodory_measure_reduction(weights, nodes)

in_state = (padded_weights, padded_indices)
root_weights, _ = jax.lax.fori_loop(
0, max_tree_depth, _tree_reduction_step, in_state
0, max_tree_depth, _tree_reduction_step, in_state,
)
output_weights, retained_root_indices = jax.lax.top_k(root_weights, d)
# jax.debug.breakpoint()
return output_weights, retained_root_indices
# leaf_weights, retained_leaf_indices = caratheodory_measure_reduction(
# root_weights[retained_root_indices], padded_nodes[retained_root_indices]
# )
# retained_indices = retained_root_indices[retained_leaf_indices]
# return leaf_weights[retained_leaf_indices], retained_indices


@jax.vmap
def _centroid(weights, nodes):
"""Compute the centroid mass and node centre (centre of mass)."""
centroid_weights = jnp.sum(weights)
centroid_nodes = jnp.nan_to_num(jnp.average(nodes, 0, weights)).at[..., 0].set(1)
centroid_nodes = jnp.nan_to_num(jnp.average(nodes, 0, weights))
return centroid_weights, centroid_nodes
6 changes: 4 additions & 2 deletions tests/unit/test_recombination.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def random_atomic_measure_generator(
elif degeneracy == "null":
# These need randomly inserting instead, rather than being all at the end.
nodes = nodes.at[(d - 1) :].set(0.0)
weights = weights.at[(d - 1) :].set(0.0)
weights = weights / jnp.sum(weights)
else:
raise ValueError(
f"Degeneracy must be one of {get_args(DEGENERACIES)}; "
Expand Down Expand Up @@ -171,8 +173,8 @@ def test_tree_recombination(
)
for weights, nodes in case_iterator:
measure_reduction = recombination
# if degeneracy is None and mode != "qr":
# measure_reduction = eqx.filter_jit(recombination)
if degeneracy is None and mode != "qr":
measure_reduction = eqx.filter_jit(recombination)
with warnings.catch_warnings(record=True) as record:
warnings.simplefilter("always")
result_weights, result_nodes = measure_reduction(
Expand Down

0 comments on commit 752cae6

Please sign in to comment.