diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index a6f6f94ba..c7916f50a 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -42,6 +42,8 @@ The contributors to this library are: * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) * [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein) +* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization) +* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers) ## Acknowledgments diff --git a/README.md b/README.md index 0df45d970..a5660a988 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ POT provides the following generic OT solvers (links to examples): * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. * [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]). +* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. POT provides the following Machine Learning related solvers: @@ -319,3 +320,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35. [54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804). + +[55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR). + +[56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019. diff --git a/RELEASES.md b/RELEASES.md index bd5a618e8..9bb36d55f 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -18,6 +18,8 @@ - Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455) - Add Entropic Wasserstein Component Analysis (ECWA) in ot.dr (PR #486) +- Added feature Efficient Discrete Multi Marginal Optimal Transport Regularization + examples (PR #454) + #### Closed issues - Fix change in scipy API for `cdist` (PR #487) diff --git a/examples/others/plot_dmmot.py b/examples/others/plot_dmmot.py new file mode 100644 index 000000000..1548ba470 --- /dev/null +++ b/examples/others/plot_dmmot.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +r""" +=============================================================================== +Computing d-dimensional Barycenters via d-MMOT +=============================================================================== + +When the cost is discretized (Monge), the d-MMOT solver can more quickly +compute and minimize the distance between many distributions without the need +for intermediate barycenter computations. This example compares the time to +identify, and the quality of, solutions for the d-MMOT problem using a +primal/dual algorithm and classical LP barycenter approaches. +""" + +# Author: Ronak Mehta +# Xizheng Yu +# +# License: MIT License + +# %% +# Generating 2 distributions +# ----- +import numpy as np +import matplotlib.pyplot as pl +import ot + +np.random.seed(0) + +n = 100 +d = 2 +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) +A = np.vstack((a1, a2)).T +x = np.arange(n, dtype=np.float64) +M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a1, 'b', label='Source distribution') +pl.plot(x, a2, 'r', label='Target distribution') +pl.legend() + +# %% +# Minimize the distances among distributions, identify the Barycenter +# ----- +# The objective being minimized is different for both methods, so the objective +# values cannot be compared. + +# L2 Iteration +weights = np.ones(d) / d +l2_bary = A.dot(weights) + +print('LP Iterations:') +weights = np.ones(d) / d +lp_bary, lp_log = ot.lp.barycenter( + A, M, weights, solver='interior-point', verbose=False, log=True) +print('Time\t: ', ot.toc('')) +print('Obj\t: ', lp_log['fun']) + +print('') +print('Discrete MMOT Algorithm:') +ot.tic() +barys, log = ot.lp.dmmot_monge_1dgrid_optimize( + A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True) +dmmot_obj = log['primal objective'] +print('Time\t: ', ot.toc('')) +print('Obj\t: ', dmmot_obj) + +# %% +# Compare Barycenters in both methods +# ----- +pl.figure(1, figsize=(6.4, 3)) +for i in range(len(barys)): + if i == 0: + pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') + else: + continue + # pl.plot(x, barys[i], 'g-*') +pl.plot(x, lp_bary, label='LP Barycenter') +pl.plot(x, l2_bary, label='L2 Barycenter') +pl.plot(x, a1, 'b', label='Source distribution') +pl.plot(x, a2, 'r', label='Target distribution') +pl.title('Monge Cost: Barycenters from LP Solver and dmmot solver') +pl.legend() + + +# %% +# More than 2 distributions +# -------------------------------------------------- +# Generate 7 pseudorandom gaussian distributions with 50 bins. +n = 50 # nb bins +d = 7 +vecsize = n * d + +data = [] +for i in range(d): + m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1) + a = ot.datasets.make_1D_gauss(n, m=m, s=5) + data.append(a) + +x = np.arange(n, dtype=np.float64) +M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') +A = np.vstack(data).T + +pl.figure(1, figsize=(6.4, 3)) +for i in range(len(data)): + pl.plot(x, data[i]) + +pl.title('Distributions') +pl.legend() + +# %% +# Minimizing Distances Among Many Distributions +# --------------- +# The objective being minimized is different for both methods, so the objective +# values cannot be compared. + +# Perform gradient descent optimization using the d-MMOT method. +barys = ot.lp.dmmot_monge_1dgrid_optimize( + A, niters=3000, lr_init=1e-4, lr_decay=0.997) + +# after minimization, any distribution can be used as a estimate of barycenter. +bary = barys[0] + +# Compute 1D Wasserstein barycenter using the L2/LP method +weights = ot.unif(d) +l2_bary = A.dot(weights) +lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', + verbose=False, log=True) + +# %% +# Compare Barycenters in both methods +# --------- +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, bary, 'g-*', label='Discrete MMOT') +pl.plot(x, l2_bary, 'k', label='L2 Barycenter') +pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') +pl.title('Barycenters') +pl.legend() + +# %% +# Compare with original distributions +# --------- +pl.figure(1, figsize=(6.4, 3)) +for i in range(len(data)): + pl.plot(x, data[i]) +for i in range(len(barys)): + if i == 0: + pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') + else: + continue + # pl.plot(x, barys[i], 'g') +pl.plot(x, l2_bary, 'k^', label='L2') +pl.plot(x, lp_bary, 'o', color='grey', label='LP') +pl.title('Barycenters') +pl.legend() +pl.show() + +# %% diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 3641d7817..2badb28f3 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -17,6 +17,7 @@ from . import cvx from .cvx import barycenter +from .dmmot import * # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted @@ -30,7 +31,8 @@ __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', - 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle'] + 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle', + 'discrete_mmot', 'discrete_mmot_converge'] def check_number_threads(numThreads): diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py new file mode 100644 index 000000000..8576c3c61 --- /dev/null +++ b/ot/lp/dmmot.py @@ -0,0 +1,344 @@ +# -*- coding: utf-8 -*- +""" +d-MMOT solvers for optimal transport +""" + +# Author: Ronak Mehta +# Xizheng Yu +# +# License: MIT License + +import numpy as np +from ..backend import get_backend + + +def dist_monge_max_min(i): + r""" + A tensor :math:c is Monge if for all valid :math:i_1, \ldots i_d and + :math:j_1, \ldots, j_d, + + .. math:: + c(s_1, \ldots, s_d) + c(t_1, \ldots t_d) \leq c(i_1, \ldots i_d) + + c(j_1, \ldots, j_d) + + where :math:s_k = \min(i_k, j_k) and :math:t_k = \max(i_k, j_k). + + Our focus is on a specific cost, which is known to be Monge: + + .. math:: + c(i_1,i_2,\ldots,i_d) = \max{i_k:k\in[d]} - \min{i_k:k\in[d]}. + + When :math:d=2, this cost reduces to :math:c(i_1,i_2)=|i_1-i_2|, + which agrees with the classical EMD cost. This choice of :math:c is called + the generalized EMD cost. + + Parameters + ---------- + i : list + The list of integer indexes. + + Returns + ------- + cost : numeric value + The ground cost (generalized EMD cost) of the tensor. + + References + ---------- + .. [56] Jeffery Kline. Properties of the d-dimensional earth mover's + problem. Discrete Applied Mathematics, 265: 128-141, 2019. + .. [57] Wolfgang W. Bein, Peter Brucker, James K. Park, and Pramod K. + Pathak. A monge property for the d-dimensional transportation problem. + Discrete Applied Mathematics, 58(2):97-109, 1995. ISSN 0166-218X. doi: + https://doi.org/10.1016/0166-218X(93)E0121-E. URL + https://www.sciencedirect.com/ science/article/pii/0166218X93E0121E. + Workshop on Discrete Algoritms. + """ + + return max(i) - min(i) + + +def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): + r""" + Compute the discrete multi-marginal optimal transport of distributions A. + + This function operates on distributions whose supports are real numbers on + the real line. + + The algorithm solves both primal and dual d-MMOT programs concurrently to + produce the optimal transport plan as well as the total (minimal) cost. + The cost is a ground cost, and the solution is independent of + which Monge cost is desired. + + The algorithm accepts :math:`d` distributions (i.e., histograms) + :math:`a_{1}, \ldots, a_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime} + a_{j}=1` for all :math:`j \in[d]`. Although the algorithm states that all + histograms have the same number of bins, the algorithm can be easily + adapted to accept as inputs :math:`a_{i} \in \mathbb{R}_{+}^{n_{i}}` + with :math:`n_{i} \neq n_{j}` [50]. + + The function solves the following optimization problem[51]: + + .. math:: + \begin{align}\begin{aligned} + \underset{\gamma\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}} + \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, \gamma(i_1,\ldots,i_d) + \quad \textrm{s.t.} + \sum_{i_2,\ldots,i_d} \gamma(i_1,\ldots,i_d) &= a_1(i_i), + (\forall i_1\in[n])\\ + \qquad\vdots\\ + \sum_{i_1,\ldots,i_{d-1}} \gamma(i_1,\ldots,i_d) &= a_{d}(i_{d}), + (\forall i_d\in[n]). + \end{aligned} + \end{align} + + + Parameters + ---------- + A : nx.ndarray, shape (dim, n_hists) + The input ndarray containing distributions of n bins in d dimensions. + verbose : bool, optional + If True, print debugging information during execution. Default=False. + log : bool, optional + If True, record log. Default is False. + + Returns + ------- + obj : float + the value of the primal objective function evaluated at the solution. + log : dict + A dictionary containing the log of the discrete mmot problem: + - 'A': a dictionary that maps tuples of indices to the corresponding + primal variables. The tuples are the indices of the entries that are + set to their minimum value during the algorithm. + - 'primal objective': a float, the value of the objective function + evaluated at the solution. + - 'dual': a list of arrays, the dual variables corresponding to + the input arrays. The i-th element of the list is the dual variable + corresponding to the i-th dimension of the input arrays. + - 'dual objective': a float, the value of the dual objective function + evaluated at the solution. + + + References + ---------- + .. [55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & + Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal + Transport Regularization. In The Eleventh International + Conference on Learning Representations. + .. [56] Jeffery Kline. Properties of the d-dimensional earth mover's + problem. Discrete Applied Mathematics, 265: 128-141, 2019. + .. [58] Leonid V Kantorovich. On the translocation of masses. Dokl. Akad. + Nauk SSSR, 37:227-229, 1942. + + See Also + -------- + ot.lp.dmmot_monge_1dgrid_optimize : Optimize the d-Dimensional Earth + Mover's Distance (d-MMOT) + """ + + nx = get_backend(A) + A_copy = A + A = nx.to_numpy(A) + + AA = [np.copy(A[:, j]) for j in range(A.shape[1])] + + dims = tuple([len(_) for _ in AA]) + xx = {} + dual = [np.zeros(d) for d in dims] + + idx = [0, ] * len(AA) + obj = 0 + + if verbose: + print('i minval oldidx\t\tobj\t\tvals') + + while all([i < _ for _, i in zip(dims, idx)]): + vals = [v[i] for v, i in zip(AA, idx)] + minval = min(vals) + i = vals.index(minval) + xx[tuple(idx)] = minval + obj += (dist_monge_max_min(idx)) * minval + for v, j in zip(AA, idx): + v[j] -= minval + # oldidx = nx.copy(idx) + oldidx = idx.copy() + idx[i] += 1 + if idx[i] < dims[i]: + temp = (dist_monge_max_min(idx) - + dist_monge_max_min(oldidx) + + dual[i][idx[i] - 1]) + dual[i][idx[i]] += temp + if verbose: + print(i, minval, oldidx, obj, '\t', vals) + + # the above terminates when any entry in idx equals the corresponding + # value in dims this leaves other dimensions incomplete; the remaining + # terms of the dual solution must be filled-in + for _, i in enumerate(idx): + try: + dual[_][i:] = dual[_][i] + except Exception: + pass + + dualobj = sum([np.dot(A[:, i], arr) for i, arr in enumerate(dual)]) + obj = nx.from_numpy(obj) + + log_dict = {'A': xx, + 'primal objective': obj, + 'dual': dual, + 'dual objective': dualobj} + + # define forward/backward relations for pytorch + obj = nx.set_gradients(obj, (A_copy), (dual)) + + if log: + return obj, log_dict + else: + return obj + + +def dmmot_monge_1dgrid_optimize( + A, + niters=100, + lr_init=1e-5, + lr_decay=0.995, + print_rate=100, + verbose=False, + log=False): + r"""Minimize the d-dimensional EMD using gradient descent. + + Discrete Multi-Marginal Optimal Transport (d-MMOT): Let :math:`a_1, \ldots, + a_d\in\mathbb{R}^n_{+}` be discrete probability distributions. Here, + the d-MMOT is the LP, + + .. math:: + \begin{align}\begin{aligned} + \underset{x\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}} + \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad + \textrm{s.t.} + \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= a_1(i_i), + (\forall i_1\in[n])\\ + \qquad\vdots\\ + \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= a_{d}(i_{d}), + (\forall i_d\in[n]). + \end{aligned} + \end{align} + + The dual linear program of the d-MMOT problem is: + + .. math:: + \underset{z_j\in\mathbb{R}^n, j\in[d]}{\textrm{maximize}}\qquad\sum_{j} + a_j'z_j\qquad \textrm{subject to}\qquad z_{1}(i_1)+\cdots+z_{d}(i_{d}) + \leq c(i_1,\ldots,i_{d}), + + + where the indices in the constraints include all :math:`i_j\in[n]`, :math: + `j\in[d]`. Denote by :math:`\phi(a_1,\ldots,a_d)`, the optimal objective + value of the LP in d-MMOT problem. Let :math:`z^*` be an optimal solution + to the dual program. Then, + + .. math:: + \begin{align} + \nabla \phi(a_1,\ldots,a_{d}) &= z^*, + ~~\text{and for any $t\in \mathbb{R}$,}~~ + \phi(a_1,a_2,\ldots,a_{d}) = \sum_{j}a_j' + (z_j^* + t\, \eta), \nonumber \\ + \text{where } \eta &:= (z_1^{*}(n)\,e, z^*_1(n)\,e, \cdots, + z^*_{d}(n)\,e) + \end{align} + + Using these dual variables naturally provided by the algorithm in + ot.lp.dmmot_monge_1dgrid_loss, gradient steps move each input distribution + to minimize their d-mmot distance. + + Parameters + ---------- + A : nx.ndarray, shape (dim, n_hists) + The input ndarray containing distributions of n bins in d dimensions. + niters : int, optional (default=100) + The maximum number of iterations for the optimization algorithm. + lr_init : float, optional (default=1e-5) + The initial learning rate (step size) for the optimization algorithm. + lr_decay : float, optional (default=0.995) + The learning rate decay rate in each iteration. + print_rate : int, optional (default=100) + The rate at which to print the objective value and gradient norm + during the optimization algorithm. + verbose : bool, optional + If True, print debugging information during execution. Default=False. + log : bool, optional + If True, record log. Default is False. + + Returns + ------- + a : list of ndarrays, each of shape (n,) + The optimal solution as a list of n approximate barycenters, each of + length vecsize. + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + .. [55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & + Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal + Transport Regularization. In The Eleventh International + Conference on Learning Representations. + .. [60] Olvi L Mangasarian and RR Meyer. Nonlinear perturbation of linear + programs. SIAM Journal on Control and Optimization, 17(6):745-752, 1979 + .. [59] Michael C Ferris and Olvi L Mangasarian. Finite perturbation of + convex programs. Applied Mathematics and Optimization, 23(1):263-273, + 1991. + + See Also + -------- + ot.lp.dmmot_monge_1dgrid_loss: d-Dimensional Earth Mover's Solver + """ + + # function body here + nx = get_backend(A) + A = nx.to_numpy(A) + n, d = A.shape # n is dim, d is n_hists + + def dualIter(A, lr): + funcval, log_dict = dmmot_monge_1dgrid_loss( + A, verbose=verbose, log=True) + grad = np.column_stack(log_dict['dual']) + A_new = np.reshape(A, (n, d)) - grad * lr + return funcval, A_new, grad, log_dict + + def renormalize(A): + A = np.reshape(A, (n, d)) + for i in range(A.shape[1]): + if min(A[:, i]) < 0: + A[:, i] -= min(A[:, i]) + A[:, i] /= np.sum(A[:, i]) + return A + + def listify(A): + return [A[:, i] for i in range(A.shape[1])] + + lr = lr_init + + funcval, _, grad, log_dict = dualIter(A, lr) + gn = np.linalg.norm(grad) + + print(f'Inital:\t\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') + + for i in range(niters): + + A = renormalize(A) + funcval, A, grad, log_dict = dualIter(A, lr) + gn = np.linalg.norm(grad) + + if i % print_rate == 0: + print(f'Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') + + lr *= lr_decay + + A = renormalize(A) + a = listify(A) + + if log: + return a, log_dict + else: + return a diff --git a/test/test_dmmot.py b/test/test_dmmot.py new file mode 100644 index 000000000..fa8dc6b89 --- /dev/null +++ b/test/test_dmmot.py @@ -0,0 +1,76 @@ +"""Tests for ot.lp.dmmot module """ + +# Author: Ronak Mehta +# Xizheng Yu +# +# License: MIT License + +import numpy as np +import ot + + +def create_test_data(nx): + np.random.seed(1234) + n = 4 + a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) + a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + A = np.vstack([a1, a2]).T + x = np.arange(n, dtype=np.float64).reshape((n, 1)) + A, x = nx.from_numpy(A, x) + return A, x + + +def test_dmmot_monge_1dgrid_loss(nx): + A, x = create_test_data(nx) + + # Compute loss using dmmot_monge_1dgrid_loss + primal_obj = ot.lp.dmmot_monge_1dgrid_loss(A) + primal_obj = nx.to_numpy(primal_obj) + expected_primal_obj = 0.13667759626298503 + + np.testing.assert_allclose(primal_obj, + expected_primal_obj, + rtol=1e-7, + err_msg="Test failed: \ + Expected different primal objective value") + + # Compute loss using exact OT solver with absolute ground metric + A, x = nx.to_numpy(A, x) + M = ot.utils.dist(x, metric='cityblock') # absolute ground metric + bary, _ = ot.barycenter(A, M, 1e-2, weights=None, verbose=False, log=True) + ot_obj = 0.0 + for x in A.T: + # deal with C-contiguous error from tensorflow backend (not sure why) + x = np.ascontiguousarray(x) + # compute loss + _, log = ot.lp.emd(x, np.array(bary / np.sum(bary)), M, log=True) + ot_obj += log['cost'] + + np.testing.assert_allclose(primal_obj, + ot_obj, + rtol=1e-7, + err_msg="Test failed: \ + Expected different primal objective value") + + +def test_dmmot_monge_1dgrid_optimize(nx): + # test discrete_mmot_converge result + A, _ = create_test_data(nx) + d = 2 + niters = 10 + result = ot.lp.dmmot_monge_1dgrid_optimize(A, + niters, + lr_init=1e-3, + lr_decay=1) + + expected_obj = np.array([[0.05553516, 0.13082618, 0.27327479, 0.54036388], + [0.04185365, 0.09570724, 0.24384705, 0.61859206]]) + + assert len(result) == d, "Test failed: Expected a list of length n" + for i in range(d): + np.testing.assert_allclose(result[i], + expected_obj[i], + atol=1e-7, + rtol=1e-7, + err_msg="Test failed: \ + Expected vectors of all zeros")