Skip to content

Commit fd7a9e6

Browse files
author
Kilian Fatras
committed
add tests and examples
1 parent db9c2bd commit fd7a9e6

File tree

3 files changed

+343
-80
lines changed

3 files changed

+343
-80
lines changed

examples/plot_stochastic.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
==========================
3+
Stochastic examples
4+
==========================
5+
6+
This example is designed to show how to use the stochatic optimization
7+
algorithms for descrete and semicontinous measures from the POT library.
8+
9+
"""
10+
11+
# Author: Kilian Fatras <kilian.fatras@gmail.com>
12+
#
13+
# License: MIT License
14+
15+
import matplotlib.pylab as pl
16+
import numpy as np
17+
import ot
18+
19+
20+
#############################################################################
21+
# COMPUTE TRANSPORTATION MATRIX
22+
#############################################################################
23+
24+
#############################################################################
25+
# DISCRETE CASE
26+
# Sample two discrete measures for the discrete case
27+
# ---------------------------------------------
28+
#
29+
# Define 2 discrete measures a and b, the points where are defined the source
30+
# and the target measures and finally the cost matrix c.
31+
32+
n_source = 7
33+
n_target = 4
34+
eps = 1
35+
nb_iter = 10000
36+
lr = 0.1
37+
38+
a = (1./n_source) * np.ones(n_source)
39+
b = (1./n_target) * np.ones(n_target)
40+
X_source = np.arange(n_source)
41+
Y_target = np.arange(0, 2 * n_target, 2)
42+
M = np.abs(X_source[:, None] - Y_target[None, :])
43+
44+
#############################################################################
45+
#
46+
# Call the "SAG" method to find the transportation matrix in the discrete case
47+
# ---------------------------------------------
48+
#
49+
# Define the method "SAG", call ot.transportation_matrix_entropic and plot the
50+
# results.
51+
52+
method = "SAG"
53+
sag_pi = ot.stochastic.transportation_matrix_entropic(method, eps, a, b, M,
54+
n_source, n_target,
55+
nb_iter, lr)
56+
print(sag_pi)
57+
58+
#############################################################################
59+
# SEMICONTINOUS CASE
60+
# Sample one general measure a, one discrete measures b for the semicontinous
61+
# case
62+
# ---------------------------------------------
63+
#
64+
# Define one general measure a, one discrete measures b, the points where
65+
# are defined the source and the target measures and finally the cost matrix c.
66+
67+
n_source = 7
68+
n_target = 4
69+
eps = 1
70+
nb_iter = 10000
71+
lr = 0.1
72+
73+
a = (1./n_source) * np.ones(n_source)
74+
b = (1./n_target) * np.ones(n_target)
75+
X_source = np.arange(n_source)
76+
Y_target = np.arange(0, 2 * n_target, 2)
77+
M = np.abs(X_source[:, None] - Y_target[None, :])
78+
79+
#############################################################################
80+
#
81+
# Call the "ASGD" method to find the transportation matrix in the semicontinous
82+
# case
83+
# ---------------------------------------------
84+
#
85+
# Define the method "ASGD", call ot.transportation_matrix_entropic and plot the
86+
# results.
87+
88+
method = "ASGD"
89+
asgd_pi = ot.stochastic.transportation_matrix_entropic(method, eps, a, b, M,
90+
n_source, n_target,
91+
nb_iter, lr)
92+
print(asgd_pi)
93+
94+
#############################################################################
95+
#
96+
# Compare the results with the Sinkhorn algorithm
97+
# ---------------------------------------------
98+
#
99+
# Call the Sinkhorn algorithm from POT
100+
101+
sinkhorn_pi = ot.sinkhorn(a, b, M, 1)
102+
print(sinkhorn_pi)
103+
104+
105+
##############################################################################
106+
# PLOT TRANSPORTATION MATRIX
107+
##############################################################################
108+
109+
##############################################################################
110+
# Plot SAG results
111+
# ----------------
112+
113+
pl.figure(4, figsize=(5, 5))
114+
ot.plot.plot1D_mat(a, b, sag_pi, 'OT matrix SAG')
115+
pl.show()
116+
117+
118+
##############################################################################
119+
# Plot ASGD results
120+
# -----------------
121+
122+
pl.figure(4, figsize=(5, 5))
123+
ot.plot.plot1D_mat(a, b, asgd_pi, 'OT matrix ASGD')
124+
pl.show()
125+
126+
127+
##############################################################################
128+
# Plot Sinkhorn results
129+
# ---------------------
130+
131+
pl.figure(4, figsize=(5, 5))
132+
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
133+
pl.show()

0 commit comments

Comments
 (0)