|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +r""" |
| 4 | +================================= |
| 5 | +(Fused) Gromov-Wasserstein Linear Dictionary Learning |
| 6 | +================================= |
| 7 | +
|
| 8 | +In this exemple, we illustrate how to learn a Gromov-Wasserstein dictionary on |
| 9 | +a dataset of structured data such as graphs, denoted |
| 10 | +:math:`\{ \mathbf{C_s} \}_{s \in [S]}` where every nodes have uniform weights. |
| 11 | +Given a dictionary :math:`\mathbf{C_{dict}}` composed of D structures of a fixed |
| 12 | +size nt, each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` |
| 13 | +is modeled as a convex combination :math:`\mathbf{w_s} \in \Sigma_D` of these |
| 14 | +dictionary atoms as :math:`\sum_d w_{s,d} \mathbf{C_{dict}[d]}`. |
| 15 | +
|
| 16 | +
|
| 17 | +First, we consider a dataset composed of graphs generated by Stochastic Block models |
| 18 | +with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters |
| 19 | +varying in :math:`\{ 1, 2, 3\}`. We learn a dictionary of 3 atoms, by minimizing |
| 20 | +the Gromov-Wasserstein distance from all samples to its model in the dictionary |
| 21 | +with respect to the dictionary atoms. |
| 22 | +
|
| 23 | +Second, we illustrate the extension of this dictionary learning framework to |
| 24 | +structured data endowed with node features by using the Fused Gromov-Wasserstein |
| 25 | +distance. Starting from the aforementioned dataset of unattributed graphs, we |
| 26 | +add discrete labels uniformly depending on the number of clusters. Then we learn |
| 27 | +and visualize attributed graph atoms where each sample is modeled as a joint convex |
| 28 | +combination between atom structures and features. |
| 29 | +
|
| 30 | +
|
| 31 | +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph |
| 32 | +Dictionary Learning, International Conference on Machine Learning (ICML), 2021. |
| 33 | +
|
| 34 | +""" |
| 35 | +# Author: Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr> |
| 36 | +# |
| 37 | +# License: MIT License |
| 38 | + |
| 39 | +# sphinx_gallery_thumbnail_number = 4 |
| 40 | + |
| 41 | +import numpy as np |
| 42 | +import matplotlib.pylab as pl |
| 43 | +from sklearn.manifold import MDS |
| 44 | +from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning |
| 45 | +import ot |
| 46 | +import networkx |
| 47 | +from networkx.generators.community import stochastic_block_model as sbm |
| 48 | +# %% |
| 49 | +# ============================================================================= |
| 50 | +# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. |
| 51 | +# ============================================================================= |
| 52 | + |
| 53 | +np.random.seed(42) |
| 54 | + |
| 55 | +N = 60 # number of graphs in the dataset |
| 56 | +# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability. |
| 57 | +clusters = [1, 2, 3] |
| 58 | +Nc = N // len(clusters) # number of graphs by cluster |
| 59 | +nlabels = len(clusters) |
| 60 | +dataset = [] |
| 61 | +labels = [] |
| 62 | + |
| 63 | +p_inter = 0.1 |
| 64 | +p_intra = 0.9 |
| 65 | +for n_cluster in clusters: |
| 66 | + for i in range(Nc): |
| 67 | + n_nodes = int(np.random.uniform(low=30, high=50)) |
| 68 | + |
| 69 | + if n_cluster > 1: |
| 70 | + P = p_inter * np.ones((n_cluster, n_cluster)) |
| 71 | + np.fill_diagonal(P, p_intra) |
| 72 | + else: |
| 73 | + P = p_intra * np.eye(1) |
| 74 | + sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32) |
| 75 | + G = sbm(sizes, P, seed=i, directed=False) |
| 76 | + C = networkx.to_numpy_array(G) |
| 77 | + dataset.append(C) |
| 78 | + labels.append(n_cluster) |
| 79 | + |
| 80 | + |
| 81 | +# Visualize samples |
| 82 | + |
| 83 | +def plot_graph(x, C, binary=True, color='C0', s=None): |
| 84 | + for j in range(C.shape[0]): |
| 85 | + for i in range(j): |
| 86 | + if binary: |
| 87 | + if C[i, j] > 0: |
| 88 | + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') |
| 89 | + else: # connection intensity proportional to C[i,j] |
| 90 | + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k') |
| 91 | + |
| 92 | + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) |
| 93 | + |
| 94 | + |
| 95 | +pl.figure(1, (12, 8)) |
| 96 | +pl.clf() |
| 97 | +for idx_c, c in enumerate(clusters): |
| 98 | + C = dataset[(c - 1) * Nc] # sample with c clusters |
| 99 | + # get 2d position for nodes |
| 100 | + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) |
| 101 | + pl.subplot(2, nlabels, c) |
| 102 | + pl.title('(graph) sample from label ' + str(c), fontsize=14) |
| 103 | + plot_graph(x, C, binary=True, color='C0', s=50.) |
| 104 | + pl.axis("off") |
| 105 | + pl.subplot(2, nlabels, nlabels + c) |
| 106 | + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) |
| 107 | + pl.imshow(C, interpolation='nearest') |
| 108 | + pl.axis("off") |
| 109 | +pl.tight_layout() |
| 110 | +pl.show() |
| 111 | + |
| 112 | +# %% |
| 113 | +# ============================================================================= |
| 114 | +# Estimate the gromov-wasserstein dictionary from the dataset |
| 115 | +# ============================================================================= |
| 116 | + |
| 117 | + |
| 118 | +np.random.seed(0) |
| 119 | +ps = [ot.unif(C.shape[0]) for C in dataset] |
| 120 | + |
| 121 | +D = 3 # 3 atoms in the dictionary |
| 122 | +nt = 6 # of 6 nodes each |
| 123 | + |
| 124 | +q = ot.unif(nt) |
| 125 | +reg = 0. # regularization coefficient to promote sparsity of unmixings {w_s} |
| 126 | + |
| 127 | +Cdict_GW, log = gromov_wasserstein_dictionary_learning( |
| 128 | + Cs=dataset, D=D, nt=nt, ps=ps, q=q, epochs=10, batch_size=16, |
| 129 | + learning_rate=0.1, reg=reg, projection='nonnegative_symmetric', |
| 130 | + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, |
| 131 | + use_log=True, use_adam_optimizer=True, verbose=True |
| 132 | +) |
| 133 | +# visualize loss evolution over epochs |
| 134 | +pl.figure(2, (4, 3)) |
| 135 | +pl.clf() |
| 136 | +pl.title('loss evolution by epoch', fontsize=14) |
| 137 | +pl.plot(log['loss_epochs']) |
| 138 | +pl.xlabel('epochs', fontsize=12) |
| 139 | +pl.ylabel('loss', fontsize=12) |
| 140 | +pl.tight_layout() |
| 141 | +pl.show() |
| 142 | + |
| 143 | +# %% |
| 144 | +# ============================================================================= |
| 145 | +# Visualization of the estimated dictionary atoms |
| 146 | +# ============================================================================= |
| 147 | + |
| 148 | + |
| 149 | +# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white) |
| 150 | + |
| 151 | +pl.figure(3, (12, 8)) |
| 152 | +pl.clf() |
| 153 | +for idx_atom, atom in enumerate(Cdict_GW): |
| 154 | + scaled_atom = (atom - atom.min()) / (atom.max() - atom.min()) |
| 155 | + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) |
| 156 | + pl.subplot(2, D, idx_atom + 1) |
| 157 | + pl.title('(graph) atom ' + str(idx_atom + 1), fontsize=14) |
| 158 | + plot_graph(x, atom / atom.max(), binary=False, color='C0', s=100.) |
| 159 | + pl.axis("off") |
| 160 | + pl.subplot(2, D, D + idx_atom + 1) |
| 161 | + pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) |
| 162 | + pl.imshow(scaled_atom, interpolation='nearest') |
| 163 | + pl.colorbar() |
| 164 | + pl.axis("off") |
| 165 | +pl.tight_layout() |
| 166 | +pl.show() |
| 167 | +#%% |
| 168 | +# ============================================================================= |
| 169 | +# Visualization of the embedding space |
| 170 | +# ============================================================================= |
| 171 | + |
| 172 | +unmixings = [] |
| 173 | +reconstruction_errors = [] |
| 174 | +for C in dataset: |
| 175 | + p = ot.unif(C.shape[0]) |
| 176 | + unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing( |
| 177 | + C, Cdict_GW, p=p, q=q, reg=reg, |
| 178 | + tol_outer=10**(-5), tol_inner=10**(-5), |
| 179 | + max_iter_outer=30, max_iter_inner=300 |
| 180 | + ) |
| 181 | + unmixings.append(unmixing) |
| 182 | + reconstruction_errors.append(reconstruction_error) |
| 183 | +unmixings = np.array(unmixings) |
| 184 | +print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) |
| 185 | + |
| 186 | + |
| 187 | +# Compute the 2D representation of the unmixing living in the 2-simplex of probability |
| 188 | +unmixings2D = np.zeros(shape=(N, 2)) |
| 189 | +for i, w in enumerate(unmixings): |
| 190 | + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. |
| 191 | + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. |
| 192 | +x = [0., 0.] |
| 193 | +y = [1., 0.] |
| 194 | +z = [0.5, np.sqrt(3) / 2.] |
| 195 | +extremities = np.stack([x, y, z]) |
| 196 | + |
| 197 | +pl.figure(4, (4, 4)) |
| 198 | +pl.clf() |
| 199 | +pl.title('Embedding space', fontsize=14) |
| 200 | +for cluster in range(nlabels): |
| 201 | + start, end = Nc * cluster, Nc * (cluster + 1) |
| 202 | + if cluster == 0: |
| 203 | + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') |
| 204 | + else: |
| 205 | + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) |
| 206 | +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') |
| 207 | +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) |
| 208 | +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) |
| 209 | +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) |
| 210 | +pl.axis('off') |
| 211 | +pl.legend(fontsize=11) |
| 212 | +pl.tight_layout() |
| 213 | +pl.show() |
| 214 | +# %% |
| 215 | +# ============================================================================= |
| 216 | +# Endow the dataset with node features |
| 217 | +# ============================================================================= |
| 218 | + |
| 219 | +# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters |
| 220 | +# 1 cluster --> 0 as nodes feature |
| 221 | +# 2 clusters --> 1 as nodes feature |
| 222 | +# 3 clusters --> 2 as nodes feature |
| 223 | +# features are one-hot encoded following these assignments |
| 224 | +dataset_features = [] |
| 225 | +for i in range(len(dataset)): |
| 226 | + n = dataset[i].shape[0] |
| 227 | + F = np.zeros((n, 3)) |
| 228 | + if i < Nc: # graph with 1 cluster |
| 229 | + F[:, 0] = 1. |
| 230 | + elif i < 2 * Nc: # graph with 2 clusters |
| 231 | + F[:, 1] = 1. |
| 232 | + else: # graph with 3 clusters |
| 233 | + F[:, 2] = 1. |
| 234 | + dataset_features.append(F) |
| 235 | + |
| 236 | +pl.figure(5, (12, 8)) |
| 237 | +pl.clf() |
| 238 | +for idx_c, c in enumerate(clusters): |
| 239 | + C = dataset[(c - 1) * Nc] # sample with c clusters |
| 240 | + F = dataset_features[(c - 1) * Nc] |
| 241 | + colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])] |
| 242 | + # get 2d position for nodes |
| 243 | + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) |
| 244 | + pl.subplot(2, nlabels, c) |
| 245 | + pl.title('(graph) sample from label ' + str(c), fontsize=14) |
| 246 | + plot_graph(x, C, binary=True, color=colors, s=50) |
| 247 | + pl.axis("off") |
| 248 | + pl.subplot(2, nlabels, nlabels + c) |
| 249 | + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) |
| 250 | + pl.imshow(C, interpolation='nearest') |
| 251 | + pl.axis("off") |
| 252 | +pl.tight_layout() |
| 253 | +pl.show() |
| 254 | +# %% |
| 255 | +# ============================================================================= |
| 256 | +# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs |
| 257 | +# ============================================================================= |
| 258 | +np.random.seed(0) |
| 259 | +ps = [ot.unif(C.shape[0]) for C in dataset] |
| 260 | +D = 3 # 6 atoms instead of 3 |
| 261 | +nt = 6 |
| 262 | +q = ot.unif(nt) |
| 263 | +reg = 0.001 |
| 264 | +alpha = 0.5 # trade-off parameter between structure and feature information of Fused Gromov-Wasserstein |
| 265 | + |
| 266 | + |
| 267 | +Cdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning( |
| 268 | + Cs=dataset, Ys=dataset_features, D=D, nt=nt, ps=ps, q=q, alpha=alpha, |
| 269 | + epochs=10, batch_size=16, learning_rate_C=0.1, learning_rate_Y=0.1, reg=reg, |
| 270 | + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, |
| 271 | + projection='nonnegative_symmetric', use_log=True, use_adam_optimizer=True, verbose=True |
| 272 | +) |
| 273 | +# visualize loss evolution |
| 274 | +pl.figure(6, (4, 3)) |
| 275 | +pl.clf() |
| 276 | +pl.title('loss evolution by epoch', fontsize=14) |
| 277 | +pl.plot(log['loss_epochs']) |
| 278 | +pl.xlabel('epochs', fontsize=12) |
| 279 | +pl.ylabel('loss', fontsize=12) |
| 280 | +pl.tight_layout() |
| 281 | +pl.show() |
| 282 | + |
| 283 | +# %% |
| 284 | +# ============================================================================= |
| 285 | +# Visualization of the estimated dictionary atoms |
| 286 | +# ============================================================================= |
| 287 | + |
| 288 | +pl.figure(7, (12, 8)) |
| 289 | +pl.clf() |
| 290 | +max_features = Ydict_FGW.max() |
| 291 | +min_features = Ydict_FGW.min() |
| 292 | + |
| 293 | +for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)): |
| 294 | + scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min()) |
| 295 | + #scaled_F = 2 * (Fatom - min_features) / (max_features - min_features) |
| 296 | + colors = ['C%s' % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])] |
| 297 | + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) |
| 298 | + pl.subplot(2, D, idx_atom + 1) |
| 299 | + pl.title('(attributed graph) atom ' + str(idx_atom + 1), fontsize=14) |
| 300 | + plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100) |
| 301 | + pl.axis("off") |
| 302 | + pl.subplot(2, D, D + idx_atom + 1) |
| 303 | + pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) |
| 304 | + pl.imshow(scaled_atom, interpolation='nearest') |
| 305 | + pl.colorbar() |
| 306 | + pl.axis("off") |
| 307 | +pl.tight_layout() |
| 308 | +pl.show() |
| 309 | + |
| 310 | +# %% |
| 311 | +# ============================================================================= |
| 312 | +# Visualization of the embedding space |
| 313 | +# ============================================================================= |
| 314 | + |
| 315 | +unmixings = [] |
| 316 | +reconstruction_errors = [] |
| 317 | +for i in range(len(dataset)): |
| 318 | + C = dataset[i] |
| 319 | + Y = dataset_features[i] |
| 320 | + p = ot.unif(C.shape[0]) |
| 321 | + unmixing, Cembedded, Yembedded, OT, reconstruction_error = fused_gromov_wasserstein_linear_unmixing( |
| 322 | + C, Y, Cdict_FGW, Ydict_FGW, p=p, q=q, alpha=alpha, |
| 323 | + reg=reg, tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=30, max_iter_inner=300 |
| 324 | + ) |
| 325 | + unmixings.append(unmixing) |
| 326 | + reconstruction_errors.append(reconstruction_error) |
| 327 | +unmixings = np.array(unmixings) |
| 328 | +print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) |
| 329 | + |
| 330 | +# Visualize unmixings in the 2-simplex of probability |
| 331 | +unmixings2D = np.zeros(shape=(N, 2)) |
| 332 | +for i, w in enumerate(unmixings): |
| 333 | + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. |
| 334 | + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. |
| 335 | +x = [0., 0.] |
| 336 | +y = [1., 0.] |
| 337 | +z = [0.5, np.sqrt(3) / 2.] |
| 338 | +extremities = np.stack([x, y, z]) |
| 339 | + |
| 340 | +pl.figure(8, (4, 4)) |
| 341 | +pl.clf() |
| 342 | +pl.title('Embedding space', fontsize=14) |
| 343 | +for cluster in range(nlabels): |
| 344 | + start, end = Nc * cluster, Nc * (cluster + 1) |
| 345 | + if cluster == 0: |
| 346 | + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') |
| 347 | + else: |
| 348 | + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) |
| 349 | + |
| 350 | +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') |
| 351 | +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) |
| 352 | +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) |
| 353 | +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) |
| 354 | +pl.axis('off') |
| 355 | +pl.legend(fontsize=11) |
| 356 | +pl.tight_layout() |
| 357 | +pl.show() |
0 commit comments