Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tensorflow_backend.py and linear_solvers.py #17

Merged
merged 8 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions pygmtools/classic_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,53 @@ def sm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
>>> (pygm.hungarian(X) * X_gt).sum() / X_gt.sum()
jt.Var([1.], dtype=float32)


.. dropdown:: Tensorflow Example

::

>>> import tensorflow as tf
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'tensorflow'
>>> _ = tf.random.set_seed(1)

# Generate a batch of isomorphic graphs
>>> batch_size = 10
>>> X_gt = tf.Variable(tf.zeros([batch_size, 4, 4]))
>>> indices = tf.stack([tf.range(4),tf.random.shuffle(tf.range(4))], axis=1)
>>> updates = tf.ones([4])
>>> for i in range(batch_size):
... _ = X_gt[i].assign(tf.tensor_scatter_nd_update(X_gt[i], indices, updates))
>>> A1 = tf.random.uniform([batch_size, 4, 4])
>>> A2 = tf.matmul(tf.matmul(tf.transpose(X_gt, perm=[0, 2, 1]), A1), X_gt)
>>> n1 = n2 = tf.constant([4] * batch_size)

# Build affinity matrix
>>> conn1, edge1, ne1 = pygm.utils.dense_to_sparse(A1)
>>> conn2, edge2, ne2 = pygm.utils.dense_to_sparse(A2)
>>> import functools
>>> gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1.) # set affinity function
>>> K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff)

# Solve by SM. Note that X is normalized with a squared sum of 1
>>> X = pygm.sm(K, n1, n2)
>>> tf.reduce_sum((X ** 2), axis=[1, 2])
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([1. , 1.0000001 , 1. , 0.9999999 , 1. ,
1. , 1.0000001 , 0.99999994, 1. , 0.9999998 ],
dtype=float32)>

# Accuracy
>>> tf.reduce_sum((pygm.hungarian(X) * X_gt))/ tf.reduce_sum(X_gt)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

This solver supports gradient back-propogation
>>> K = tf.Variable(K)
>>> with tf.GradientTape() as tape:
... y = tf.reduce_sum(pygm.sm(K, n1, n2))
... len(tf.where(tape.gradient(y, K)))
2560

.. note::
If you find this graph matching solver useful for your research, please cite:

Expand Down Expand Up @@ -456,6 +503,52 @@ def rrwm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
>>> (pygm.hungarian(X) * X_gt).sum() / X_gt.sum()
jt.Var([1.], dtype=float32)

.. dropdown:: Tensorflow Example

::

>>> import tensorflow as tf
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'tensorflow'
>>> _ = tf.random.set_seed(1)

# Generate a batch of isomorphic graphs
>>> batch_size = 10
>>> X_gt = tf.Variable(tf.zeros([batch_size, 4, 4]))
>>> indices = tf.stack([tf.range(4),tf.random.shuffle(tf.range(4))], axis=1)
>>> updates = tf.ones([4])
>>> for i in range(batch_size):
... _ = X_gt[i].assign(tf.tensor_scatter_nd_update(X_gt[i], indices, updates))
>>> A1 = tf.random.uniform([batch_size, 4, 4])
>>> A2 = tf.matmul(tf.matmul(tf.transpose(X_gt, perm=[0, 2, 1]), A1), X_gt)
>>> n1 = n2 = tf.constant([4] * batch_size)

# Build affinity matrix
>>> conn1, edge1, ne1 = pygm.utils.dense_to_sparse(A1)
>>> conn2, edge2, ne2 = pygm.utils.dense_to_sparse(A2)
>>> import functools
>>> gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1.) # set affinity function
>>> K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff)

# Solve by RRWM. Note that X is normalized with a sum of 1
>>> X = pygm.rrwm(K, n1, n2, beta=100)
>>> tf.reduce_sum(X, axis=[1, 2])
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([1. , 1. , 1. , 0.99999994, 1. ,
1. , 1. , 0.99999994, 0.99999994, 1.0000001 ],
dtype=float32)>

