Skip to content

Commit 632bc9a

Browse files
author
Hicham Janati
committed
update docstrings + init
1 parent adf9d04 commit 632bc9a

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

ot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,4 @@
3636
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
3737
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
3838
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
39-
'sinkhorn_unbalanced']
39+
'sinkhorn_unbalanced', "barycenter_unbalanced"]

ot/unbalanced.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
1919
The function solves the following optimization problem:
2020
2121
.. math::
22-
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + alpha KL(\gamma 1, a) + alpha KL(\gamma^T 1, b)
22+
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \alpha KL(\gamma 1, a) + \alpha KL(\gamma^T 1, b)
2323
2424
s.t.
2525
\gamma\geq 0
@@ -43,9 +43,9 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
4343
M : np.ndarray (ns, nt)
4444
loss matrix
4545
reg : float
46-
Regularization term > 0
46+
Entropy regularization term > 0
4747
alpha : float
48-
Regulatization term > 0
48+
Marginal relaxation term > 0
4949
method : str
5050
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
5151
'sinkhorn_epsilon_scaling', see those function for specific parameters
@@ -128,7 +128,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
128128
The function solves the following optimization problem:
129129
130130
.. math::
131-
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + alpha KL(\gamma 1, a) + alpha KL(\gamma^T 1, b)
131+
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \alpha KL(\gamma 1, a) + \alpha KL(\gamma^T 1, b)
132132
133133
s.t.
134134
\gamma\geq 0
@@ -152,9 +152,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
152152
M : np.ndarray (ns,nt)
153153
loss matrix
154154
reg : float
155-
Regularization term > 0
156-
alpha: float
157-
Regularization term > 0
155+
Entropy regularization term > 0
156+
alpha : float
157+
Marginal relaxation term > 0
158158
method : str
159159
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
160160
'sinkhorn_epsilon_scaling', see those function for specific parameters
@@ -239,7 +239,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
239239
The function solves the following optimization problem:
240240
241241
.. math::
242-
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + alpha KL(\gamma 1, a) + alpha KL(\gamma^T 1, b)
242+
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \alpha KL(\gamma 1, a) + \alpha KL(\gamma^T 1, b)
243243
244244
s.t.
245245
\gamma\geq 0
@@ -263,9 +263,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
263263
M : np.ndarray (ns,nt)
264264
loss matrix
265265
reg : float
266-
Regularization term > 0
267-
alpha: float
268-
Regularization term > 0
266+
Entropy regularization term > 0
267+
alpha : float
268+
Marginal relaxation term > 0
269269
numItermax : int, optional
270270
Max number of iterations
271271
stopThr : float, optional
@@ -410,7 +410,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
410410
411411
where :
412412
413-
- :math:`W_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
413+
- :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
414414
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
415415
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
416416
- alpha is the marginal relaxation hyperparameter
@@ -423,9 +423,9 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
423423
M : np.ndarray (d,d)
424424
loss matrix for OT
425425
reg : float
426-
Regularization term > 0
426+
Entropy regularization term > 0
427427
alpha : float
428-
Regularization term > 0
428+
Marginal relaxation term > 0
429429
weights : np.ndarray (n,)
430430
Weights of each histogram a_i on the simplex (barycentric coodinates)
431431
numItermax : int, optional

0 commit comments

Comments
 (0)