Skip to content

Commit 5b8ffa6

Browse files
authored
Merge branch 'master' into stochastic_OT
2 parents cd193f7 + 5cd6c0a commit 5b8ffa6

25 files changed

+3269
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,4 @@ You can also post bug reports and feature requests in Github issues. Make sure t
228228
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)
229229

230230
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
231+
20.1 KB
Loading
45 KB
Loading
28.7 KB
Loading
52.7 KB
Loading
578 KB
Loading
13.2 KB
Loading
20.9 KB
Loading
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"collapsed": false
8+
},
9+
"outputs": [],
10+
"source": [
11+
"%matplotlib inline"
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"\n# 1D Wasserstein barycenter comparison between exact LP and entropic regularization\n\n\nThis example illustrates the computation of regularized Wasserstein Barycenter\nas proposed in [3] and exact LP barycenters using standard LP solver.\n\nIt reproduces approximately Figure 3.1 and 3.2 from the following paper:\nCuturi, M., & Peyr\u00e9, G. (2016). A smoothed dual approach for variational\nWasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.\n\n[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyr\u00e9, G. (2015).\nIterative Bregman projections for regularized transportation problems\nSIAM Journal on Scientific Computing, 37(2), A1111-A1138.\n\n\n"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {
25+
"collapsed": false
26+
},
27+
"outputs": [],
28+
"source": [
29+
"# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nfrom matplotlib.collections import PolyCollection # noqa\n\n#import ot.lp.cvx as cvx"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"Gaussian Data\n-------------\n\n"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": null,
42+
"metadata": {
43+
"collapsed": false
44+
},
45+
"outputs": [],
46+
"source": [
47+
"#%% parameters\n\nproblems = []\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])"
48+
]
49+
},
50+
{
51+
"cell_type": "markdown",
52+
"metadata": {},
53+
"source": [
54+
"Dirac Data\n----------\n\n"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"metadata": {
61+
"collapsed": false
62+
},
63+
"outputs": [],
64+
"source": [
65+
"#%% parameters\n\na1 = 1.0 * (x > 10) * (x < 50)\na2 = 1.0 * (x > 60) * (x < 80)\n\na1 /= a1.sum()\na2 /= a2.sum()\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()\n\n#%% parameters\n\na1 = np.zeros(n)\na2 = np.zeros(n)\n\na1[10] = .25\na1[20] = .5\na1[30] = .25\na2[80] = 1\n\n\na1 /= a1.sum()\na2 /= a2.sum()\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()"
66+
]
67+
},
68+
{
69+
"cell_type": "markdown",
70+
"metadata": {},
71+
"source": [
72+
"Final figure\n------------\n\n\n"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": null,
78+
"metadata": {
79+
"collapsed": false
80+
},
81+
"outputs": [],
82+
"source": [
83+
"#%% plot\n\nnbm = len(problems)\nnbm2 = (nbm // 2)\n\n\npl.figure(2, (20, 6))\npl.clf()\n\nfor i in range(nbm):\n\n A = problems[i][0]\n bary_l2 = problems[i][1][0]\n bary_wass = problems[i][1][1]\n bary_wass2 = problems[i][1][2]\n\n pl.subplot(2, nbm, 1 + i)\n for j in range(n_distributions):\n pl.plot(x, A[:, j])\n if i == nbm2:\n pl.title('Distributions')\n pl.xticks(())\n pl.yticks(())\n\n pl.subplot(2, nbm, 1 + i + nbm)\n\n pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')\n pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\n pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\n if i == nbm - 1:\n pl.legend()\n if i == nbm2:\n pl.title('Barycenters')\n\n pl.xticks(())\n pl.yticks(())"
84+
]
85+
}
86+
],
87+
"metadata": {
88+
"kernelspec": {
89+
"display_name": "Python 3",
90+
"language": "python",
91+
"name": "python3"
92+
},
93+
"language_info": {
94+
"codemirror_mode": {
95+
"name": "ipython",
96+
"version": 3
97+
},
98+
"file_extension": ".py",
99+
"mimetype": "text/x-python",
100+
"name": "python",
101+
"nbconvert_exporter": "python",
102+
"pygments_lexer": "ipython3",
103+
"version": "3.6.5"
104+
}
105+
},
106+
"nbformat": 4,
107+
"nbformat_minor": 0
108+
}
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=================================================================================
4+
1D Wasserstein barycenter comparison between exact LP and entropic regularization
5+
=================================================================================
6+
7+
This example illustrates the computation of regularized Wasserstein Barycenter
8+
as proposed in [3] and exact LP barycenters using standard LP solver.
9+
10+
It reproduces approximately Figure 3.1 and 3.2 from the following paper:
11+
Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational
12+
Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.
13+
14+
[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
15+
Iterative Bregman projections for regularized transportation problems
16+
SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
17+
18+
"""
19+
20+
# Author: Remi Flamary <remi.flamary@unice.fr>
21+
#
22+
# License: MIT License
23+
24+
import numpy as np
25+
import matplotlib.pylab as pl
26+
import ot
27+
# necessary for 3d plot even if not used
28+
from mpl_toolkits.mplot3d import Axes3D # noqa
29+
from matplotlib.collections import PolyCollection # noqa
30+
31+
#import ot.lp.cvx as cvx
32+
33+
##############################################################################
34+
# Gaussian Data
35+
# -------------
36+
37+
#%% parameters
38+
39+
problems = []
40+
41+
n = 100 # nb bins
42+
43+
# bin positions
44+
x = np.arange(n, dtype=np.float64)
45+
46+
# Gaussian distributions
47+
# Gaussian distributions
48+
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
49+
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
50+
51+
# creating matrix A containing all distributions
52+
A = np.vstack((a1, a2)).T
53+
n_distributions = A.shape[1]
54+
55+
# loss matrix + normalization
56+
M = ot.utils.dist0(n)
57+
M /= M.max()
58+
59+
60+
#%% plot the distributions
61+
62+
pl.figure(1, figsize=(6.4, 3))
63+
for i in range(n_distributions):
64+
pl.plot(x, A[:, i])
65+
pl.title('Distributions')
66+
pl.tight_layout()
67+
68+
#%% barycenter computation
69+
70+
alpha = 0.5 # 0<=alpha<=1
71+
weights = np.array([1 - alpha, alpha])
72+
73+
# l2bary
74+
bary_l2 = A.dot(weights)
75+
76+
# wasserstein
77+
reg = 1e-3
78+
ot.tic()
79+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
80+
ot.toc()
81+
82+
83+
ot.tic()
84+
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
85+
ot.toc()
86+
87+
pl.figure(2)
88+
pl.clf()
89+
pl.subplot(2, 1, 1)
90+
for i in range(n_distributions):
91+
pl.plot(x, A[:, i])
92+
pl.title('Distributions')
93+
94+
pl.subplot(2, 1, 2)
95+
pl.plot(x, bary_l2, 'r', label='l2')
96+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
97+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
98+
pl.legend()
99+
pl.title('Barycenters')
100+
pl.tight_layout()
101+
102+
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
103+
104+
##############################################################################
105+
# Dirac Data
106+
# ----------
107+
108+
#%% parameters
109+
110+
a1 = 1.0 * (x > 10) * (x < 50)
111+
a2 = 1.0 * (x > 60) * (x < 80)
112+
113+
a1 /= a1.sum()
114+
a2 /= a2.sum()
115+
116+
# creating matrix A containing all distributions
117+
A = np.vstack((a1, a2)).T
118+
n_distributions = A.shape[1]
119+
120+
# loss matrix + normalization
121+
M = ot.utils.dist0(n)
122+
M /= M.max()
123+
124+
125+
#%% plot the distributions
126+
127+
pl.figure(1, figsize=(6.4, 3))
128+
for i in range(n_distributions):
129+
pl.plot(x, A[:, i])
130+
pl.title('Distributions')
131+
pl.tight_layout()
132+
133+
134+
#%% barycenter computation
135+
136+
alpha = 0.5 # 0<=alpha<=1
137+
weights = np.array([1 - alpha, alpha])
138+
139+
# l2bary
140+
bary_l2 = A.dot(weights)
141+
142+
# wasserstein
143+
reg = 1e-3
144+
ot.tic()
145+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
146+
ot.toc()
147+
148+
149+
ot.tic()
150+
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
151+
ot.toc()
152+
153+
154+
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
155+
156+
pl.figure(2)
157+
pl.clf()
158+
pl.subplot(2, 1, 1)
159+
for i in range(n_distributions):
160+
pl.plot(x, A[:, i])
161+
pl.title('Distributions')
162+
163+
pl.subplot(2, 1, 2)
164+
pl.plot(x, bary_l2, 'r', label='l2')
165+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
166+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
167+
pl.legend()
168+
pl.title('Barycenters')
169+
pl.tight_layout()
170+
171+
#%% parameters
172+
173+
a1 = np.zeros(n)
174+
a2 = np.zeros(n)
175+
176+
a1[10] = .25
177+
a1[20] = .5
178+
a1[30] = .25
179+
a2[80] = 1
180+
181+
182+
a1 /= a1.sum()
183+
a2 /= a2.sum()
184+
185+
# creating matrix A containing all distributions
186+
A = np.vstack((a1, a2)).T
187+
n_distributions = A.shape[1]
188+
189+
# loss matrix + normalization
190+
M = ot.utils.dist0(n)
191+
M /= M.max()
192+
193+
194+
#%% plot the distributions
195+
196+
pl.figure(1, figsize=(6.4, 3))
197+
for i in range(n_distributions):
198+
pl.plot(x, A[:, i])
199+
pl.title('Distributions')
200+
pl.tight_layout()
201+
202+
203+
#%% barycenter computation
204+
205+
alpha = 0.5 # 0<=alpha<=1
206+
weights = np.array([1 - alpha, alpha])
207+
208+
# l2bary
209+
bary_l2 = A.dot(weights)
210+
211+
# wasserstein
212+
reg = 1e-3
213+
ot.tic()
214+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
215+
ot.toc()
216+
217+
218+
ot.tic()
219+
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
220+
ot.toc()
221+
222+
223+
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
224+
225+
pl.figure(2)
226+
pl.clf()
227+
pl.subplot(2, 1, 1)
228+
for i in range(n_distributions):
229+
pl.plot(x, A[:, i])
230+
pl.title('Distributions')
231+
232+
pl.subplot(2, 1, 2)
233+
pl.plot(x, bary_l2, 'r', label='l2')
234+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
235+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
236+
pl.legend()
237+
pl.title('Barycenters')
238+
pl.tight_layout()
239+
240+
241+
##############################################################################
242+
# Final figure
243+
# ------------
244+
#
245+
246+
#%% plot
247+
248+
nbm = len(problems)
249+
nbm2 = (nbm // 2)
250+
251+
252+
pl.figure(2, (20, 6))
253+
pl.clf()
254+
255+
for i in range(nbm):
256+
257+
A = problems[i][0]
258+
bary_l2 = problems[i][1][0]
259+
bary_wass = problems[i][1][1]
260+
bary_wass2 = problems[i][1][2]
261+
262+
pl.subplot(2, nbm, 1 + i)
263+
for j in range(n_distributions):
264+
pl.plot(x, A[:, j])
265+
if i == nbm2:
266+
pl.title('Distributions')
267+
pl.xticks(())
268+
pl.yticks(())
269+
270+
pl.subplot(2, nbm, 1 + i + nbm)
271+
272+
pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')
273+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
274+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
275+
if i == nbm - 1:
276+
pl.legend()
277+
if i == nbm2:
278+
pl.title('Barycenters')
279+
280+
pl.xticks(())
281+
pl.yticks(())

0 commit comments

Comments
 (0)