# Accuracy
>>> tf.reduce_sum((pygm.hungarian(X) * X_gt)) / tf.reduce_sum(X_gt)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

# This solver supports gradient back-propogation
>>> K = tf.Variable(K)
>>> with tf.GradientTape(persistent=True) as tape:
... y = tf.reduce_sum(pygm.rrwm(K, n1, n2, beta=100))
... len(tf.where(tape.gradient(y, K)))
768

.. note::
If you find this graph matching solver useful in your research, please cite:

Expand Down Expand Up @@ -687,6 +780,46 @@ def ipfp(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
>>> (pygm.hungarian(X) * X_gt).sum() / X_gt.sum()
jt.Var([1.], dtype=float32)

.. dropdown:: Tensorflow Example

::

>>> import tensorflow as tf
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'tensorflow'
>>> _ = tf.random.set_seed(1)

# Generate a batch of isomorphic graphs
>>> batch_size = 10
>>> X_gt = tf.Variable(tf.zeros([batch_size, 4, 4]))
>>> indices = tf.stack([tf.range(4),tf.random.shuffle(tf.range(4))], axis=1)
>>> updates = tf.ones([4])
>>> for i in range(batch_size):
... _ = X_gt[i].assign(tf.tensor_scatter_nd_update(X_gt[i], indices, updates))
>>> A1 = tf.random.uniform([batch_size, 4, 4])
>>> A2 = tf.matmul(tf.matmul(tf.transpose(X_gt, perm=[0, 2, 1]), A1), X_gt)
>>> n1 = n2 = tf.constant([4] * batch_size)

# Build affinity matrix
>>> conn1, edge1, ne1 = pygm.utils.dense_to_sparse(A1)
>>> conn2, edge2, ne2 = pygm.utils.dense_to_sparse(A2)
>>> import functools
>>> gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1.) # set affinity function
>>> K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff)

# Solve by IPFP
>>> X = pygm.ipfp(K, n1, n2)
>>> X[0]
<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[0., 0., 1., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]], dtype=float32)>

# Accuracy
>>> tf.reduce_sum((pygm.hungarian(X) * X_gt)) / tf.reduce_sum(X_gt)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

.. note::
If you find this graph matching solver useful in your research, please cite:

Expand Down
175 changes: 175 additions & 0 deletions pygmtools/linear_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,106 @@ def sinkhorn(s, n1=None, n2=None, unmatch1=None, unmatch2=None,
>>> print('row_sum:', x.sum(1), 'col_sum:', x.sum(0))
row_sum: jt.Var([0.8776659 0.52249795 0.45705223 0.8203896 0.70309657], dtype=float32) col_sum: jt.Var([0.7831943 0.30627945 0.73395807 0.61827636 0.938994 ], dtype=float32)

.. dropdown:: Tensorflow Example

::

>>> import tensorflow as tf
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'tensorflow'
>>> np.random.seed(0)

# 2-dimensional (non-batched) input
>>> s_2d = tf.constant(np.random.rand(5, 5))
>>> s_2d
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ],
[0.64589411, 0.43758721, 0.891773 , 0.96366276, 0.38344152],
[0.79172504, 0.52889492, 0.56804456, 0.92559664, 0.07103606],
[0.0871293 , 0.0202184 , 0.83261985, 0.77815675, 0.87001215],
[0.97861834, 0.79915856, 0.46147936, 0.78052918, 0.11827443]])>
>>> x = pygm.sinkhorn(s_2d)
>>> x
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0.18880224, 0.24990915, 0.19202217, 0.16034278, 0.20892366],
[0.18945066, 0.17240445, 0.23345011, 0.22194762, 0.18274716],
[0.23713583, 0.204348 , 0.18271243, 0.23114583, 0.1446579 ],
[0.11731039, 0.1229692 , 0.23823909, 0.19961588, 0.32186549],
[0.26730088, 0.2503692 , 0.15357619, 0.18694789, 0.1418058 ]])>
>>> print('row_sum:', tf.reduce_sum(x,axis=1), 'col_sum:', tf.reduce_sum(x, axis=0))
row_sum: tf.Tensor([1. 1.00000001 0.99999998 1.00000005 0.99999997], shape=(5,), dtype=float64) col_sum: tf.Tensor([1. 1. 1. 1. 1.], shape=(5,), dtype=float64)

