Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Cardinality constraint #1749

Closed
swierh opened this issue Mar 17, 2023 · 13 comments
Closed

[Feature Request] Cardinality constraint #1749

swierh opened this issue Mar 17, 2023 · 13 comments
Labels
enhancement New feature or request

Comments

@swierh
Copy link

swierh commented Mar 17, 2023

🚀 Feature Request

We've got a problem where we have a constraint on the number of non-zero inputs. Out of 10 possible features, we can only select 5 at a time for any sample in the q-batch. This is because in a certain sample we can only use a maximum of 5 different ingredients which we have to pick from 10 possible ingredients.

We tried implementing this by using a penalized acquisition function by adding a loss on the cardinality. This works, but a constraint seems like a better place for this.

Also, when we add a normal inequality constraint (e.g. x1 + x2 > 0), the optimizer doesn't seem to find solutions with low cardinality.

We've experimented with adding the cardinality constraint as a non-linear constraint, but it appears that have to write a custom sampler to get this working.

Is there a cleaner way of doing this, what would you suggest as the right approach for this?

All help will be much appreciated!

The code we've implemented for the penalty is as follows:

class CardinalityPenalty(torch.nn.Module):
    r"""Cardinality penalty class to be added to any arbitrary acquisition function
    to construct a PenalizedAcquisitionFunction."""

    def __init__(self, init_point: Tensor, cardinality: float, epsilon: float = 0.01):
        r"""Initializing cardinality regularization.

        Args:
            init_point: The "1 x dim" reference point against which
                we want to regularize.
            cardinality: The cardinality below which penality is 0.
        """
        super().__init__()
        self.init_point = init_point
        self.allowed_cardinality = cardinality
        self.epsilon = epsilon

    def forward(self, X: Tensor) -> Tensor:
        r"""
        Args:
            X: A "batch_shape x q x dim" representing the points to be evaluated.
            The values in X must be -1 ≤ x ≤ 1.

        Returns:
            A tensor of size "batch_shape" representing the acqfn for each q-batch.
        """

        # We allow for a small deviation from 0, and we don't care about the sign:
        X_rel = (X - self.init_point).abs().sub(self.epsilon).clip(0)

        cardinality = torch.norm(X_rel, p=0, dim=-1)

        # If the cardinality is too high, we want the point to move to the closest
        # allowed plane, so we penalize the smallest non-zero value:
        min_non_zero = (X_rel + (X_rel == 0)).min(dim=-1).values

        # We penalize the cardinality if it's higher than the allowed cardinality.
        # At or below that cardinality, we don't care about the non-zero values,
        # which is why we subtract 1.
        # When the cardinality is higher than the allowed value, we start penalizing
        # the smallest non-zero value.
        penalty = (cardinality - self.allowed_cardinality - 1 + min_non_zero).clip(0)

        # The penalty must be a single value for the entire batch:
        penalty = penalty.mean(dim=-1) ** 2
        return penalty

And we then wrap our acquisition function as follows:

acqf = qKnowledgeGradient(model, num_fantasies=10)
penalty = CardinalityPenalty(
    torch.tensor(np.zeros_like(X[0]), dtype=torch.float64, device=device),
    cardinality=5,
)
acqf = PenalizedAcquisitionFunction(acqf, penalty, 1.0)

x_batch, _ = optimize_acqf(acq_function=acqf, ...)
@swierh swierh added the enhancement New feature or request label Mar 17, 2023
@eytan
Copy link
Contributor

eytan commented Mar 17, 2023

cc @dme65 , @bletham , @qingfeng10.

Is the 5 at a time a hard constraint or some ideal number that you’d like to stick to? I wonder if some of the modeling+optimization approaches from https://arxiv.org/pdf/2203.01900.pdf could be useful.

@saitcakmak
Copy link
Contributor

We've experimented with adding the cardinality constraint as a non-linear constraint, but it appears that have to write a custom sampler to get this working.

The custom sampler here is just something to generate the initial conditions for the optimizer to use for multi-start optimization. A lazy implementation could call gen_batch_initial_conditions in a loop, check if the returned ICs satisfy the constraints, and stop once num_restart ICs satisfying the constraints is found. A better approach could be to modify the internals of gen_batch_initial_conditions (where it generates X_rnd) to do a similar thing, but to generate the raw samples this time (i.e., to ensure that all of X_rnd satisfies the constraints). This is generally hard to do for a generic non-linear constraint, which is why we do not support it out of the box. Rejection sampling (which is pretty much what I described) can be very inefficient if the constraints restrict the feasible candidates to a very small subspace, which doesn't seem to be the case with your constraints.

