Skip to content

Commit 71ad677

Browse files
committed
take comments into account$
1 parent 1122565 commit 71ad677

File tree

3 files changed

+11
-16
lines changed

3 files changed

+11
-16
lines changed

ot/gromov/_bregman.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -826,9 +826,7 @@ def entropic_fused_gromov_wasserstein2(
826826
logv['T'] = T
827827

828828
lin_term = nx.sum(T * M)
829-
gw_term = (logv['fgw_dist'] - (1 - alpha) * lin_term) / alpha
830-
831-
logv['quad_loss'] = gw_term * alpha
829+
logv['quad_loss'] = (logv['fgw_dist'] - (1 - alpha) * lin_term)
832830
logv['lin_loss'] = lin_term * (1 - alpha)
833831

834832
if log:

ot/gromov/_gw.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -584,9 +584,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
584584

585585
# compute separate terms for gradients and log
586586
lin_term = nx.sum(T * M)
587-
gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha
588-
589-
log_fgw['quad_loss'] = gw_term * alpha
587+
log_fgw['quad_loss'] = (fgw_dist - (1 - alpha) * lin_term)
590588
log_fgw['lin_loss'] = lin_term * (1 - alpha)
591589

592590
if loss_fun == 'square_loss':

ot/solvers.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
5959
Unbalanced penalization weight :math:`\lambda_u`, by default None
6060
(balanced OT)
6161
unbalanced_type : str, optional
62-
Type of unbalanced penalization unction :math:`U` either "KL", "L2", "TV", by default "KL"
62+
Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL"
6363
n_threads : int, optional
6464
Number of OMP threads for exact OT solver, by default 1
6565
max_iter : int, optional
66-
Maximum number of iteration, by default None (default values in each solvers)
66+
Maximum number of iterations, by default None (default values in each solvers)
6767
plan_init : array_like, shape (dim_a, dim_b), optional
6868
Initialization of the OT plan for iterative methods, by default None
6969
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
@@ -391,30 +391,29 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None,
391391
Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"``
392392
symmetric : bool, optional
393393
Use symmetric version of the Gromov-Wasserstein problem, by default None
394-
tests wether the matrices are symmetric or True/False to avoid the test.
394+
tests whether the matrices are symmetric or True/False to avoid the test.
395395
reg : float, optional
396396
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
397397
OT)
398398
reg_type : str, optional
399-
Type of regularization :math:`R`, by default "entropic" (only used when
399+
Type of regularization :math:`R`, by default "entropy" (only used when
400400
``reg!=None``)
401401
alpha : float, optional
402402
Weight the quadratic term (alpha*Gromov) and the linear term
403403
((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for
404404
Gromov problem (when M is not provided). By default ``alpha=None``
405-
corresponds to to
406-
``alpha=1`` for Gromov problem (``M==None``) and ``alpha=0.5`` for Fused
407-
Gromov-Wasserstein problem (``M!=None``)
405+
corresponds to ``alpha=1`` for Gromov problem (``M==None``) and
406+
``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``)
408407
unbalanced : float, optional
409408
Unbalanced penalization weight :math:`\lambda_u`, by default None
410409
(balanced OT), Not implemented yet
411410
unbalanced_type : str, optional
412-
Type of unbalanced penalization unction :math:`U` either "KL", "semirelaxed",
413-
"partial", by default "KL" , Not implemented yet
411+
Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed",
412+
"partial", by default "KL" but note that it is not implemented yet.
414413
n_threads : int, optional
415414
Number of OMP threads for exact OT solver, by default 1
416415
max_iter : int, optional
417-
Maximum number of iteration, by default None (default values in each
416+
Maximum number of iterations, by default None (default values in each
418417
solvers)
419418
plan_init : array_like, shape (dim_a, dim_b), optional
420419
Initialization of the OT plan for iterative methods, by default None

0 commit comments

Comments
 (0)