# 3-dimensional (batched) input
>>> s_3d = tf.constant(np.random.rand(3, 5, 5))
>>> x = pygm.sinkhorn(s_3d)
>>> print('row_sum:', tf.reduce_sum(x, axis=2))
row_sum: tf.Tensor(
[[1. 1. 1. 1. 1. ]
[0.99999998 1.00000002 0.99999999 1.00000003 0.99999999]
[1. 1. 1. 1. 1. ]], shape=(3, 5), dtype=float64)
>>> print('col_sum:', tf.reduce_sum(x, axis=1))
col_sum: tf.Tensor(
[[1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]], shape=(3, 5), dtype=float64)

# If the 3-d tensor are with different number of nodes
>>> n1 = tf.constant([3, 4, 5])
>>> n2 = tf.constant([3, 4, 5])
>>> x = pygm.sinkhorn(s_3d, n1, n2)
>>> x[0] # non-zero size: 3x3
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0.36665934, 0.21498158, 0.41835906, 0. , 0. ],
[0.27603621, 0.44270207, 0.28126175, 0. , 0. ],
[0.35730445, 0.34231636, 0.3003792 , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ]])>
>>> x[1] # non-zero size: 4x4
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0.28847831, 0.20583051, 0.34242091, 0.16327021, 0. ],
[0.22656752, 0.30153021, 0.19407969, 0.27782262, 0. ],
[0.25346378, 0.19649853, 0.32565049, 0.22438715, 0. ],
[0.23149039, 0.29614075, 0.13784891, 0.33452002, 0. ],
[0. , 0. , 0. , 0. , 0. ]])>
>>> x[2] # non-zero size: 5x5
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0.20147352, 0.19541986, 0.24942798, 0.17346397, 0.18021467],
[0.21050732, 0.17620948, 0.18645469, 0.20384684, 0.22298167],
[0.18319623, 0.18024007, 0.17619871, 0.1664133 , 0.29395169],
[0.20754376, 0.2236443 , 0.19658101, 0.20570847, 0.16652246],
[0.19727917, 0.22448629, 0.19133762, 0.25056742, 0.13632951]])>

# non-squared input
>>> s_non_square = tf.constant(np.random.rand(4, 5))
>>> x = pygm.sinkhorn(s_non_square, dummy_row=True) # set dummy_row=True for non-squared cases
>>> print('row_sum:', tf.reduce_sum(x,axis=1), 'col_sum:', tf.reduce_sum(x,axis=0))
row_sum: tf.Tensor([1. 1. 1. 1.], shape=(4,), dtype=float64) col_sum: tf.Tensor([0.78239609 0.80485526 0.80165627 0.80004254 0.81104984], shape=(5,), dtype=float64)

