Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions ot/batch/_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,90 @@ def h2(C2):
return compute_tensor_batch(f1, f2, h1, h2, a, b, C1, C2, symmetric=symmetric)


def div_to_product_batch(
T, a, b, T1=None, T2=None, divergence="kl", mass=True, nx=None
):
r"""Fast computation of the Bregman divergence between a batch of arbitrary measures and a product measures.
Only support for Kullback-Leibler and half-squared L2 divergences.

- For half-squared L2 divergence:

.. math::
\frac{1}{2} || \pi - a \otimes b ||^2
= \frac{1}{2} \Big[ \sum_{i, j} \pi_{ij}^2 + (\sum_i a_i^2) ( \sum_j b_j^2) - 2 \sum_{i, j} a_i \pi_{ij} b_j \Big]

- For Kullback-Leibler divergence:

.. math::
KL(\pi | a \otimes b)
= \langle \pi, \log \pi \rangle - \langle \pi_1, \log a \rangle
- \langle \pi_2, \log b \rangle - m(\pi) + m(a) m(b)

where :

- :math:`\pi` is the (`dim_a`, `dim_b`) transport plan
- :math:`\pi_1` and :math:`\pi_2` are the marginal distributions
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- :math:`m` denotes the mass of the measure

Parameters
----------
pi : array-like (B, n, m)
Transport plan for each problem in the batch
a : array-like (B,n)
Unnormalized histogram of dimension `n` for each problem in the batch
b : array-like (B,m)
Unnormalized histogram of dimension `m` for each problem in the batch
T1 : array-like (B, n), optional (default = None)
Marginal distribution with respect to the first dimension of the transport plan for each problem in the batch
Only used in case of Kullback-Leibler divergence.
T2 : array-like (B, m), optional (default = None)
Marginal distribution with respect to the second dimension of the transport plan for each problem in the batch
Only used in case of Kullback-Leibler divergence.
divergence : string, default = "kl"
Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence)
mass : bool, optional. Default is False.
Only used in case of Kullback-Leibler divergence.
If False, calculate the relative entropy.
If True, calculate the Kullback-Leibler divergence.
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
-------
Bregman divergence between an arbitrary measure and a product measure for each problem in the batch.
"""

arr = [T, a, b, T1, T2]

if nx is None:
nx = get_backend(*arr, T1, T2)

if divergence == "kl":
if T1 is None:
T1 = nx.sum(T, 2)
if T2 is None:
T2 = nx.sum(T, 1)

if divergence == "kl":
res = (
nx.sum((T * nx.log(T + 1.0 * (T == 0))), (1, 2))
- nx.sum(T1 * nx.log(a), 1)
- nx.sum(T2 * nx.log(b), 1)
)
if mass:
res = res - nx.sum(T1, 1) + nx.sum(a, 1) * nx.sum(b, 1)

elif divergence == "l2":
res = (
nx.sum(T**2, (1, 2))
+ nx.sum(a**2, 1) * nx.sum(b**2, 1)
- 2 * nx.sum((a * (T @ b[:, :, None]).squeeze(-1)), 1)
) / 2

return res


def loss_quadratic_batch(L, T, recompute_const=False, symmetric=True, nx=None):
r"""
Computes the gromov-wasserstein cost given a cost tensor and transport plan. Batched version.
Expand Down Expand Up @@ -266,6 +350,74 @@ def loss_quadratic_samples_batch(
)


def loss_fugw_batch(
L, M, T, alpha=0.5, reg_marginals=1, symmetric=True, divergence="kl", nx=None
):
r"""
Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (Gromov term), a cost matrix between features across domains (linear term) and a transport plan. Batched version.

Parameters
----------
L : dict
Cost tensor as returned by `tensor_batch`.
M : array-like, shape (B, n, m)
Cost matrix between features across domains.
T : array-like, shape (B, n, m)
Transport plan.
alpha : float or array-like( B,) optional
Weight the quadratic term (alpha*Gromov) and the linear term
((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha
a scalar it is used for all problems in the batch.
reg_marginals : float or array-like( B,) optional
Marginal relaxation terms. If rho is
a scalar it is used for all problems in the batch.
symmetric : bool, optional
Whether to use symmetric version. Default is True.
divergence : string, default = "kl"
Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence)
nx : module, optional
Backend to use. Default is None.

Examples
--------
>>> import numpy as np
>>> from ot.batch import tensor_batch, loss_quadratic_batch
>>> # Create batch of cost matrices
>>> C1 = np.random.rand(3, 5, 5) # 3 problems, 5x5 source matrices
>>> C2 = np.random.rand(3, 4, 4) # 3 problems, 4x4 target matrices
>>> a = np.ones((3, 5)) / 5 # Uniform source distributions
>>> b = np.ones((3, 4)) / 4 # Uniform target distributions
>>> L = tensor_batch(a, b, C1, C2, loss='sqeuclidean')
>>> # Use the uniform transport plan for testing
>>> T = np.ones((3, 5, 4)) / (5 * 4)
>>> loss = loss_quadratic_batch(L, T, recompute_const=True)
>>> loss.shape
(3,)

See Also
--------
ot.batch.tensor_batch : From computing the cost tensor L.
ot.batch.solve_gromov_batch : For finding the optimal transport plan T.
"""
if nx is None:
nx = get_backend(T)

Q = loss_quadratic_batch(L, T, recompute_const=True, symmetric=symmetric, nx=nx)

L = loss_linear_batch(M, T, nx=nx)

unbalanced = div_to_product_batch(
T,
a=nx.sum(T, axis=2),
b=nx.sum(T, axis=1),
divergence=divergence,
mass=True,
nx=nx,
)

return (1 - alpha) * L + alpha * Q + reg_marginals * unbalanced


def solve_gromov_batch(
C1,
C2,
Expand Down
Empty file.
Loading