Skip to content

Commit effe32f

Browse files
committed
update documenation
1 parent e1ca42f commit effe32f

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

ot/gromov.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
338338
- :math:`\mathbf{q}`: distribution in the target space
339339
- `L`: loss function to account for the misfit between the similarity matrices
340340
341+
.. note:: This function is backend-compatible and will work on arrays
342+
from all compatible backends. But the algorithm uses the C++ CPU backend
343+
which can lead to copy overhead on GPU arrays.
344+
341345
Parameters
342346
----------
343347
C1 : array-like, shape (ns, ns)
@@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
436440
Note that when using backends, this loss function is differentiable wrt the
437441
marices and weights for quadratic loss using the gradients from [38]_.
438442
443+
.. note:: This function is backend-compatible and will work on arrays
444+
from all compatible backends. But the algorithm uses the C++ CPU backend
445+
which can lead to copy overhead on GPU arrays.
446+
439447
Parameters
440448
----------
441449
C1 : array-like, shape (ns, ns)
@@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
545553
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
546554
- `L` is a loss function to account for the misfit between the similarity matrices
547555
556+
.. note:: This function is backend-compatible and will work on arrays
557+
from all compatible backends. But the algorithm uses the C++ CPU backend
558+
which can lead to copy overhead on GPU arrays.
559+
548560
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
549561
550562
Parameters
@@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
645657
The algorithm used for solving the problem is conditional gradient as
646658
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
647659
660+
.. note:: This function is backend-compatible and will work on arrays
661+
from all compatible backends. But the algorithm uses the C++ CPU backend
662+
which can lead to copy overhead on GPU arrays.
663+
648664
Note that when using backends, this loss function is differentiable wrt the
649665
marices and weights for quadratic loss using the gradients from [38]_.
650666

ot/lp/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
222222
format
223223
224224
.. note:: This function is backend-compatible and will work on arrays
225-
from all compatible backends.
225+
from all compatible backends. But the algorithm uses the C++ CPU backend
226+
which can lead to copy overhead on GPU arrays.
226227
227228
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
228229
@@ -360,7 +361,8 @@ def emd2(a, b, M, processes=1,
360361
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
361362
362363
.. note:: This function is backend-compatible and will work on arrays
363-
from all compatible backends.
364+
from all compatible backends. But the algorithm uses the C++ CPU backend
365+
which can lead to copy overhead on GPU arrays.
364366
365367
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
366368

ot/weak.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=
3333
3434
3535
.. note:: This function is backend-compatible and will work on arrays
36-
from all compatible backends.
36+
from all compatible backends. But the algorithm uses the C++ CPU backend
37+
which can lead to copy overhead on GPU arrays.
3738
3839
Uses the conditional gradient algorithm to solve the problem proposed
3940
in :ref:`[39] <references-weak>`.

0 commit comments

Comments
 (0)