|
12 | 12 | # Author: Remi Flamary <remi.flamary@unice.fr>
|
13 | 13 | # Minhui Huang <mhhuang@ucdavis.edu>
|
14 | 14 | # Jakub Zadrozny <jakub.r.zadrozny@gmail.com>
|
| 15 | +# Antoine Collas <antoine.collas@inria.fr> |
15 | 16 | #
|
16 | 17 | # License: MIT License
|
17 | 18 |
|
18 | 19 | from scipy import linalg
|
19 | 20 | import autograd.numpy as np
|
| 21 | +from sklearn.decomposition import PCA |
20 | 22 |
|
21 | 23 | import pymanopt
|
22 | 24 | import pymanopt.manifolds
|
23 | 25 | import pymanopt.optimizers
|
24 | 26 |
|
| 27 | +from .bregman import sinkhorn as sinkhorn_bregman |
| 28 | +from .utils import dist as dist_utils |
| 29 | + |
25 | 30 |
|
26 | 31 | def dist(x1, x2):
|
27 | 32 | r""" Compute squared euclidean distance between samples (autograd)
|
@@ -376,3 +381,153 @@ def Vpi(X, Y, a, b, pi):
|
376 | 381 | iter = iter + 1
|
377 | 382 |
|
378 | 383 | return pi, U
|
| 384 | + |
| 385 | + |
| 386 | +def ewca(X, U0=None, reg=1, k=2, method='BCD', sinkhorn_method='sinkhorn', stopThr=1e-6, maxiter=100, maxiter_sink=1000, maxiter_MM=10, verbose=0): |
| 387 | + r""" |
| 388 | + Entropic Wasserstein Component Analysis :ref:`[52] <references-entropic-wasserstein-component_analysis>`. |
| 389 | +
|
| 390 | + The function solves the following optimization problem: |
| 391 | +
|
| 392 | + .. math:: |
| 393 | + \mathbf{U} = \mathop{\arg \min}_\mathbf{U} \quad |
| 394 | + W(\mathbf{X}, \mathbf{U}\mathbf{U}^T \mathbf{X}) |
| 395 | +
|
| 396 | + where : |
| 397 | +
|
| 398 | + - :math:`\mathbf{U}` is a matrix in the Stiefel(`p`, `d`) manifold |
| 399 | + - :math:`W` is entropic regularized Wasserstein distances |
| 400 | + - :math:`\mathbf{X}` are samples |
| 401 | +
|
| 402 | + Parameters |
| 403 | + ---------- |
| 404 | + X : ndarray, shape (n, d) |
| 405 | + Samples from measure :math:`\mu`. |
| 406 | + U0 : ndarray, shape (d, k), optional |
| 407 | + Initial starting point for projection. |
| 408 | + reg : float, optional |
| 409 | + Regularization term >0 (entropic regularization). |
| 410 | + k : int, optional |
| 411 | + Subspace dimension. |
| 412 | + method : str, optional |
| 413 | + Eather 'BCD' or 'MM' (Block Coordinate Descent or Majorization-Minimization). |
| 414 | + Prefer MM when d is large. |
| 415 | + sinkhorn_method : str |
| 416 | + Method used for the Sinkhorn solver, see :ref:`ot.bregman.sinkhorn` for more details. |
| 417 | + stopThr : float, optional |
| 418 | + Stop threshold on error (>0). |
| 419 | + maxiter : int, optional |
| 420 | + Maximum number of iterations of the BCD/MM. |
| 421 | + maxiter_sink : int, optional |
| 422 | + Maximum number of iterations of the Sinkhorn solver. |
| 423 | + maxiter_MM : int, optional |
| 424 | + Maximum number of iterations of the MM (only used when method='MM'). |
| 425 | + verbose : int, optional |
| 426 | + Print information along iterations. |
| 427 | +
|
| 428 | + Returns |
| 429 | + ------- |
| 430 | + pi : ndarray, shape (n, n) |
| 431 | + Optimal transportation matrix for the given parameters. |
| 432 | + U : ndarray, shape (d, k) |
| 433 | + Matrix Stiefel manifold. |
| 434 | +
|
| 435 | +
|
| 436 | + .. _references-entropic-wasserstein-component_analysis: |
| 437 | + References |
| 438 | + ---------- |
| 439 | + .. [52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). |
| 440 | + Entropic Wasserstein Component Analysis. |
| 441 | + """ # noqa |
| 442 | + n, d = X.shape |
| 443 | + X = X - X.mean(0) |
| 444 | + |
| 445 | + if U0 is None: |
| 446 | + pca_fitted = PCA(n_components=k).fit(X) |
| 447 | + U = pca_fitted.components_.T |
| 448 | + if method == 'MM': |
| 449 | + lambda_scm = pca_fitted.explained_variance_[0] |
| 450 | + else: |
| 451 | + U = U0 |
| 452 | + |
| 453 | + # marginals |
| 454 | + u0 = (1. / n) * np.ones(n) |
| 455 | + |
| 456 | + # print iterations |
| 457 | + if verbose > 0: |
| 458 | + print('{:4s}|{:13s}|{:12s}|{:12s}'.format('It.', 'Loss', 'Crit.', 'Thres.') + '\n' + '-' * 40) |
| 459 | + |
| 460 | + def compute_loss(M, pi, reg): |
| 461 | + return np.sum(M * pi) + reg * np.sum(pi * (np.log(pi) - 1)) |
| 462 | + |
| 463 | + def grassmann_distance(U1, U2): |
| 464 | + proj = U1.T @ U2 |
| 465 | + _, s, _ = np.linalg.svd(proj) |
| 466 | + s[s > 1] = 1 |
| 467 | + s = np.arccos(s) |
| 468 | + return np.linalg.norm(s) |
| 469 | + |
| 470 | + # loop |
| 471 | + it = 0 |
| 472 | + crit = np.inf |
| 473 | + sinkhorn_warmstart = None |
| 474 | + |
| 475 | + while (it < maxiter) and (crit > stopThr): |
| 476 | + U_old = U |
| 477 | + |
| 478 | + # Solve transport |
| 479 | + M = dist_utils(X, (X @ U) @ U.T) |
| 480 | + pi, log_sinkhorn = sinkhorn_bregman( |
| 481 | + u0, u0, M, reg, |
| 482 | + numItermax=maxiter_sink, |
| 483 | + method=sinkhorn_method, warmstart=sinkhorn_warmstart, |
| 484 | + warn=False, log=True |
| 485 | + ) |
| 486 | + key_warmstart = 'warmstart' |
| 487 | + if key_warmstart in log_sinkhorn: |
| 488 | + sinkhorn_warmstart = log_sinkhorn[key_warmstart] |
| 489 | + if (pi >= 1e-300).all(): |
| 490 | + loss = compute_loss(M, pi, reg) |
| 491 | + else: |
| 492 | + loss = np.inf |
| 493 | + |
| 494 | + # Solve PCA |
| 495 | + pi_sym = (pi + pi.T) / 2 |
| 496 | + |
| 497 | + if method == 'BCD': |
| 498 | + # block coordinate descent |
| 499 | + S = X.T @ (2 * pi_sym - (1. / n) * np.eye(n)) @ X |
| 500 | + _, U = np.linalg.eigh(S) |
| 501 | + U = U[:, ::-1][:, :k] |
| 502 | + |
| 503 | + elif method == 'MM': |
| 504 | + # majorization-minimization |
| 505 | + eig, _ = np.linalg.eigh(pi_sym) |
| 506 | + lambda_pi = eig[0] |
| 507 | + |
| 508 | + for _ in range(maxiter_MM): |
| 509 | + X_proj = X @ U |
| 510 | + X_T_X_proj = X.T @ X_proj |
| 511 | + |
| 512 | + R = (1 / n) * X_T_X_proj |
| 513 | + alpha = 1 - 2 * n * lambda_pi |
| 514 | + if alpha > 0: |
| 515 | + R = alpha * (R - lambda_scm * U) |
| 516 | + else: |
| 517 | + R = alpha * R |
| 518 | + |
| 519 | + R = R - (2 * X.T @ (pi_sym @ X_proj)) + (2 * lambda_pi * X_T_X_proj) |
| 520 | + U, _ = np.linalg.qr(R) |
| 521 | + |
| 522 | + else: |
| 523 | + raise ValueError(f"Unknown method '{method}', use 'BCD' or 'MM'.") |
| 524 | + |
| 525 | + # stop or not |
| 526 | + it += 1 |
| 527 | + crit = grassmann_distance(U_old, U) |
| 528 | + |
| 529 | + # print |
| 530 | + if verbose > 0: |
| 531 | + print('{:4d}|{:8e}|{:8e}|{:8e}'.format(it, loss, crit, stopThr)) |
| 532 | + |
| 533 | + return pi, U |
0 commit comments