diff --git a/RELEASES.md b/RELEASES.md index 4ee33917d..e70cec198 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,6 +8,7 @@ #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) +- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520) ## 0.9.1 diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index cd41a95d4..1f1c69398 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -961,6 +961,13 @@ List of compatible Backends - `Tensorflow `_ (all outputs differentiable w.r.t. inputs) - `Cupy `_ (no differentiation, GPU only) +The library automatically detects which backends are available for use. A backend +is instantiated lazily only when necessary to prevent unwarranted GPU memory allocations. +You can also disable the import of a specific backend library (e.g., to accelerate +loading of `ot` library) using the environment variable `POT_BACKEND_DISABLE_` with in (TORCH,TENSORFLOW,CUPY,JAX). +For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=1`. +It's important to note that the `numpy` backend cannot be disabled. + List of compatible modules ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/ot/backend.py b/ot/backend.py index 288224d7c..e9750ee0a 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -87,43 +87,67 @@ # License: MIT License import numpy as np +import os import scipy import scipy.linalg -import scipy.special as special from scipy.sparse import issparse, coo_matrix, csr_matrix -import warnings +import scipy.special as special import time +import warnings + + +DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH' +DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX' +DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY' +DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW' + -try: - import torch - torch_type = torch.Tensor -except ImportError: +if not os.environ.get(DISABLE_TORCH_KEY, False): + try: + import torch + torch_type = torch.Tensor + except ImportError: + torch = False + torch_type = float +else: torch = False torch_type = float -try: - import jax - import jax.numpy as jnp - import jax.scipy.special as jspecial - from jax.lib import xla_bridge - jax_type = jax.numpy.ndarray -except ImportError: +if not os.environ.get(DISABLE_JAX_KEY, False): + try: + import jax + import jax.numpy as jnp + import jax.scipy.special as jspecial + from jax.lib import xla_bridge + jax_type = jax.numpy.ndarray + except ImportError: + jax = False + jax_type = float +else: jax = False jax_type = float -try: - import cupy as cp - import cupyx - cp_type = cp.ndarray -except ImportError: +if not os.environ.get(DISABLE_CUPY_KEY, False): + try: + import cupy as cp + import cupyx + cp_type = cp.ndarray + except ImportError: + cp = False + cp_type = float +else: cp = False cp_type = float -try: - import tensorflow as tf - import tensorflow.experimental.numpy as tnp - tf_type = tf.Tensor -except ImportError: +if not os.environ.get(DISABLE_TF_KEY, False): + try: + import tensorflow as tf + import tensorflow.experimental.numpy as tnp + tf_type = tf.Tensor + except ImportError: + tf = False + tf_type = float +else: tf = False tf_type = float @@ -132,26 +156,51 @@ # Mapping between argument types and the existing backend -_BACKENDS = [] +_BACKEND_IMPLEMENTATIONS = [] +_BACKENDS = {} -def register_backend(backend): - _BACKENDS.append(backend) +def _register_backend_implementation(backend_impl): + _BACKEND_IMPLEMENTATIONS.append(backend_impl) -def get_backend_list(): - """Returns the list of available backends""" - return _BACKENDS +def _get_backend_instance(backend_impl): + if backend_impl.__name__ not in _BACKENDS: + _BACKENDS[backend_impl.__name__] = backend_impl() + return _BACKENDS[backend_impl.__name__] -def _check_args_backend(backend, args): - is_instance = set(isinstance(a, backend.__type__) for a in args) +def _check_args_backend(backend_impl, args): + is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args) # check that all arguments matched or not the type if len(is_instance) == 1: return is_instance.pop() - # Oterwise return an error - raise ValueError(str_type_error.format([type(a) for a in args])) + # Otherwise return an error + raise ValueError(str_type_error.format([type(arg) for arg in args])) + + +def get_backend_list(): + """Returns instances of all available backends. + + Note that the function forces all detected implementations + to be instantiated even if specific backend was not use before. + Be careful as instantiation of the backend might lead to side effects, + like GPU memory pre-allocation. See the documentation for more details. + If you only need to know which implementations are available, + use `:py:func:`ot.backend.get_available_backend_implementations`, + which does not force instance of the backend object to be created. + """ + return [ + _get_backend_instance(backend_impl) + for backend_impl + in get_available_backend_implementations() + ] + + +def get_available_backend_implementations(): + """Returns the list of available backend implementations.""" + return _BACKEND_IMPLEMENTATIONS def get_backend(*args): @@ -167,9 +216,9 @@ def get_backend(*args): if not len(args) > 0: raise ValueError(" The function takes at least one (non-None) parameter") - for backend in _BACKENDS: - if _check_args_backend(backend, args): - return backend + for backend_impl in _BACKEND_IMPLEMENTATIONS: + if _check_args_backend(backend_impl, args): + return _get_backend_instance(backend_impl) raise ValueError("Unknown type of non implemented backend.") @@ -1341,7 +1390,7 @@ def matmul(self, a, b): return np.matmul(a, b) -register_backend(NumpyBackend()) +_register_backend_implementation(NumpyBackend) class JaxBackend(Backend): @@ -1710,7 +1759,7 @@ def matmul(self, a, b): if jax: # Only register jax backend if it is installed - register_backend(JaxBackend()) + _register_backend_implementation(JaxBackend) class TorchBackend(Backend): @@ -2193,7 +2242,7 @@ def matmul(self, a, b): if torch: # Only register torch backend if it is installed - register_backend(TorchBackend()) + _register_backend_implementation(TorchBackend) class CupyBackend(Backend): # pragma: no cover @@ -2586,7 +2635,7 @@ def matmul(self, a, b): if cp: # Only register cp backend if it is installed - register_backend(CupyBackend()) + _register_backend_implementation(CupyBackend) class TensorflowBackend(Backend): @@ -3006,4 +3055,4 @@ def matmul(self, a, b): if tf: # Only register tensorflow backend if it is installed - register_backend(TensorflowBackend()) + _register_backend_implementation(TensorflowBackend) diff --git a/test/conftest.py b/test/conftest.py index c0db8abe2..0303ed9f2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,19 +4,33 @@ # License: MIT License -import pytest -from ot.backend import jax, tf -from ot.backend import get_backend_list import functools +import os +import pytest + +from ot.backend import get_backend_list, jax, tf + if jax: + os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' from jax.config import config config.update("jax_enable_x64", True) if tf: + # make sure TF doesn't allocate entire GPU + import tensorflow as tf + physical_devices = tf.config.list_physical_devices('GPU') + for device in physical_devices: + try: + tf.config.experimental.set_memory_growth(device, True) + except Exception: + pass + + # allow numpy API for TF from tensorflow.python.ops.numpy_ops import np_config np_config.enable_numpy_behavior() + backend_list = get_backend_list()