@@ -435,7 +435,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
435
435
##############################################################################
436
436
437
437
438
- def batch_grad_dual (M , reg , a , b , alpha , beta , batch_size , batch_alpha ,
438
+ def batch_grad_dual (a , b , M , reg , alpha , beta , batch_size , batch_alpha ,
439
439
batch_beta ):
440
440
'''
441
441
Computes the partial gradient of F_\W_varepsilon
@@ -528,7 +528,7 @@ def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha,
528
528
return grad_alpha , grad_beta
529
529
530
530
531
- def sgd_entropic_regularization (M , reg , a , b , batch_size , numItermax , lr ):
531
+ def sgd_entropic_regularization (a , b , M , reg , batch_size , numItermax , lr ):
532
532
'''
533
533
Compute the sgd algorithm to solve the regularized discrete measures
534
534
optimal transport dual problem
@@ -612,7 +612,7 @@ def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr):
612
612
k = np .sqrt (cur_iter / 100 + 1 )
613
613
batch_alpha = np .random .choice (n_source , batch_size , replace = False )
614
614
batch_beta = np .random .choice (n_target , batch_size , replace = False )
615
- update_alpha , update_beta = batch_grad_dual (M , reg , a , b , cur_alpha ,
615
+ update_alpha , update_beta = batch_grad_dual (a , b , M , reg , cur_alpha ,
616
616
cur_beta , batch_size ,
617
617
batch_alpha , batch_beta )
618
618
cur_alpha += (lr / k ) * update_alpha
@@ -698,7 +698,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
698
698
arXiv preprint arxiv:1711.02283.
699
699
'''
700
700
701
- opt_alpha , opt_beta = sgd_entropic_regularization (M , reg , a , b , batch_size ,
701
+ opt_alpha , opt_beta = sgd_entropic_regularization (a , b , M , reg , batch_size ,
702
702
numItermax , lr )
703
703
pi = (np .exp ((opt_alpha [:, None ] + opt_beta [None , :] - M [:, :]) / reg ) *
704
704
a [:, None ] * b [None , :])
0 commit comments