Skip to content

Inconsistent of log in entropic_gromov_barycenters and gromov_barycenters when cal gw with log information #317

@cshjin

Description

@cshjin

Describe the bug

To Reproduce

Steps to reproduce the behavior:

  1. Calculating the barycenter with the optional arg log=True.
  2. log=True in gromov_wasserstein returns an additional log dictionary, similar as entropic_gromov_wasserstein

Screenshots

image

image

Code sample

import networkx as nx
import numpy as np
from scipy.sparse.csgraph import shortest_path
from ot.gromov import gromov_barycenters

Gs = [nx.cycle_graph(4)]
Ds = [shortest_path(nx.adjacency_matrix(g)) for g in Gs]
ps = [np.ones(4) / 4]
lambdas = np.ones(len(Gs)) / len(Gs)
N = 4
p = np.ones(N) / N

C = gromov_barycenters(N, Ds, ps, p, lambdas, "square_loss", log=True)

Expected behavior

  • the internal log information from gromov_barycenter is not necessary.
  • only the error in each iteration in the while loop need to be recorded.
  • return C if log is False, return C, {"err": [...]} if log=True

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux):
  • Python version:
  • How was POT installed (source, pip, conda):
  • Build command you used (if compiling from source):
  • Only for GPU related bugs:
    • CUDA version:
    • GPU models and configuration:
    • Any other relevant information:

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

# output:
Linux-5.4.0-70-generic-x86_64-with-debian-bullseye-sid
Python 3.7.10 (default, Feb 26 2021, 18:47:35) 
[GCC 7.3.0]
NumPy 1.19.2
SciPy 1.6.2
POT 0.7.0

Additional context

The issue happens in the version 0.7.0, but I checked the code in the latest version (0.8.0).
The problem exists as well.

Issue happens in the following lines when log=True

POT/ot/gromov.py

Lines 1506 to 1507 in cb51064

T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]

POT/ot/gromov.py

Lines 1520 to 1521 in cb51064

if log:
log['err'].append(err)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions