Skip to content

Commit

Permalink
Merge branch 'main' into gh-pages
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasBoTang committed Jul 14, 2023
2 parents a5c1670 + 3788197 commit 74cc028
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pkg/pyepo/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
Pytorch autograd function for end-to-end training
"""

from pyepo.func.blackbox import blackboxOpt
from pyepo.func.spoplus import SPOPlus
from pyepo.func.blackbox import blackboxOpt, negativeIdentity
from pyepo.func.perturbed import perturbedOpt, perturbedFenchelYoung
from pyepo.func.contrastive import NCE, contrastiveMAP
from pyepo.func.rank import listwiseLTR, pairwiseLTR, pointwiseLTR
91 changes: 90 additions & 1 deletion pkg/pyepo/func/blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def backward(ctx, grad_output):
rand_sigma = ctx.rand_sigma
if solve_ratio < 1:
module = ctx.module
ins_num = len(pred_cost)
# get device
device = pred_cost.device
# convert tenstor
Expand All @@ -148,3 +147,93 @@ def backward(ctx, grad_output):
grad = np.array(grad)
grad = torch.FloatTensor(grad).to(device)
return grad, None, None, None, None, None, None


class negativeIdentity(optModule):
"""
An autograd module for differentiable optimizer, which yield optimal a
solution and use negative identity as gradient on the backward pass.
For negative identity backpropagation, the objective function is linear and
constraints are known and fixed, but the cost vector need to be predicted
from contextual data.
If the interpolation hyperparameter λ aligns with an appropriate step size,
then the identity update is tantamount to DBB. However, the identity update
does not require an additional call to the solver during the backward pass
and tuning an additional hyperparameter λ.
Reference: <https://arxiv.org/abs/2205.15213>
"""

def __init__(self, optmodel, processes=1, solve_ratio=1, dataset=None):
"""
Args:
optmodel (optModel): an PyEPO optimization model
processes (int): number of processors, 1 for single-core, 0 for all of cores
solve_ratio (float): the ratio of new solutions computed during training
dataset (None/optDataset): the training data
"""
super().__init__(optmodel, processes, solve_ratio, dataset)
# build blackbox optimizer
self.nid = negativeIdentityFunc()

def forward(self, pred_cost):
"""
Forward pass
"""
sols = self.nid.apply(pred_cost, self.optmodel, self.processes,
self.pool, self.solve_ratio, self)
return sols


class negativeIdentityFunc(Function):
"""
A autograd function for differentiable black-box optimizer
"""

@staticmethod
def forward(ctx, pred_cost, optmodel, processes, pool, solve_ratio, module):
"""
Forward pass for NID
Args:
pred_cost (torch.tensor): a batch of predicted values of the cost
optmodel (optModel): an PyEPO optimization model
processes (int): number of processors, 1 for single-core, 0 for all of cores
pool (ProcessPool): process pool object
solve_ratio (float): the ratio of new solutions computed during training
module (optModule): blackboxOpt module
Returns:
torch.tensor: predicted solutions
"""
# get device
device = pred_cost.device
# convert tenstor
cp = pred_cost.detach().to("cpu").numpy()
# solve
rand_sigma = np.random.uniform()
if rand_sigma <= solve_ratio:
sol, _ = _solve_in_pass(cp, optmodel, processes, pool)
if solve_ratio < 1:
# add into solpool
module.solpool = np.concatenate((module.solpool, sol))
# remove duplicate
module.solpool = np.unique(module.solpool, axis=0)
else:
sol, _ = _cache_in_pass(cp, optmodel, module.solpool)
# convert to tensor
pred_sol = torch.FloatTensor(np.array(sol)).to(device)
return pred_sol

@staticmethod
def backward(ctx, grad_output):
"""
Backward pass for NID
"""
# get device
device = grad_output.device
# identity matrix
I = torch.eye(grad_output.shape[1]).to(device)
return grad_output @ (-I), None, None, None, None, None
2 changes: 1 addition & 1 deletion pkg/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# description
description = "PyTorch-based End-to-End Predict-then-Optimize Tool",
# version
version = "0.3.2",
version = "0.3.3",
# Github repo
url = "https://github.com/khalil-research/PyEPO",
# author name
Expand Down

0 comments on commit 74cc028

Please sign in to comment.