@jduerholt
Copy link
Contributor

Hi, I call this kind of constraint an NChooseK constraint, because out of N ingredients you are able to pick k. In some discussion in the ax issues with @sgbaird ,it was shown that you can express this via a continuous relaxation as nonlinear inequality constraint, which is exactly the relaxation which is also used in the paper which @eytan mentioned above. I try to dig this discussion out later.

Currently I am implementing this in our own botorch wrapper focusing on materials discovery. Hopefully it is finished next week. In this method the NchooseK constraints are transformed into the botorch definition of a nonlinear inequality:

https://github.com/experimental-design/bofire/blob/cc33f21f4d33c3a58e20e88c8694a14c12a2713e/bofire/utils/torch_tools.py#L88

You can take this defintions and just give it to optimize_acqf.

Then you just need to setup initial conditions adhering to the constraints. This can be done by our Polytopesampler based on the polytopesampler of botorch which is also able to handle NChooseK constraints.

Here is an example in which the NChooseK constraint is combined with an additional linear equality (mixture constraint)

from bofire.data_models.constraints.api import (
    LinearInequalityConstraint,
    LinearEqualityConstraint,
    NChooseKConstraint,
)
from bofire.data_models.domain.api import Domain
from bofire.data_models.features.api import CategoricalInput, ContinuousInput

import bofire.strategies.api as strategies
import bofire.data_models.strategies.api as data_models
from bofire.utils.torch_tools import get_nchoosek_constraints

features = [ContinuousInput(key=f"if{i+1}", lower_bound=0, upper_bound=1) for i in range(10)]

constraints = [
    LinearEqualityConstraint(features=[f"if{i+1}" for i in range(10)], coefficients=[1.0]*10, rhs=1),
    NChooseKConstraint(features=[f"if{i+1}" for i in range(10)], min_count=0, max_count=5, none_also_valid=True)
    ]


domain = Domain(
        input_features=features,
        constraints=constraints,
    )

data_model = data_models.PolytopeSampler(domain=domain)
sampler = strategies.PolytopeSampler(data_model=data_model)

# this generates the samples, it is currently a bit slow as we generate the complete combinatorics, this can be easily speedup
samples = sampler.ask(5, return_all=False)

# this gives you callables for nonlinear inequality constraint to handle it in optimize_acqf
callables = get_nchoosek_constraints(domain=domain)

This are the generated samples, fulfilling both constraints.

image

In the next weeks this should be all automatically get integrated. Note that bofire is still in a kind of alpha stage ;)

@swierh
Copy link
Author

swierh commented Apr 13, 2023

Thanks for the feedback everyone!

Through a referenced issue, I came across this comment which has a much better solution to our problem than what I originally came up with.

We based the following code on that solution, which works like a charm:

from botorch.acquisition import AcquisitionFunction
from botorch.optim import optimize_acqf
import torch
from torch import Tensor
from torch.quasirandom import SobolEngine


def narrow_gaussian(x: Tensor, epsilon: float):
    return torch.exp(-0.5 * (x / epsilon) ** 2)


def build_cardinality_constraint(cardinality: int, epsilon: float = 1e-2):
    """
    Builds a constraint function that checks whether the vector has a cardinality
    no higher than specified.

    If the constraint is met, the result will be >= 0.
    """

    def cardinality_constraint(x: Tensor):
        """
        Checks whether the vector has a cardinality no higher than specified.

        If the constraint is met, the result will be >= 0.
        """
        return narrow_gaussian(x, epsilon).sum(dim=-1) - x.shape[-1] + cardinality

    return cardinality_constraint


def generate_cardinality_limited_points(
    n: int,
    x_dim: int,
    cardinality: int,
    q: int = 1,
):
    """
    Generate initial points that are of the specified cardinality.
    """
    X = SobolEngine(dimension=x_dim, scramble=True).draw(n * q).to(torch.double)
    x_idx = torch.arange(X.shape[0]).unsqueeze(-1)
    y_idx = torch.argsort(torch.rand(n * q, x_dim), dim=-1)[..., : x_dim - cardinality]
    X[x_idx, y_idx] = 0
    X = X.reshape(n, q, x_dim)
    return X


