@@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
338
338
- :math:`\mathbf{q}`: distribution in the target space
339
339
- `L`: loss function to account for the misfit between the similarity matrices
340
340
341
+ .. note:: This function is backend-compatible and will work on arrays
342
+ from all compatible backends. But the algorithm uses the C++ CPU backend
343
+ which can lead to copy overhead on GPU arrays.
344
+
341
345
Parameters
342
346
----------
343
347
C1 : array-like, shape (ns, ns)
@@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
436
440
Note that when using backends, this loss function is differentiable wrt the
437
441
marices and weights for quadratic loss using the gradients from [38]_.
438
442
443
+ .. note:: This function is backend-compatible and will work on arrays
444
+ from all compatible backends. But the algorithm uses the C++ CPU backend
445
+ which can lead to copy overhead on GPU arrays.
446
+
439
447
Parameters
440
448
----------
441
449
C1 : array-like, shape (ns, ns)
@@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
545
553
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
546
554
- `L` is a loss function to account for the misfit between the similarity matrices
547
555
556
+ .. note:: This function is backend-compatible and will work on arrays
557
+ from all compatible backends. But the algorithm uses the C++ CPU backend
558
+ which can lead to copy overhead on GPU arrays.
559
+
548
560
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
549
561
550
562
Parameters
@@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
645
657
The algorithm used for solving the problem is conditional gradient as
646
658
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
647
659
660
+ .. note:: This function is backend-compatible and will work on arrays
661
+ from all compatible backends. But the algorithm uses the C++ CPU backend
662
+ which can lead to copy overhead on GPU arrays.
663
+
648
664
Note that when using backends, this loss function is differentiable wrt the
649
665
marices and weights for quadratic loss using the gradients from [38]_.
650
666
0 commit comments