Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix negative weights #698

Merged
merged 6 commits into from
Jul 15, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions coreax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,12 @@ def pairwise_difference(x: ArrayLike, y: ArrayLike) -> Array:
return pairwise(difference)(x, y)


def solve_qp(kernel_mm: ArrayLike, gramian_row_mean: ArrayLike, **osqp_kwargs) -> Array:
def solve_qp(
kernel_mm: ArrayLike,
gramian_row_mean: ArrayLike,
precision_threshold: float = 1e-12,
**osqp_kwargs,
) -> Array:
r"""
Solve quadratic programs with the :class:`jaxopt.OSQP` solver.

Expand All @@ -218,6 +223,8 @@ def solve_qp(kernel_mm: ArrayLike, gramian_row_mean: ArrayLike, **osqp_kwargs) -

:param kernel_mm: :math:`m \times m` coreset Gram matrix
:param gramian_row_mean: :math:`m \times 1` array of Gram matrix means
:precision_threshold: Threshold below which values are rounded to zero (accommodates
precision loss)
:return: Optimised solution for the quadratic program
"""
# Setup optimisation problem - all variable names are consistent with the OSQP
Expand All @@ -239,7 +246,10 @@ def solve_qp(kernel_mm: ArrayLike, gramian_row_mean: ArrayLike, **osqp_kwargs) -
sol = qp.run(
params_obj=(q_array, c), params_eq=(a_array, b), params_ineq=(g_array, h)
).params
return sol.primal

# Ensure conditions of solution are met to chosen precision
solution = jnp.where(sol.primal < jnp.abs(precision_threshold), 0, sol.primal)
pc532627 marked this conversation as resolved.
Show resolved Hide resolved
return solution / jnp.sum(solution)


def sample_batch_indices(
Expand Down