Skip to content

Commit 37e3b29

Browse files
author
Kilian Fatras
committed
fixed argument functions
1 parent cd193f7 commit 37e3b29

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

ot/stochastic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
435435
##############################################################################
436436

437437

438-
def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha,
438+
def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
439439
batch_beta):
440440
'''
441441
Computes the partial gradient of F_\W_varepsilon
@@ -528,7 +528,7 @@ def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha,
528528
return grad_alpha, grad_beta
529529

530530

531-
def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr):
531+
def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
532532
'''
533533
Compute the sgd algorithm to solve the regularized discrete measures
534534
optimal transport dual problem
@@ -612,7 +612,7 @@ def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr):
612612
k = np.sqrt(cur_iter / 100 + 1)
613613
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
614614
batch_beta = np.random.choice(n_target, batch_size, replace=False)
615-
update_alpha, update_beta = batch_grad_dual(M, reg, a, b, cur_alpha,
615+
update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
616616
cur_beta, batch_size,
617617
batch_alpha, batch_beta)
618618
cur_alpha += (lr / k) * update_alpha
@@ -698,7 +698,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
698698
arXiv preprint arxiv:1711.02283.
699699
'''
700700

701-
opt_alpha, opt_beta = sgd_entropic_regularization(M, reg, a, b, batch_size,
701+
opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
702702
numItermax, lr)
703703
pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) *
704704
a[:, None] * b[None, :])

test/test_stochastic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def test_dual_sgd_sinkhorn():
196196
reg = 1
197197
batch_size = 30
198198

199-
a = ot.datasets.get_1D_gauss(n, m=15, s=5) # m= mean, s= std
200-
b = ot.datasets.get_1D_gauss(n, m=15, s=5)
199+
a = ot.datasets.get_1D_gauss(n, 15, 5) # m= mean, s= std
200+
b = ot.datasets.get_1D_gauss(n, 15, 5)
201201
X_source = np.arange(n, dtype=np.float64)
202202
Y_target = np.arange(n, dtype=np.float64)
203203
M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1)))

0 commit comments

Comments
 (0)