Skip to content

Commit 50c0f17

Browse files
[MRG] GW dictionary learning (#319)
* add fgw dictionary learning feature * add fgw dictionary learning feature * plot gromov wasserstein dictionary learning * Update __init__.py * fix pep8 errors exact E501 line too long * fix last pep8 issues * add unitary tests for (F)GW dictionary learning without using autodifferentiable functions * correct tests for (F)GW dictionary learning without using autodiff * correct tests for (F)GW dictionary learning without using autodiff * fix docs and notations * answer to review: improve tests, docs, examples + make node weights optional * fix pep8 and examples * improve docs + tests + thumbnail * make example faster * improve ex * update README.md * make GDL tests faster Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent a5e0f0d commit 50c0f17

File tree

6 files changed

+1954
-39
lines changed

6 files changed

+1954
-39
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ POT provides the following generic OT solvers (links to examples):
3636
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
3737
formulations).
3838
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
39+
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
3940
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
4041

4142
POT provides the following Machine Learning related solvers:
@@ -198,6 +199,7 @@ The contributors to this library are
198199
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
199200
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
200201
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
202+
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
201203

202204
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
203205

RELEASES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
of the regularization parameter (PR #336).
1111
- Backend implementation for `ot.lp.free_support_barycenter` (PR #340).
1212
- Add weak OT solver + example (PR #341).
13-
13+
- Add (F)GW linear dictionary learning solvers + example (PR #319)
1414

1515
#### Closed issues
1616

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
# -*- coding: utf-8 -*-
2+
3+
r"""
4+
=================================
5+
(Fused) Gromov-Wasserstein Linear Dictionary Learning
6+
=================================
7+
8+
In this exemple, we illustrate how to learn a Gromov-Wasserstein dictionary on
9+
a dataset of structured data such as graphs, denoted
10+
:math:`\{ \mathbf{C_s} \}_{s \in [S]}` where every nodes have uniform weights.
11+
Given a dictionary :math:`\mathbf{C_{dict}}` composed of D structures of a fixed
12+
size nt, each graph :math:`(\mathbf{C_s}, \mathbf{p_s})`
13+
is modeled as a convex combination :math:`\mathbf{w_s} \in \Sigma_D` of these
14+
dictionary atoms as :math:`\sum_d w_{s,d} \mathbf{C_{dict}[d]}`.
15+
16+
17+
First, we consider a dataset composed of graphs generated by Stochastic Block models
18+
with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters
19+
varying in :math:`\{ 1, 2, 3\}`. We learn a dictionary of 3 atoms, by minimizing
20+
the Gromov-Wasserstein distance from all samples to its model in the dictionary
21+
with respect to the dictionary atoms.
22+
23+
Second, we illustrate the extension of this dictionary learning framework to
24+
structured data endowed with node features by using the Fused Gromov-Wasserstein
25+
distance. Starting from the aforementioned dataset of unattributed graphs, we
26+
add discrete labels uniformly depending on the number of clusters. Then we learn
27+
and visualize attributed graph atoms where each sample is modeled as a joint convex
28+
combination between atom structures and features.
29+
30+
31+
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph
32+
Dictionary Learning, International Conference on Machine Learning (ICML), 2021.
33+
34+
"""
35+
# Author: Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
36+
#
37+
# License: MIT License
38+
39+
# sphinx_gallery_thumbnail_number = 4
40+
41+
import numpy as np
42+
import matplotlib.pylab as pl
43+
from sklearn.manifold import MDS
44+
from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning
45+
import ot
46+
import networkx
47+
from networkx.generators.community import stochastic_block_model as sbm
48+
# %%
49+
# =============================================================================
50+
# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
51+
# =============================================================================
52+
53+
np.random.seed(42)
54+
55+
N = 60 # number of graphs in the dataset
56+
# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability.
57+
clusters = [1, 2, 3]
58+
Nc = N // len(clusters) # number of graphs by cluster
59+
nlabels = len(clusters)
60+
dataset = []
61+
labels = []
62+
63+
p_inter = 0.1
64+
p_intra = 0.9
65+
for n_cluster in clusters:
66+
for i in range(Nc):
67+
n_nodes = int(np.random.uniform(low=30, high=50))
68+
69+
if n_cluster > 1:
70+
P = p_inter * np.ones((n_cluster, n_cluster))
71+
np.fill_diagonal(P, p_intra)
72+
else:
73+
P = p_intra * np.eye(1)
74+
sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32)
75+
G = sbm(sizes, P, seed=i, directed=False)
76+
C = networkx.to_numpy_array(G)
77+
dataset.append(C)
78+
labels.append(n_cluster)
79+
80+
81+
# Visualize samples
82+
83+
def plot_graph(x, C, binary=True, color='C0', s=None):
84+
for j in range(C.shape[0]):
85+
for i in range(j):
86+
if binary:
87+
if C[i, j] > 0:
88+
pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
89+
else: # connection intensity proportional to C[i,j]
90+
pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k')
91+
92+
pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)
93+
94+
95+
pl.figure(1, (12, 8))
96+
pl.clf()
97+
for idx_c, c in enumerate(clusters):
98+
C = dataset[(c - 1) * Nc] # sample with c clusters
99+
# get 2d position for nodes
100+
x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
101+
pl.subplot(2, nlabels, c)
102+
pl.title('(graph) sample from label ' + str(c), fontsize=14)
103+
plot_graph(x, C, binary=True, color='C0', s=50.)
104+
pl.axis("off")
105+
pl.subplot(2, nlabels, nlabels + c)
106+
pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
107+
pl.imshow(C, interpolation='nearest')
108+
pl.axis("off")
109+
pl.tight_layout()
110+
pl.show()
111+
112+
# %%
113+
# =============================================================================
114+
# Estimate the gromov-wasserstein dictionary from the dataset
115+
# =============================================================================
116+
117+
118+
np.random.seed(0)
119+
ps = [ot.unif(C.shape[0]) for C in dataset]
120+
121+
D = 3 # 3 atoms in the dictionary
122+
nt = 6 # of 6 nodes each
123+
124+
q = ot.unif(nt)
125+
reg = 0. # regularization coefficient to promote sparsity of unmixings {w_s}
126+
127+
Cdict_GW, log = gromov_wasserstein_dictionary_learning(
128+
Cs=dataset, D=D, nt=nt, ps=ps, q=q, epochs=10, batch_size=16,
129+
learning_rate=0.1, reg=reg, projection='nonnegative_symmetric',
130+
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300,
131+
use_log=True, use_adam_optimizer=True, verbose=True
132+
)
133+
# visualize loss evolution over epochs
134+
pl.figure(2, (4, 3))
135+
pl.clf()
136+
pl.title('loss evolution by epoch', fontsize=14)
137+
pl.plot(log['loss_epochs'])
138+
pl.xlabel('epochs', fontsize=12)
139+
pl.ylabel('loss', fontsize=12)
140+
pl.tight_layout()
141+
pl.show()
142+
143+
# %%
144+
# =============================================================================
145+
# Visualization of the estimated dictionary atoms
146+
# =============================================================================
147+
148+
149+
# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)
150+
151+
pl.figure(3, (12, 8))
152+
pl.clf()
153+
for idx_atom, atom in enumerate(Cdict_GW):
154+
scaled_atom = (atom - atom.min()) / (atom.max() - atom.min())
155+
x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom)
156+
pl.subplot(2, D, idx_atom + 1)
157+
pl.title('(graph) atom ' + str(idx_atom + 1), fontsize=14)
158+
plot_graph(x, atom / atom.max(), binary=False, color='C0', s=100.)
159+
pl.axis("off")
160+
pl.subplot(2, D, D + idx_atom + 1)
161+
pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14)
162+
pl.imshow(scaled_atom, interpolation='nearest')
163+
pl.colorbar()
164+
pl.axis("off")
165+
pl.tight_layout()
166+
pl.show()
167+
#%%
168+
# =============================================================================
169+
# Visualization of the embedding space
170+
# =============================================================================
171+
172+
unmixings = []
173+
reconstruction_errors = []
174+
for C in dataset:
175+
p = ot.unif(C.shape[0])
176+
unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing(
177+
C, Cdict_GW, p=p, q=q, reg=reg,
178+
tol_outer=10**(-5), tol_inner=10**(-5),
179+
max_iter_outer=30, max_iter_inner=300
180+
)
181+
unmixings.append(unmixing)
182+
reconstruction_errors.append(reconstruction_error)
183+
unmixings = np.array(unmixings)
184+
print('cumulated reconstruction error:', np.array(reconstruction_errors).sum())
185+
186+
187+
# Compute the 2D representation of the unmixing living in the 2-simplex of probability
188+
unmixings2D = np.zeros(shape=(N, 2))
189+
for i, w in enumerate(unmixings):
190+
unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
191+
unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
192+
x = [0., 0.]
193+
y = [1., 0.]
194+
z = [0.5, np.sqrt(3) / 2.]
195+
extremities = np.stack([x, y, z])
196+
197+
pl.figure(4, (4, 4))
198+
pl.clf()
199+
pl.title('Embedding space', fontsize=14)
200+
for cluster in range(nlabels):
201+
start, end = Nc * cluster, Nc * (cluster + 1)
202+
if cluster == 0:
203+
pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster')
204+
else:
205+
pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1))
206+
pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms')
207+
pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
208+
pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
209+
pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
210+
pl.axis('off')
211+
pl.legend(fontsize=11)
212+
pl.tight_layout()
213+
pl.show()
214+
# %%
215+
# =============================================================================
216+
# Endow the dataset with node features
217+
# =============================================================================
218+
219+
# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters
220+
# 1 cluster --> 0 as nodes feature
221+
# 2 clusters --> 1 as nodes feature
222+
# 3 clusters --> 2 as nodes feature
223+
# features are one-hot encoded following these assignments
224+
dataset_features = []
225+
for i in range(len(dataset)):
226+
n = dataset[i].shape[0]
227+
F = np.zeros((n, 3))
228+
if i < Nc: # graph with 1 cluster
229+
F[:, 0] = 1.
230+
elif i < 2 * Nc: # graph with 2 clusters
231+
F[:, 1] = 1.
232+
else: # graph with 3 clusters
233+
F[:, 2] = 1.
234+
dataset_features.append(F)
235+
236+
pl.figure(5, (12, 8))
237+
pl.clf()
238+
for idx_c, c in enumerate(clusters):
239+
C = dataset[(c - 1) * Nc] # sample with c clusters
240+
F = dataset_features[(c - 1) * Nc]
241+
colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])]
242+
# get 2d position for nodes
243+
x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
244+
pl.subplot(2, nlabels, c)
245+
pl.title('(graph) sample from label ' + str(c), fontsize=14)
246+
plot_graph(x, C, binary=True, color=colors, s=50)
247+
pl.axis("off")
248+
pl.subplot(2, nlabels, nlabels + c)
249+
pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
250+
pl.imshow(C, interpolation='nearest')
251+
pl.axis("off")
252+
pl.tight_layout()
253+
pl.show()
254+
# %%
255+
# =============================================================================
256+
# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs
257+
# =============================================================================
258+
np.random.seed(0)
259+
ps = [ot.unif(C.shape[0]) for C in dataset]
260+
D = 3 # 6 atoms instead of 3
261+
nt = 6
262+
q = ot.unif(nt)
263+
reg = 0.001
264+
alpha = 0.5 # trade-off parameter between structure and feature information of Fused Gromov-Wasserstein
265+
266+
267+
Cdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning(
268+
Cs=dataset, Ys=dataset_features, D=D, nt=nt, ps=ps, q=q, alpha=alpha,
269+
epochs=10, batch_size=16, learning_rate_C=0.1, learning_rate_Y=0.1, reg=reg,
270+
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300,
271+
projection='nonnegative_symmetric', use_log=True, use_adam_optimizer=True, verbose=True
272+
)
273+
# visualize loss evolution
274+
pl.figure(6, (4, 3))
275+
pl.clf()
276+
pl.title('loss evolution by epoch', fontsize=14)
277+
pl.plot(log['loss_epochs'])
278+
pl.xlabel('epochs', fontsize=12)
279+
pl.ylabel('loss', fontsize=12)
280+
pl.tight_layout()
281+
pl.show()
282+
283+
# %%
284+
# =============================================================================
285+
# Visualization of the estimated dictionary atoms
286+
# =============================================================================
287+
288+
pl.figure(7, (12, 8))
289+
pl.clf()
290+
max_features = Ydict_FGW.max()
291+
min_features = Ydict_FGW.min()
292+
293+
for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)):
294+
scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min())
295+
#scaled_F = 2 * (Fatom - min_features) / (max_features - min_features)
296+
colors = ['C%s' % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])]
297+
x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom)
298+
pl.subplot(2, D, idx_atom + 1)
299+
pl.title('(attributed graph) atom ' + str(idx_atom + 1), fontsize=14)
300+
plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100)
301+
pl.axis("off")
302+
pl.subplot(2, D, D + idx_atom + 1)
303+
pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14)
304+
pl.imshow(scaled_atom, interpolation='nearest')
305+
pl.colorbar()
306+
pl.axis("off")
307+
pl.tight_layout()
308+
pl.show()
309+
310+
# %%
311+
# =============================================================================
312+
# Visualization of the embedding space
313+
# =============================================================================
314+
315+
unmixings = []
316+
reconstruction_errors = []
317+
for i in range(len(dataset)):
318+
C = dataset[i]
319+
Y = dataset_features[i]
320+
p = ot.unif(C.shape[0])
321+
unmixing, Cembedded, Yembedded, OT, reconstruction_error = fused_gromov_wasserstein_linear_unmixing(
322+
C, Y, Cdict_FGW, Ydict_FGW, p=p, q=q, alpha=alpha,
323+
reg=reg, tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=30, max_iter_inner=300
324+
)
325+
unmixings.append(unmixing)
326+
reconstruction_errors.append(reconstruction_error)
327+
unmixings = np.array(unmixings)
328+
print('cumulated reconstruction error:', np.array(reconstruction_errors).sum())
329+
330+
# Visualize unmixings in the 2-simplex of probability
331+
unmixings2D = np.zeros(shape=(N, 2))
332+
for i, w in enumerate(unmixings):
333+
unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
334+
unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
335+
x = [0., 0.]
336+
y = [1., 0.]
337+
z = [0.5, np.sqrt(3) / 2.]
338+
extremities = np.stack([x, y, z])
339+
340+
pl.figure(8, (4, 4))
341+
pl.clf()
342+
pl.title('Embedding space', fontsize=14)
343+
for cluster in range(nlabels):
344+
start, end = Nc * cluster, Nc * (cluster + 1)
345+
if cluster == 0:
346+
pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster')
347+
else:
348+
pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1))
349+
350+
pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms')
351+
pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
352+
pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
353+
pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
354+
pl.axis('off')
355+
pl.legend(fontsize=11)
356+
pl.tight_layout()
357+
pl.show()

0 commit comments

Comments
 (0)