def get_cardinality_limited_batch_initial_conditions(
    num_restarts: int,
    raw_samples: int,
    acqf: AcquisitionFunction,
    x_dim: int,
    cardinality: int,
    q: int = 1,
):
    """
    Get initial conditions that are of the specified cardinality.
    """
    X = generate_cardinality_limited_points(
        raw_samples, x_dim=x_dim, cardinality=cardinality, q=q
    )
    return X[acqf(X).topk(num_restarts).indices]

cardinality_constraint = build_cardinality_constraint(cardinality)
batch_initial_conditions = get_cardinality_limited_batch_initial_conditions(...)
candidate, _ = optimize_acqf(
    nonlinear_inequality_constraints=[cardinality_constraint],
    batch_initial_conditions=batch_initial_conditions,
    ...
)

We did run into the next problem, which is that nonlinear constraints have not been implemented in botorch for q>1 (or for knowledge gradient). These constraints are handled through the make_scipy_nonlinear_inequality_constraints function, which works differently from the linear one.

@Balandat
Copy link
Contributor

Glad to hear you were able to make some progress on this.

We did run into the next problem, which is that nonlinear constraints have not been implemented in botorch for q>1 (or for knowledge gradient). These constraints are handled through the make_scipy_nonlinear_inequality_constraints function, which works differently from the linear one.

Doing this should hopefully not be too hard if the constraints are intra-point constraints that apply separately to each of the q candidates (rather than a constraint across the different candidates in a batch). Looks like that is your use case?

Out of 10 possible features, we can only select 5 at a time for any sample in the q-batch.

For KG it's going to be a bit trickier, since the locations of the fantasy observations also need to satisfy those constraints. That won't really be an overly challenging problem to hook up, but my concern is that you'll end up with a huge number of nonlinear constraints in the acquisition function optimization problem that way. By default that optimization will use SLSQP which is exceedingly slow with many constraints. Do you need to use KG here?

@swierh
Copy link
Author

swierh commented Apr 13, 2023

Doing this should hopefully not be too hard if the constraints are intra-point constraints that apply separately to each of the q candidates (rather than a constraint across the different candidates in a batch). Looks like that is your use case?

Yes, the points in a batch are independent. So far I've given it one attempt, but got stuck on the shape of x0 causing a bunch of conflicts. I'll hopefully have some time soon to give it another go.

For KG it's going to be a bit trickier, since the locations of the fantasy observations also need to satisfy those constraints. That won't really be an overly challenging problem to hook up, but my concern is that you'll end up with a huge number of nonlinear constraints in the acquisition function optimization problem that way. By default that optimization will use SLSQP which is exceedingly slow with many constraints. Do you need to use KG here?

We can use other acquisition functions, though KG seems to perform well on our specific challenge. As for it being exceedingly slow, for our challenge it's not really a problem if the computation of a single batch takes a few hours, so that might be something we can test.

@jduerholt
Copy link
Contributor

Yes, the points in a batch are independent. So far I've given it one attempt, but got stuck on the shape of x0 causing a bunch of conflicts. I'll hopefully have some time soon to give it another go.

@swierh: are you already on it making the non-linear constraints also possible for q>1? Else, I can also give it a try as we also need it. I tested so far only with q=1. A quick fix could also be to optimize with sequential = True.

@Balandat: concerning speed, for larger q-batches and higher dimensional problems with a lot of constraints, one could also try how ipopt peforms instead of SLSQP. Could be faster there, and the interface is the same as for the scipy optimizers.

@Balandat
Copy link
Contributor

one could also try how ipopt peforms instead of SLSQP

Yeah that's also possible. I've used IPOPT quite a bit in grad school, at the time it wasn't exactly easy to build and run on different platforms with a python interface - but I assume this has improved since? If you gave this a try I'd love to hear about how it goes.

@jduerholt
Copy link
Contributor

We are using it to generate D-optimal designs using cyipopt as python wrapper. Installation is still not so easy, especially with pip. As soon as I have time, I will give it a try for acqf optimization ;)

@swierh
Copy link
Author

swierh commented Apr 14, 2023

@jduerholt I'm not currently working on it, and won't have much time for at least the coming week, so feel free to give it a try.

@jduerholt
Copy link
Contributor

@swierh: the PR regarding problems with nonlinear constraints with q > 1 has been merged last week (#1793), and the feature is now available in the main branch.

@swierh
Copy link
Author

swierh commented Nov 23, 2023

@jduerholt Awesome! Thanks so much for the effort you and @Balandat put into it!

Resolved by #1793

@swierh swierh closed this as completed Nov 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

6 participants