-
Notifications
You must be signed in to change notification settings - Fork 528
Closed
Labels
Description
Describe the bug
To Reproduce
Steps to reproduce the behavior:
- Calculating the barycenter with the optional arg
log=True
. log=True
ingromov_wasserstein
returns an additional log dictionary, similar asentropic_gromov_wasserstein
Screenshots
Code sample
import networkx as nx
import numpy as np
from scipy.sparse.csgraph import shortest_path
from ot.gromov import gromov_barycenters
Gs = [nx.cycle_graph(4)]
Ds = [shortest_path(nx.adjacency_matrix(g)) for g in Gs]
ps = [np.ones(4) / 4]
lambdas = np.ones(len(Gs)) / len(Gs)
N = 4
p = np.ones(N) / N
C = gromov_barycenters(N, Ds, ps, p, lambdas, "square_loss", log=True)
Expected behavior
- the internal log information from
gromov_barycenter
is not necessary. - only the error in each iteration in the while loop need to be recorded.
- return C if log is False, return C, {"err": [...]} if log=True
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux):
- Python version:
- How was POT installed (source,
pip
,conda
): - Build command you used (if compiling from source):
- Only for GPU related bugs:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
# output:
Linux-5.4.0-70-generic-x86_64-with-debian-bullseye-sid
Python 3.7.10 (default, Feb 26 2021, 18:47:35)
[GCC 7.3.0]
NumPy 1.19.2
SciPy 1.6.2
POT 0.7.0
Additional context
The issue happens in the version 0.7.0, but I checked the code in the latest version (0.8.0).
The problem exists as well.
Issue happens in the following lines when log=True
Lines 1506 to 1507 in cb51064
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, | |
numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)] |
Lines 1520 to 1521 in cb51064
if log: | |
log['err'].append(err) |