-
Notifications
You must be signed in to change notification settings - Fork 528
Description
When running the example plot_UOT_barycenter_1D.py the weights get passed into unbalanced.py but are interpreted as the method. Therefore an error comes up:
Traceback (most recent call last):
File "<ipython-input-56-90bb6b274118>", line 1, in <module>
runfile('.../plot_UOT_barycenter_1D(1).py', wdir='C:/Users/kvstr/Downloads')
File "...\Anaconda3\envs\my_env\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 827, in runfile
execfile(filename, namespace)
File "...\Anaconda3\envs\my_env\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 110, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File ".../plot_UOT_barycenter_1D(1).py", line 80, in <module>
bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
File "...\Anaconda3\envs\my_env\lib\site-packages\ot\unbalanced.py", line 1003, in barycenter_unbalanced
if method.lower() == 'sinkhorn':
AttributeError: 'numpy.ndarray' object has no attribute 'lower'
I have tested on multiple machines with clean python installations, mainly on Windows 10 64-bit, but could also be reproduced on a Linux machine.
Windows-10-10.0.18362-SP0, although also reproduced on Linux
Python 3.7.6 (default, Jan 8 2020, 20:23:39) [MSC v.1916 64 bit (AMD64)], also tested in a 3.6 Anaconda environment with same result
NumPy 1.18.1
SciPy 1.4.1
POT 0.6.0
By just downloading and running the example file from readthedocs it should be reproducable.
As far as I can tell, changing the function calls in lines 80 and 114 from
ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
to
ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights)
fixes this issue. There is a seperate issue I found with the unbalanced.barycenter_unbalanced funtion that still keeps the example from working which I will raise seperately.