-
Notifications
You must be signed in to change notification settings - Fork 528
Labels
Description
Using ot.da.SinkhornL1l2Transport for a domain adaptation problem, I faced an error as follows:
Datasets used:
** To Reproduce**
If you download the input files in C:\ , then the code is:
import numpy as np
import to
Xs = np.loadtxt("C: / Xs.txt").reshape(604, 5)
Xt = np.loadtxt("C: / Xt.txt").reshape(601, 5)
ys = np.loadtxt("C: / ys.txt")
ot_base = ot.da.SinkhornL1l2Transport(reg_e=10000, reg_cl=100, max_iter=100, verbose=True)
ot_base.fit(Xs=Xs, ys=ys, Xt=Xt)
Result and Error
It. |Loss |Relative loss|Absolute loss
------------------------------------------------
0|5.193677e+06|0.000000e+00|0.000000e+00
1|3.150847e+05|1.548343e+01|4.878593e+06
2|2.668420e+05|1.807914e-01|4.824274e+04
3|2.663638e+05|1.795333e-03|4.782117e+02
4|2.663590e+05|1.786689e-05|4.759007e+00
5|2.663590e+05|1.312580e-07|3.496174e-02
6|2.663588e+05|7.339658e-07|1.954982e-01
7|2.663106e+05|1.808094e-04|4.815146e+01
C:\Users\enayat.aria\PycharmProjects\pythonProject\venv\lib\site-packages\ot\optim.py:357: RuntimeWarning:
invalid value encountered in log
return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G)
Traceback (most recent call last):
File "C:\Program Files\JetBrains\PyCharm Community Edition 2021.1.2\plugins\python-
ce\helpers\pydev\pydevd.py", line 1483, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "C:\Program Files\JetBrains\PyCharm Community Edition 2021.1.2\plugins\python-
ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "C:/Users/enayat.aria/PycharmProjects/pythonProject/OT_for_DA.py", line 258, in <module>
ot_base.fit(Xs=Xs, ys=ys, Xt=Xt)
File "C:\Users\enayat.aria\PycharmProjects\pythonProject\venv\lib\site-packages\ot\da.py", line 1950, in fit
returned_ = sinkhorn_l1l2_gl(
File "C:\Users\enayat.aria\PycharmProjects\pythonProject\venv\lib\site-packages\ot\da.py", line 239, in
sinkhorn_l1l2_gl
return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax,
File "C:\Users\enayat.aria\PycharmProjects\pythonProject\venv\lib\site-packages\ot\optim.py", line 388, in gcg
G = G + alpha * deltaG
TypeError: unsupported operand type(s) for *: 'NoneType' and 'float'
Checking the parameters, I found that the G matrix obtained in the optim.py code has negative values in the last iteration; due to the last update.
Please let me know how to solve the problem, or if I should provide more information.
Best,
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Windows 10
- Python version: 3.9
- How was POT installed (source,
pip
,conda
): pip