Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option for more efficient graph matching matrix operations #1046

Merged
merged 10 commits into from
Aug 7, 2023
68 changes: 52 additions & 16 deletions graspologic/match/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from beartype import beartype
from ot import sinkhorn
from scipy.optimize import linear_sum_assignment
from scipy.sparse import csr_array
from sklearn.utils import check_scalar

from graspologic.types import List, RngType, Tuple
Expand Down Expand Up @@ -82,6 +81,7 @@ def __init__(
transport_regularizer: Scalar = 100,
transport_tol: Scalar = 5e-2,
transport_max_iter: Int = 1000,
fast: bool = True,
):
# TODO check if init is doubly stochastic
self.init = init
Expand Down Expand Up @@ -113,6 +113,8 @@ def __init__(
)
self.transport_max_iter = transport_max_iter

self.fast = fast

if maximize:
self.obj_func_scalar = -1
else:
Expand Down Expand Up @@ -406,6 +408,7 @@ def compute_step_size(self, P: np.ndarray, Q: np.ndarray) -> float:
self.BA_ns,
self.BA_sn,
self.S_nn,
fast=self.fast,
)
if a * self.obj_func_scalar > 0 and 0 <= -b / (2 * a) <= 1:
alpha = -b / (2 * a)
Expand Down Expand Up @@ -538,6 +541,14 @@ def _compute_gradient(
return grad


def _fast_trace(X: np.ndarray, Y: np.ndarray) -> float:
return (X * Y.T).sum()


def _fast_traceT(X: np.ndarray, Y: np.ndarray) -> float:
return (X * Y).sum()


def _compute_coefficients(
P: np.ndarray,
Q: np.ndarray,
Expand All @@ -554,25 +565,50 @@ def _compute_coefficients(
BA_ns: MultilayerAdjacency,
BA_sn: MultilayerAdjacency,
S: AdjacencyMatrix,
fast: bool,
) -> Tuple[float, float]:
R = P - Q
# TODO make these "smart" traces like in the scipy code, couldn't hurt
# TODO can also refactor to not repeat multiplications like the old code but I was
# finding it harder to follow that way.

n_layers = len(A)
a_cross = 0
b_cross = 0
a_intra = 0
b_intra = 0
a_cross = 0.0
b_cross = 0.0
a_intra = 0.0
b_intra = 0.0

for i in range(n_layers):
a_cross += np.trace(AB[i].T @ R @ BA[i] @ R)
b_cross += np.trace(AB[i].T @ R @ BA[i] @ Q) + np.trace(AB[i].T @ Q @ BA[i] @ R)
b_cross += np.trace(AB_ns[i].T @ R @ BA_ns[i]) + np.trace(
AB_sn[i].T @ BA_sn[i] @ R
)
a_intra += np.trace(A[i] @ R @ B[i].T @ R.T)
b_intra += np.trace(A[i] @ Q @ B[i].T @ R.T) + np.trace(A[i] @ R @ B[i].T @ Q.T)
b_intra += np.trace(A_ns[i].T @ R @ B_ns[i]) + np.trace(A_sn[i] @ R @ B_sn[i].T)
if fast:
# could maybe be even faster if we do `opt_einsum` or something
ABiTR = AB[i].T @ R
BAiR = BA[i] @ R
AiR = A[i] @ R
RBi = R @ B[i]

a_cross += _fast_trace(ABiTR, BAiR)
b_cross += _fast_trace(ABiTR, BA[i] @ Q)
b_cross += _fast_trace(AB[i].T @ Q, BAiR)
b_cross += _fast_trace(AB_ns[i].T @ R, BA_ns[i])
b_cross += _fast_trace(AB_sn[i].T @ BA_sn[i], R)

a_intra += _fast_traceT(AiR, RBi)
b_intra += _fast_traceT(A[i] @ Q, RBi)
b_intra += _fast_traceT(AiR, Q @ B[i])
b_intra += _fast_trace(A_ns[i].T @ R, B_ns[i])
b_intra += _fast_traceT(A_sn[i] @ R, B_sn[i])
else:
a_cross += np.trace(AB[i].T @ R @ BA[i] @ R)
b_cross += np.trace(AB[i].T @ R @ BA[i] @ Q) + np.trace(
AB[i].T @ Q @ BA[i] @ R
)
b_cross += np.trace(AB_ns[i].T @ R @ BA_ns[i]) + np.trace(
AB_sn[i].T @ BA_sn[i] @ R
)
a_intra += np.trace(A[i] @ R @ B[i].T @ R.T)
b_intra += np.trace(A[i] @ Q @ B[i].T @ R.T) + np.trace(
A[i] @ R @ B[i].T @ Q.T
)
b_intra += np.trace(A_ns[i].T @ R @ B_ns[i]) + np.trace(
A_sn[i] @ R @ B_sn[i].T
)

a = a_cross + a_intra
b = b_cross + b_intra
Expand Down
9 changes: 8 additions & 1 deletion graspologic/match/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def graph_match(
transport_regularizer: Scalar = 100,
transport_tol: Scalar = 5e-2,
transport_max_iter: Int = 1000,
fast: bool = True,
) -> MatchResult:
"""
Attempts to solve the Graph Matching Problem or the Quadratic Assignment Problem
Expand Down Expand Up @@ -192,6 +193,12 @@ def graph_match(
Setting this value higher may provide more precise solutions at the cost of
longer computation time.

fast: bool, default=True
Whether to use numerical shortcuts to speed up the computation. Typically will
be faster for most applications, although requires storing intermediate
computations in memory which may be undesirable for very large inputs or when
memory is a bottleneck.

Returns
-------
res: MatchResult
Expand Down Expand Up @@ -281,7 +288,6 @@ def graph_match(
partial_match=partial_match,
init=init,
init_perturbation=init_perturbation,
verbose=solver_verbose,
shuffle_input=shuffle_input,
padding=padding,
maximize=maximize,
Expand All @@ -291,6 +297,7 @@ def graph_match(
transport_regularizer=transport_regularizer,
transport_tol=transport_tol,
transport_max_iter=transport_max_iter,
fast=fast,
)

def run_single_graph_matching(seed: RngType) -> MatchResult:
Expand Down