131
131
str_type_error = "All array should be from the same type/backend. Current types are : {}"
132
132
133
133
134
- def get_backend_list ():
135
- """Returns the list of available backends"""
136
- lst = [ NumpyBackend (), ]
134
+ # Mapping between argument types and the existing backend
135
+ _BACKENDS = []
136
+
137
137
138
- if torch :
139
- lst .append (TorchBackend () )
138
+ def register_backend ( backend ) :
139
+ _BACKENDS .append (backend )
140
140
141
- if jax :
142
- lst .append (JaxBackend ())
143
141
144
- if cp : # pragma: no cover
145
- lst .append (CupyBackend ())
142
+ def get_backend_list ():
143
+ """Returns the list of available backends"""
144
+ return _BACKENDS
145
+
146
146
147
- if tf :
148
- lst .append (TensorflowBackend ())
147
+ def _check_args_backend (backend , args ):
148
+ is_instance = set (isinstance (a , backend .__type__ ) for a in args )
149
+ # check that all arguments matched or not the type
150
+ if len (is_instance ) == 1 :
151
+ return is_instance .pop ()
149
152
150
- return lst
153
+ # Oterwise return an error
154
+ raise ValueError (str_type_error .format ([type (a ) for a in args ]))
151
155
152
156
153
157
def get_backend (* args ):
@@ -158,22 +162,12 @@ def get_backend(*args):
158
162
# check that some arrays given
159
163
if not len (args ) > 0 :
160
164
raise ValueError (" The function takes at least one parameter" )
161
- # check all same type
162
- if not len (set (type (a ) for a in args )) == 1 :
163
- raise ValueError (str_type_error .format ([type (a ) for a in args ]))
164
-
165
- if isinstance (args [0 ], np .ndarray ):
166
- return NumpyBackend ()
167
- elif isinstance (args [0 ], torch_type ):
168
- return TorchBackend ()
169
- elif isinstance (args [0 ], jax_type ):
170
- return JaxBackend ()
171
- elif isinstance (args [0 ], cp_type ): # pragma: no cover
172
- return CupyBackend ()
173
- elif isinstance (args [0 ], tf_type ):
174
- return TensorflowBackend ()
175
- else :
176
- raise ValueError ("Unknown type of non implemented backend." )
165
+
166
+ for backend in _BACKENDS :
167
+ if _check_args_backend (backend , args ):
168
+ return backend
169
+
170
+ raise ValueError ("Unknown type of non implemented backend." )
177
171
178
172
179
173
def to_numpy (* args ):
@@ -1318,6 +1312,9 @@ def matmul(self, a, b):
1318
1312
return np .matmul (a , b )
1319
1313
1320
1314
1315
+ register_backend (NumpyBackend ())
1316
+
1317
+
1321
1318
class JaxBackend (Backend ):
1322
1319
"""
1323
1320
JAX implementation of the backend
@@ -1676,6 +1673,11 @@ def matmul(self, a, b):
1676
1673
return jnp .matmul (a , b )
1677
1674
1678
1675
1676
+ if jax :
1677
+ # Only register jax backend if it is installed
1678
+ register_backend (JaxBackend ())
1679
+
1680
+
1679
1681
class TorchBackend (Backend ):
1680
1682
"""
1681
1683
PyTorch implementation of the backend
@@ -2148,6 +2150,11 @@ def matmul(self, a, b):
2148
2150
return torch .matmul (a , b )
2149
2151
2150
2152
2153
+ if torch :
2154
+ # Only register torch backend if it is installed
2155
+ register_backend (TorchBackend ())
2156
+
2157
+
2151
2158
class CupyBackend (Backend ): # pragma: no cover
2152
2159
"""
2153
2160
CuPy implementation of the backend
@@ -2530,6 +2537,11 @@ def matmul(self, a, b):
2530
2537
return cp .matmul (a , b )
2531
2538
2532
2539
2540
+ if cp :
2541
+ # Only register cp backend if it is installed
2542
+ register_backend (CupyBackend ())
2543
+
2544
+
2533
2545
class TensorflowBackend (Backend ):
2534
2546
2535
2547
__name__ = "tf"
@@ -2930,3 +2942,8 @@ def detach(self, *args):
2930
2942
2931
2943
def matmul (self , a , b ):
2932
2944
return tnp .matmul (a , b )
2945
+
2946
+
2947
+ if tf :
2948
+ # Only register tensorflow backend if it is installed
2949
+ register_backend (TensorflowBackend ())
0 commit comments