diff --git a/coreax/recombination.py b/coreax/recombination.py index e713aecde..d0968dd76 100644 --- a/coreax/recombination.py +++ b/coreax/recombination.py @@ -301,9 +301,6 @@ def _cond(state: EliminationState): def _body(state: EliminationState): weights, left_null_space_basis, indices, basis_index = state basis_vector = left_null_space_basis[:, basis_index] - - # Do normal pivoting -> then pivot for the argmin? - # This handles both possible cases, rather than just positive/negative. absolute_basis_vector = jnp.abs(basis_vector) _rescaled_weights = weights / absolute_basis_vector @@ -311,14 +308,10 @@ def _body(state: EliminationState): absolute_basis_vector > zero_tol, _rescaled_weights, jnp.inf ) elimination_index = jnp.argmin(rescaled_weights) - # pivot = (basis_index, elimination_index) - # pivoted_indices = indices.at[pivot, ...].set(indices[pivot[::-1], ...]) weights_update = rescaled_weights[elimination_index] * basis_vector update_sign = jnp.sign(basis_vector[elimination_index]) updated_weights = weights - update_sign * weights_update - # updated_weights = updated_weights.at[elimination_index].set(0.0) - # updated_weights = weights.at[pivoted_indices].set(updated_weights) basis_update = jnp.tensordot( basis_vector / basis_vector[elimination_index], @@ -326,11 +319,6 @@ def _body(state: EliminationState): axes=0, ) updated_left_null_space_basis = left_null_space_basis - basis_update - # updated_left_null_space_basis = left_null_space_basis.at[pivoted_indices].set( - # updated_left_null_space_basis - # ) - # jax.debug.print("W: {W}", W=updated_weights) - # jax.debug.breakpoint() return EliminationState( updated_weights, updated_left_null_space_basis, @@ -349,11 +337,6 @@ def _body(state: EliminationState): out_state = jax.lax.while_loop(_cond, _body, in_state) output_weights, *_ = out_state _, retained_indices = jax.lax.top_k(output_weights, m) - # jax.debug.breakpoint() - jnp.set_printoptions(precision=2, linewidth=120) - # q, s, vt = jsp.linalg.svd(augmented_nodes, full_matrices=True) - # p1, l1, u1 = jsp.linalg.lu(largest_left_null_space_basis) - # jax.debug.breakpoint() return output_weights, retained_indices @@ -366,7 +349,6 @@ def recombination( tree_reduction_factor: int = 2, mode: Literal["svd", "qr"] = "svd", rcond: InexactScalarLike | None = None, - assume_non_degenerate: bool = False, ): # Pre-process the weights abs_weights = jnp.abs(weights) @@ -431,48 +413,23 @@ def _tree_reduction_step(_, state): centroid_weights, centroid_nodes = _centroid( weights[reshaped_indices], padded_nodes[reshaped_indices] ) - # TODO: have centering that only needs computing once. - # Centre the centroids to have zero the dataset CoM at zero. + # TODO: This seems to have positive performance benefits. centred_centroid_nodes = centroid_nodes.at[:, 1:].add(-target_com[-1][..., 1:]) updated_centroid_weights, _ = caratheodory_measure_reduction( centroid_weights, centred_centroid_nodes ) # Updated weights weight_update = updated_centroid_weights / centroid_weights - # If weight update is 1, then nothing has happened. updated_weights = jnp.nan_to_num( weights.at[reshaped_indices].multiply(weight_update[..., None]) ) # Update indices - # TODO: check degenerate possibilities for this. _, eliminated_indices = jax.lax.top_k(-updated_centroid_weights, d) _updated_indices = reshaped_indices.at[eliminated_indices].set(-1) updated_indices = jnp.partition( _updated_indices.reshape(-1), n // tree_reduction_factor ) - - current_com = _centroid(updated_weights[None, ...], padded_nodes[None, ...]) - indexed_com = _centroid( - updated_weights[updated_indices].reshape(1, -1), - padded_nodes[updated_indices].reshape(1, -1, d), - ) - com_diff = jtu.tree_map( - lambda x, y: jnp.linalg.norm(x - y), - (target_com, target_com, (centroid_weights.sum(), centroid_nodes)), - ( - current_com, - indexed_com, - (updated_centroid_weights.sum(), centroid_nodes), - ), - ) - # jax.debug.print( - # "\nCOM DIFF\n--------\nMASKED: {x};\nINDEXED: {y};\nCENTROID: {z}", - # x=com_diff[0], - # y=com_diff[1], - # z=com_diff[2], - # ) - # jax.debug.breakpoint() return updated_weights, updated_indices if n <= d: @@ -480,21 +437,15 @@ def _tree_reduction_step(_, state): in_state = (padded_weights, padded_indices) root_weights, _ = jax.lax.fori_loop( - 0, max_tree_depth, _tree_reduction_step, in_state + 0, max_tree_depth, _tree_reduction_step, in_state, ) output_weights, retained_root_indices = jax.lax.top_k(root_weights, d) - # jax.debug.breakpoint() return output_weights, retained_root_indices - # leaf_weights, retained_leaf_indices = caratheodory_measure_reduction( - # root_weights[retained_root_indices], padded_nodes[retained_root_indices] - # ) - # retained_indices = retained_root_indices[retained_leaf_indices] - # return leaf_weights[retained_leaf_indices], retained_indices @jax.vmap def _centroid(weights, nodes): """Compute the centroid mass and node centre (centre of mass).""" centroid_weights = jnp.sum(weights) - centroid_nodes = jnp.nan_to_num(jnp.average(nodes, 0, weights)).at[..., 0].set(1) + centroid_nodes = jnp.nan_to_num(jnp.average(nodes, 0, weights)) return centroid_weights, centroid_nodes diff --git a/tests/unit/test_recombination.py b/tests/unit/test_recombination.py index 7e9beb894..73f6b519f 100644 --- a/tests/unit/test_recombination.py +++ b/tests/unit/test_recombination.py @@ -85,6 +85,8 @@ def random_atomic_measure_generator( elif degeneracy == "null": # These need randomly inserting instead, rather than being all at the end. nodes = nodes.at[(d - 1) :].set(0.0) + weights = weights.at[(d - 1) :].set(0.0) + weights = weights / jnp.sum(weights) else: raise ValueError( f"Degeneracy must be one of {get_args(DEGENERACIES)}; " @@ -171,8 +173,8 @@ def test_tree_recombination( ) for weights, nodes in case_iterator: measure_reduction = recombination - # if degeneracy is None and mode != "qr": - # measure_reduction = eqx.filter_jit(recombination) + if degeneracy is None and mode != "qr": + measure_reduction = eqx.filter_jit(recombination) with warnings.catch_warnings(record=True) as record: warnings.simplefilter("always") result_weights, result_nodes = measure_reduction(