This repository provides a collection of PyTorch modules for combinatorial solvers based on Differentiation of Blackbox Combinatorial Solvers.
By Marin Vlastelica*, Anselm Paulus*, Vít Musil, Georg Martius and Michal Rolínek.
Autonomous Learning Group, Max Planck Institute for Intelligent Systems.
This repository contains PyTorch modules that wrap blackbox combinatorial solver via the method proposed in Differentiation of Blackbox Combinatorial Solvers. Besides the solvers employed in the original paper, this repo includes wrapped solvers for ranking (as used in Blackbox Optimizationof Rank-Based Metrics) and Graph Matching/Multigraph Matching (as used in Deep Graph Matching via Blackbox Differentiation of Combinatorial Solvers|).
Disclaimer: This code is a PROTOTYPE. It should work fine but use at your own risk.
For the exact usage of the combinatorial modules, see our public implementations of
- Differentiation of Blackbox Combinatorial Solvers
- Deep Graph Matching via Blackbox Differentiation of Combinatorial Solvers.
Simply install with pip
python3 -m pip install git+https://github.com/martius-lab/blackbox-backprop
For running the TSP module, a manual GurobiPy installation is required as well as a license
Currently, the following solver modules are available (the list will be growing over time)
Combinatorial Problem | Solver | Paper |
---|---|---|
Travelling Salesman | Cutting plane algorithm implemented in Gurobi | Differentiation of Blackbox Combinatorial Solvers |
Shortest Path (on a grid) | Dijkstra algorithm (vertex version) | Differentiation of Blackbox Combinatorial Solvers |
Min-cost Perfect matching on general graphs | Blossom V (Kolmogorov, 2009) | Differentiation of Blackbox Combinatorial Solvers |
Ranking (+ induced Recall & mAP loss functions) | torch.argsort |
Blackbox Optimizationof Rank-Based Metrics |
Graph Matching | Swoboda, 2017 | Deep Graph Matching via Blackbox Differentiation of Combinatorial Solvers |
Multigraph Matching | Swoboda, 2019 | Deep Graph Matching via Blackbox Differentiation of Combinatorial Solvers |
The graph matching and multigraph matching solver and corresponding differentiable PyTorch modules are hosted at the LPMP repository
Exactly as you would expect of a PyTorch module (with minor details differing from solver to solver)
import blackbox_backprop as bb
...
suggested_weights = ResNet18(raw_inputs)
suggested_shortest_paths = bb.ShortestPath(suggested_weights, lambda_val=5.0) # Set the lambda hyperparameter
loss = HammingLoss(suggested_shortest_paths, true_shortest_paths) # Use e.g. Hamming distance as the loss function
loss.backward() # The backward pass is handled automatically
...
Visualizations that have appeared in the papers can be generated in the attached jupyter notebook.
This requires python packages ipyvolume
and ipywidgets
. Also, make sure to allow all jupyter nbextensions as listed here.
Ranking | Shortest path | Graph Matching |
---|---|---|
Contribute: If you spot a bug or some incompatibility, raise an issue or contribute via a pull request! Thank you!