diff --git a/coreax/util.py b/coreax/util.py index 920cb09a1..364a69583 100644 --- a/coreax/util.py +++ b/coreax/util.py @@ -239,7 +239,10 @@ def solve_qp(kernel_mm: ArrayLike, gramian_row_mean: ArrayLike, **osqp_kwargs) - sol = qp.run( params_obj=(q_array, c), params_eq=(a_array, b), params_ineq=(g_array, h) ).params - return sol.primal + + # Ensure conditions of solution are met + solution = apply_negative_precision_threshold(sol.primal, jnp.inf) + return solution / jnp.sum(solution) def sample_batch_indices(