-
Notifications
You must be signed in to change notification settings - Fork 528
Closed
Labels
Description
Describe the bug
Using a ot.dist(a, b, metric="cityblock")
is not possible, as in https://github.com/PythonOT/POT/blob/master/ot/utils.py#L224 an argument p
is appended, which the underlying scipy implementation https://github.com/scipy/scipy/blob/v1.7.1/scipy/spatial/distance.py#L1798 does not appear to support. I believe that argument was only recently added.
To Reproduce
Steps to reproduce the behavior:
import scipy.spatial as spat
import ot
a = np.random.normal(size=(3, 1))
b = np.random.normal(size=(3, 1))
print(spat.distance.cdist(a, b, metric="cityblock"))
# fails
# print(spat.distance.cdist(a, b, metric="cityblock", p=2))
print(ot.dist(a, b, metric="sqeuclidean"))
# fails
# print(ot.dist(a, b, metric="cityblock"))
Error:
Traceback (most recent call last):
File "/home/yannik/tmp/ot_err.py", line 13, in <module>
print(ot.dist(a, b, metric="cityblock"))
File "/home/yannik/pyenv/env/lib/python3.9/site-packages/ot/utils.py", line 224, in dist
return cdist(x1, x2, metric=metric, p=p)
File "/home/yannik/pyenv/env/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2954, in cdist
return cdist_fn(XA, XB, out=out, **kwargs)
TypeError: cdist_cityblock(): incompatible function arguments. The following argument types are supported:
1. (x: object, y: object, w: object = None, out: object = None) -> numpy.ndarray
Invoked with: array([[ 0.60832356],
[ 2.05290883],
[-0.54136776]]), array([[ 1.04421186],
[-1.4185131 ],
[ 0.23999721]]); kwargs: out=None, p=2
Expected behavior
The wrapper around the corresponding scipy functions should be adapted to https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Linux
- Python version: 3.9
- How was POT installed (source,
pip
,conda
): pip, version 0.8.0, not occurring under 0.7.0 - scipy 1.7.1