Skip to content

Commit a879690

Browse files
authored
[MTN] more permissive check_backend (#494)
* MTN more permissive check_backend * TST simplify tests for get_backend and test more * FIX pep8
1 parent f4e9995 commit a879690

File tree

4 files changed

+70
-89
lines changed

4 files changed

+70
-89
lines changed

ot/backend.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -131,23 +131,27 @@
131131
str_type_error = "All array should be from the same type/backend. Current types are : {}"
132132

133133

134-
def get_backend_list():
135-
"""Returns the list of available backends"""
136-
lst = [NumpyBackend(), ]
134+
# Mapping between argument types and the existing backend
135+
_BACKENDS = []
136+
137137

138-
if torch:
139-
lst.append(TorchBackend())
138+
def register_backend(backend):
139+
_BACKENDS.append(backend)
140140

141-
if jax:
142-
lst.append(JaxBackend())
143141

144-
if cp: # pragma: no cover
145-
lst.append(CupyBackend())
142+
def get_backend_list():
143+
"""Returns the list of available backends"""
144+
return _BACKENDS
145+
146146

147-
if tf:
148-
lst.append(TensorflowBackend())
147+
def _check_args_backend(backend, args):
148+
is_instance = set(isinstance(a, backend.__type__) for a in args)
149+
# check that all arguments matched or not the type
150+
if len(is_instance) == 1:
151+
return is_instance.pop()
149152

150-
return lst
153+
# Oterwise return an error
154+
raise ValueError(str_type_error.format([type(a) for a in args]))
151155

152156

153157
def get_backend(*args):
@@ -158,22 +162,12 @@ def get_backend(*args):
158162
# check that some arrays given
159163
if not len(args) > 0:
160164
raise ValueError(" The function takes at least one parameter")
161-
# check all same type
162-
if not len(set(type(a) for a in args)) == 1:
163-
raise ValueError(str_type_error.format([type(a) for a in args]))
164-
165-
if isinstance(args[0], np.ndarray):
166-
return NumpyBackend()
167-
elif isinstance(args[0], torch_type):
168-
return TorchBackend()
169-
elif isinstance(args[0], jax_type):
170-
return JaxBackend()
171-
elif isinstance(args[0], cp_type): # pragma: no cover
172-
return CupyBackend()
173-
elif isinstance(args[0], tf_type):
174-
return TensorflowBackend()
175-
else:
176-
raise ValueError("Unknown type of non implemented backend.")
165+
166+
for backend in _BACKENDS:
167+
if _check_args_backend(backend, args):
168+
return backend
169+
170+
raise ValueError("Unknown type of non implemented backend.")
177171

178172

179173
def to_numpy(*args):
@@ -1318,6 +1312,9 @@ def matmul(self, a, b):
13181312
return np.matmul(a, b)
13191313

13201314

1315+
register_backend(NumpyBackend())
1316+
1317+
13211318
class JaxBackend(Backend):
13221319
"""
13231320
JAX implementation of the backend
@@ -1676,6 +1673,11 @@ def matmul(self, a, b):
16761673
return jnp.matmul(a, b)
16771674

16781675

1676+
if jax:
1677+
# Only register jax backend if it is installed
1678+
register_backend(JaxBackend())
1679+
1680+
16791681
class TorchBackend(Backend):
16801682
"""
16811683
PyTorch implementation of the backend
@@ -2148,6 +2150,11 @@ def matmul(self, a, b):
21482150
return torch.matmul(a, b)
21492151

21502152

2153+
if torch:
2154+
# Only register torch backend if it is installed
2155+
register_backend(TorchBackend())
2156+
2157+
21512158
class CupyBackend(Backend): # pragma: no cover
21522159
"""
21532160
CuPy implementation of the backend
@@ -2530,6 +2537,11 @@ def matmul(self, a, b):
25302537
return cp.matmul(a, b)
25312538

25322539

2540+
if cp:
2541+
# Only register cp backend if it is installed
2542+
register_backend(CupyBackend())
2543+
2544+
25332545
class TensorflowBackend(Backend):
25342546

25352547
__name__ = "tf"
@@ -2930,3 +2942,8 @@ def detach(self, *args):
29302942

29312943
def matmul(self, a, b):
29322944
return tnp.matmul(a, b)
2945+
2946+
2947+
if tf:
2948+
# Only register tensorflow backend if it is installed
2949+
register_backend(TensorflowBackend())

ot/partial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
873873

874874
log_e = {'err': []}
875875

876-
if type(a) == type(b) == type(M) == np.ndarray:
876+
if nx.__name__ == "numpy":
877877
# Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
878878
K = np.empty(M.shape, dtype=M.dtype)
879879
np.divide(M, -reg, out=K)

ot/stochastic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def c_transform_entropic(b, M, reg, beta):
258258

259259

260260
def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
261-
log=False):
261+
log=False):
262262
r'''
263263
Compute the transportation matrix to solve the regularized discrete measures optimal transport max problem
264264

test/test_backend.py

Lines changed: 23 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import ot
99
import ot.backend
10-
from ot.backend import torch, jax, cp, tf
10+
from ot.backend import torch, jax, tf
1111

1212
import pytest
1313

@@ -37,17 +37,7 @@ def test_to_numpy(nx):
3737
assert isinstance(M2, np.ndarray)
3838

3939

40-
def test_get_backend():
41-
42-
A = np.zeros((3, 2))
43-
B = np.zeros((3, 1))
44-
45-
nx = get_backend(A)
46-
assert nx.__name__ == 'numpy'
47-
48-
nx = get_backend(A, B)
49-
assert nx.__name__ == 'numpy'
50-
40+
def test_get_backend_invalid():
5141
# error if no parameters
5242
with pytest.raises(ValueError):
5343
get_backend()
@@ -56,64 +46,38 @@ def test_get_backend():
5646
with pytest.raises(ValueError):
5747
get_backend(1, 2.0)
5848

59-
# test torch
60-
if torch:
6149

62-
A2 = torch.from_numpy(A)
63-
B2 = torch.from_numpy(B)
50+
def test_get_backend(nx):
6451

65-
nx = get_backend(A2)
66-
assert nx.__name__ == 'torch'
67-
68-
nx = get_backend(A2, B2)
69-
assert nx.__name__ == 'torch'
70-
71-
# test not unique types in input
72-
with pytest.raises(ValueError):
73-
get_backend(A, B2)
74-
75-
if jax:
76-
77-
A2 = jax.numpy.array(A)
78-
B2 = jax.numpy.array(B)
79-
80-
nx = get_backend(A2)
81-
assert nx.__name__ == 'jax'
82-
83-
nx = get_backend(A2, B2)
84-
assert nx.__name__ == 'jax'
52+
A = np.zeros((3, 2))
53+
B = np.zeros((3, 1))
8554

86-
# test not unique types in input
87-
with pytest.raises(ValueError):
88-
get_backend(A, B2)
55+
nx_np = get_backend(A)
56+
assert nx_np.__name__ == 'numpy'
8957

90-
if cp:
91-
A2 = cp.asarray(A)
92-
B2 = cp.asarray(B)
58+
A2, B2 = nx.from_numpy(A, B)
9359

94-
nx = get_backend(A2)
95-
assert nx.__name__ == 'cupy'
60+
effective_nx = get_backend(A2)
61+
assert effective_nx.__name__ == nx.__name__
9662

97-
nx = get_backend(A2, B2)
98-
assert nx.__name__ == 'cupy'
63+
effective_nx = get_backend(A2, B2)
64+
assert effective_nx.__name__ == nx.__name__
9965

100-
# test not unique types in input
66+
if nx.__name__ != "numpy":
67+
# test that types mathcing different backends in input raise an error
10168
with pytest.raises(ValueError):
10269
get_backend(A, B2)
70+
else:
71+
# Check that subclassing a numpy array does not break get_backend
72+
# note: This is only tested for numpy as this is hard to be consistent
73+
# with other backends
74+
class nx_subclass(nx.__type__):
75+
pass
10376

104-
if tf:
105-
A2 = tf.convert_to_tensor(A)
106-
B2 = tf.convert_to_tensor(B)
107-
108-
nx = get_backend(A2)
109-
assert nx.__name__ == 'tf'
77+
A3 = nx_subclass(0)
11078

111-
nx = get_backend(A2, B2)
112-
assert nx.__name__ == 'tf'
113-
114-
# test not unique types in input
115-
with pytest.raises(ValueError):
116-
get_backend(A, B2)
79+
effective_nx = get_backend(A3, B2)
80+
assert effective_nx.__name__ == nx.__name__
11781

11882

11983
def test_convert_between_backends(nx):

0 commit comments

Comments
 (0)