# allow matching to void nodes by setting unmatch1 and unmatch2
>>> s_2d = tf.constant(np.random.randn(5, 5))
>>> s_2d
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[ 0.01050002, 1.78587049, 0.12691209, 0.40198936, 1.8831507 ],
[-1.34775906, -1.270485 , 0.96939671, -1.17312341, 1.94362119],
[-0.41361898, -0.74745481, 1.92294203, 1.48051479, 1.86755896],
[ 0.90604466, -0.86122569, 1.91006495, -0.26800337, 0.8024564 ],
[ 0.94725197, -0.15501009, 0.61407937, 0.92220667, 0.37642553]])>
>>> unmatch1 = tf.constant(np.random.randn(5))
>>> unmatch1
<tf.Tensor: shape=(5,), dtype=float64, numpy=array([-1.09940079, 0.29823817, 1.3263859 , -0.69456786, -0.14963454])>
>>> unmatch2 = tf.constant(np.random.randn(5))
>>> unmatch2
<tf.Tensor: shape=(5,), dtype=float64, numpy=array([-0.43515355, 1.84926373, 0.67229476, 0.40746184, -0.76991607])>
>>> x = pygm.sinkhorn(s_2d, unmatch1=unmatch1, unmatch2=unmatch2, max_iter=40)
>>> x
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0.12434101, 0.23913991, 0.05663597, 0.13943479, 0.31811425],
[0.03084473, 0.01085787, 0.12689067, 0.02784578, 0.3260589 ],
[0.03192548, 0.00745004, 0.13391025, 0.16087345, 0.12289304],
[0.29820536, 0.01659601, 0.32997174, 0.06988242, 0.10573396],
[0.29787774, 0.0322356 , 0.08654936, 0.22023996, 0.06619393]])>
>>> print('row_sum:', tf.reduce_sum(x, axis=1), 'col_sum:', tf.reduce_sum(x, axis=0))
row_sum: tf.Tensor([0.87766593 0.52249794 0.45705226 0.82038949 0.70309659], shape=(5,), dtype=float64) col_sum: tf.Tensor([0.78319431 0.30627943 0.733958 0.61827641 0.93899407], shape=(5,), dtype=float64)

.. note::

Expand Down Expand Up @@ -914,6 +1014,81 @@ def hungarian(s, n1=None, n2=None, unmatch1=None, unmatch2=None,
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]], dtype=float32)

.. dropdown:: Tensorflow Example

::

>>> import tensorflow as tf
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'tensorflow'
>>> np.random.seed(0)

# 2-dimensional (non-batched) input
>>> s_2d = tf.constant(np.random.rand(5, 5))
>>> s_2d
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ],
[0.64589411, 0.43758721, 0.891773 , 0.96366276, 0.38344152],
[0.79172504, 0.52889492, 0.56804456, 0.92559664, 0.07103606],
[0.0871293 , 0.0202184 , 0.83261985, 0.77815675, 0.87001215],
[0.97861834, 0.79915856, 0.46147936, 0.78052918, 0.11827443]])>
>>> x = pygm.hungarian(s_2d)
>>> x
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.]])>

# 3-dimensional (batched) input
>>> s_3d = tf.constant(np.random.rand(3, 5, 5))
>>> n1 = n2 = tf.constant([3, 4, 5])
>>> x = pygm.hungarian(s_3d, n1, n2)
>>> x
<tf.Tensor: shape=(3, 5, 5), dtype=float64, numpy=
array([[[0., 0., 1., 0., 0.],
[0., 1., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1.],
[0., 1., 0., 0., 0.],
[0., 0., 0., 1., 0.]]])>

# allow matching to void nodes by setting unmatch1 and unmatch2
>>> s_2d = tf.constant(np.random.randn(5, 5))
>>> s_2d
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[-1.16514984, 0.90082649, 0.46566244, -1.53624369, 1.48825219],
[ 1.89588918, 1.17877957, -0.17992484, -1.07075262, 1.05445173],
[-0.40317695, 1.22244507, 0.20827498, 0.97663904, 0.3563664 ],
[ 0.70657317, 0.01050002, 1.78587049, 0.12691209, 0.40198936],
[ 1.8831507 , -1.34775906, -1.270485 , 0.96939671, -1.17312341]])>
>>> unmatch1 = tf.constant(np.random.randn(5))
>>> unmatch1
<tf.Tensor: shape=(5,), dtype=float64, numpy=array([ 1.94362119, -0.41361898, -0.74745481, 1.92294203, 1.48051479])>
>>> unmatch2 = tf.constant(np.random.randn(5))
>>> unmatch2
<tf.Tensor: shape=(5,), dtype=float64, numpy=array([ 1.86755896, 0.90604466, -0.86122569, 1.91006495, -0.26800337])>
>>> x = pygm.hungarian(s_2d, unmatch1=unmatch1, unmatch2=unmatch2)
>>> x
<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])>

.. note::

If you find this graph matching solver useful for your research, please cite:
Expand Down
Loading