diff --git a/RELEASES.md b/RELEASES.md
index 4ee33917d..e70cec198 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -8,6 +8,7 @@
#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
+- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
## 0.9.1
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index cd41a95d4..1f1c69398 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -961,6 +961,13 @@ List of compatible Backends
- `Tensorflow `_ (all outputs differentiable w.r.t. inputs)
- `Cupy `_ (no differentiation, GPU only)
+The library automatically detects which backends are available for use. A backend
+is instantiated lazily only when necessary to prevent unwarranted GPU memory allocations.
+You can also disable the import of a specific backend library (e.g., to accelerate
+loading of `ot` library) using the environment variable `POT_BACKEND_DISABLE_` with in (TORCH,TENSORFLOW,CUPY,JAX).
+For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=1`.
+It's important to note that the `numpy` backend cannot be disabled.
+
List of compatible modules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/ot/backend.py b/ot/backend.py
index 288224d7c..e9750ee0a 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -87,43 +87,67 @@
# License: MIT License
import numpy as np
+import os
import scipy
import scipy.linalg
-import scipy.special as special
from scipy.sparse import issparse, coo_matrix, csr_matrix
-import warnings
+import scipy.special as special
import time
+import warnings
+
+
+DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
+DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
+DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY'
+DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW'
+
-try:
- import torch
- torch_type = torch.Tensor
-except ImportError:
+if not os.environ.get(DISABLE_TORCH_KEY, False):
+ try:
+ import torch
+ torch_type = torch.Tensor
+ except ImportError:
+ torch = False
+ torch_type = float
+else:
torch = False
torch_type = float
-try:
- import jax
- import jax.numpy as jnp
- import jax.scipy.special as jspecial
- from jax.lib import xla_bridge
- jax_type = jax.numpy.ndarray
-except ImportError:
+if not os.environ.get(DISABLE_JAX_KEY, False):
+ try:
+ import jax
+ import jax.numpy as jnp
+ import jax.scipy.special as jspecial
+ from jax.lib import xla_bridge
+ jax_type = jax.numpy.ndarray
+ except ImportError:
+ jax = False
+ jax_type = float
+else:
jax = False
jax_type = float
-try:
- import cupy as cp
- import cupyx
- cp_type = cp.ndarray
-except ImportError:
+if not os.environ.get(DISABLE_CUPY_KEY, False):
+ try:
+ import cupy as cp
+ import cupyx
+ cp_type = cp.ndarray
+ except ImportError:
+ cp = False
+ cp_type = float
+else:
cp = False
cp_type = float
-try:
- import tensorflow as tf
- import tensorflow.experimental.numpy as tnp
- tf_type = tf.Tensor
-except ImportError:
+if not os.environ.get(DISABLE_TF_KEY, False):
+ try:
+ import tensorflow as tf
+ import tensorflow.experimental.numpy as tnp
+ tf_type = tf.Tensor
+ except ImportError:
+ tf = False
+ tf_type = float
+else:
tf = False
tf_type = float
@@ -132,26 +156,51 @@
# Mapping between argument types and the existing backend
-_BACKENDS = []
+_BACKEND_IMPLEMENTATIONS = []
+_BACKENDS = {}
-def register_backend(backend):
- _BACKENDS.append(backend)
+def _register_backend_implementation(backend_impl):
+ _BACKEND_IMPLEMENTATIONS.append(backend_impl)
-def get_backend_list():
- """Returns the list of available backends"""
- return _BACKENDS
+def _get_backend_instance(backend_impl):
+ if backend_impl.__name__ not in _BACKENDS:
+ _BACKENDS[backend_impl.__name__] = backend_impl()
+ return _BACKENDS[backend_impl.__name__]
-def _check_args_backend(backend, args):
- is_instance = set(isinstance(a, backend.__type__) for a in args)
+def _check_args_backend(backend_impl, args):
+ is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args)
# check that all arguments matched or not the type
if len(is_instance) == 1:
return is_instance.pop()
- # Oterwise return an error
- raise ValueError(str_type_error.format([type(a) for a in args]))
+ # Otherwise return an error
+ raise ValueError(str_type_error.format([type(arg) for arg in args]))
+
+
+def get_backend_list():
+ """Returns instances of all available backends.
+
+ Note that the function forces all detected implementations
+ to be instantiated even if specific backend was not use before.
+ Be careful as instantiation of the backend might lead to side effects,
+ like GPU memory pre-allocation. See the documentation for more details.
+ If you only need to know which implementations are available,
+ use `:py:func:`ot.backend.get_available_backend_implementations`,
+ which does not force instance of the backend object to be created.
+ """
+ return [
+ _get_backend_instance(backend_impl)
+ for backend_impl
+ in get_available_backend_implementations()
+ ]
+
+
+def get_available_backend_implementations():
+ """Returns the list of available backend implementations."""
+ return _BACKEND_IMPLEMENTATIONS
def get_backend(*args):
@@ -167,9 +216,9 @@ def get_backend(*args):
if not len(args) > 0:
raise ValueError(" The function takes at least one (non-None) parameter")
- for backend in _BACKENDS:
- if _check_args_backend(backend, args):
- return backend
+ for backend_impl in _BACKEND_IMPLEMENTATIONS:
+ if _check_args_backend(backend_impl, args):
+ return _get_backend_instance(backend_impl)
raise ValueError("Unknown type of non implemented backend.")
@@ -1341,7 +1390,7 @@ def matmul(self, a, b):
return np.matmul(a, b)
-register_backend(NumpyBackend())
+_register_backend_implementation(NumpyBackend)
class JaxBackend(Backend):
@@ -1710,7 +1759,7 @@ def matmul(self, a, b):
if jax:
# Only register jax backend if it is installed
- register_backend(JaxBackend())
+ _register_backend_implementation(JaxBackend)
class TorchBackend(Backend):
@@ -2193,7 +2242,7 @@ def matmul(self, a, b):
if torch:
# Only register torch backend if it is installed
- register_backend(TorchBackend())
+ _register_backend_implementation(TorchBackend)
class CupyBackend(Backend): # pragma: no cover
@@ -2586,7 +2635,7 @@ def matmul(self, a, b):
if cp:
# Only register cp backend if it is installed
- register_backend(CupyBackend())
+ _register_backend_implementation(CupyBackend)
class TensorflowBackend(Backend):
@@ -3006,4 +3055,4 @@ def matmul(self, a, b):
if tf:
# Only register tensorflow backend if it is installed
- register_backend(TensorflowBackend())
+ _register_backend_implementation(TensorflowBackend)
diff --git a/test/conftest.py b/test/conftest.py
index c0db8abe2..0303ed9f2 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -4,19 +4,33 @@
# License: MIT License
-import pytest
-from ot.backend import jax, tf
-from ot.backend import get_backend_list
import functools
+import os
+import pytest
+
+from ot.backend import get_backend_list, jax, tf
+
if jax:
+ os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from jax.config import config
config.update("jax_enable_x64", True)
if tf:
+ # make sure TF doesn't allocate entire GPU
+ import tensorflow as tf
+ physical_devices = tf.config.list_physical_devices('GPU')
+ for device in physical_devices:
+ try:
+ tf.config.experimental.set_memory_growth(device, True)
+ except Exception:
+ pass
+
+ # allow numpy API for TF
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
+
backend_list = get_backend_list()