Skip to content

Commit 5bfa317

Browse files
author
Kilian Fatras
committed
adapted code to POT
1 parent fd7a9e6 commit 5bfa317

File tree

4 files changed

+219
-162
lines changed

4 files changed

+219
-162
lines changed

examples/plot_stochastic.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,17 @@
3131

3232
n_source = 7
3333
n_target = 4
34-
eps = 1
35-
nb_iter = 10000
34+
reg = 1
35+
numItermax = 10000
3636
lr = 0.1
3737

38-
a = (1./n_source) * np.ones(n_source)
39-
b = (1./n_target) * np.ones(n_target)
40-
X_source = np.arange(n_source)
41-
Y_target = np.arange(0, 2 * n_target, 2)
42-
M = np.abs(X_source[:, None] - Y_target[None, :])
38+
a = ot.utils.unif(n_source)
39+
b = ot.utils.unif(n_target)
40+
41+
rng = np.random.RandomState(0)
42+
X_source = rng.randn(n_source, 2)
43+
Y_target = rng.randn(n_target, 2)
44+
M = ot.dist(X_source, Y_target)
4345

4446
#############################################################################
4547
#
@@ -50,9 +52,8 @@
5052
# results.
5153

5254
method = "SAG"
53-
sag_pi = ot.stochastic.transportation_matrix_entropic(method, eps, a, b, M,
54-
n_source, n_target,
55-
nb_iter, lr)
55+
sag_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method,
56+
numItermax, lr)
5657
print(sag_pi)
5758

5859
#############################################################################
@@ -66,15 +67,17 @@
6667

6768
n_source = 7
6869
n_target = 4
69-
eps = 1
70-
nb_iter = 10000
71-
lr = 0.1
70+
reg = 1
71+
numItermax = 300000
72+
lr = 1
73+
74+
a = ot.utils.unif(n_source)
75+
b = ot.utils.unif(n_target)
7276

73-
a = (1./n_source) * np.ones(n_source)
74-
b = (1./n_target) * np.ones(n_target)
75-
X_source = np.arange(n_source)
76-
Y_target = np.arange(0, 2 * n_target, 2)
77-
M = np.abs(X_source[:, None] - Y_target[None, :])
77+
rng = np.random.RandomState(0)
78+
X_source = rng.randn(n_source, 2)
79+
Y_target = rng.randn(n_target, 2)
80+
M = ot.dist(X_source, Y_target)
7881

7982
#############################################################################
8083
#
@@ -86,9 +89,8 @@
8689
# results.
8790

8891
method = "ASGD"
89-
asgd_pi = ot.stochastic.transportation_matrix_entropic(method, eps, a, b, M,
90-
n_source, n_target,
91-
nb_iter, lr)
92+
asgd_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method,
93+
numItermax, lr)
9294
print(asgd_pi)
9395

9496
#############################################################################

ot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from . import datasets
1919
from . import da
2020
from . import gromov
21+
from . import stochastic
2122

2223
# OT functions
2324
from .lp import emd, emd2

0 commit comments

Comments
 (0)