Skip to content

Commit

Permalink
#779 At coreax.solvers.composite, changed the return statement of the…
Browse files Browse the repository at this point in the history
… _jit_tree function to return indices as well, changed the reduce method to keep track of indices, added a line plt.show() to pounce.py
  • Loading branch information
qh681248 committed Sep 24, 2024
1 parent df8f4fc commit 272da26
Showing 1 changed file with 46 additions and 12 deletions.
58 changes: 46 additions & 12 deletions coreax/solvers/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,57 @@ def __check_init__(self):

@override
def reduce(
self, dataset: _Data, solver_state: Optional[_State] = None
self, dataset: _Data, solver_state: Optional[_State] = None
) -> tuple[_Coreset, _State]:
# There is no obvious way to use state information here.
del solver_state

def _reduce_coreset(data: _Data) -> tuple[_Coreset, _State]:
def _reduce_coreset(data: _Data, _indices=None) -> tuple[_Coreset, _State, _Data]:
if len(data) <= self.leaf_size:
return self.base_solver.reduce(data)
partitioned_dataset = _jit_tree(data, self.leaf_size, self.tree_type)
coreset_ensemble, _ = jax.vmap(self.base_solver.reduce)(partitioned_dataset)
coreset, state = self.base_solver.reduce(data)
if _indices is not None:
_indices = _indices[coreset.unweighted_indices]
return coreset, state, _indices

def wrapper(row: _Data) -> tuple[_Data, _Data]:
"""
Apply the reduce method of the base solver on a dataset and
return the data and unweighted indices of the coreset.
It is a wrapper to process a single partition (row) of the result of _jit_tree
that works with the vmap
"""
x, _ = self.base_solver.reduce(row)
return x.coreset, x.unweighted_indices

def get_indices(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
"""
Perform advanced indexing on array 'a' using indices 'b'.
Returns a new array with elements of 'a' at positions specified by 'b'.
"""
return a[b]

# First partition the data
partitioned_dataset, partitioned_indices = _jit_tree(data, self.leaf_size, self.tree_type)
# Then apply base solver to each partition and keep track of indices with respect to partitions
coreset_ensemble, ensemble_indices = jax.vmap(wrapper)(partitioned_dataset)
# Calculate the indices with respect to the data (one passed to _reduce_coreset)
concatenated_indices = jax.vmap(get_indices)(partitioned_indices, ensemble_indices)
# flatten the indices
concatenated_indices = jnp.ravel(concatenated_indices)
_coreset = jtu.tree_map(jnp.concatenate, coreset_ensemble)
return _reduce_coreset(_coreset.coreset)

coreset_wrong_pre_coreset_data, output_solver_state = _reduce_coreset(dataset)
coreset = eqx.tree_at(
lambda x: x.pre_coreset_data, coreset_wrong_pre_coreset_data, dataset
)
return coreset, output_solver_state
if _indices is not None:
final_indices = _indices[concatenated_indices]
else:
final_indices = concatenated_indices
return _reduce_coreset(_coreset, final_indices)

coreset, output_solver_state, _indices = _reduce_coreset(dataset)
del coreset
final_coreset = Coresubset(_indices, dataset)

return final_coreset, output_solver_state


def _jit_tree(dataset: _Data, leaf_size: int, tree_type: type[BinaryTree]) -> _Data:
Expand Down Expand Up @@ -183,4 +216,5 @@ def _binary_tree(_input_data: Data) -> np.ndarray:
return node_indices.reshape(n_leaves, -1).astype(np.int32)

indices = jax.pure_callback(_binary_tree, result_shape, padded_dataset)
return dataset[indices]
return dataset[indices], indices # (Now it returns both data and the indices)

0 comments on commit 272da26

Please sign in to comment.