Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,13 @@ List of compatible Backends
- `Tensorflow <https://www.tensorflow.org/>`_ (all outputs differentiable w.r.t. inputs)
- `Cupy <https://cupy.dev/>`_ (no differentiation, GPU only)

The library automatically detects which backends are available for use. A backend
is instantiated lazily only when necessary to prevent unwarranted GPU memory allocations.
You can also disable the import of a specific backend library (e.g., to accelerate
loading of `ot` library) using the environment variable `POT_BACKEND_DISABLE_<NAME>`.
For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=1`.
It's important to note that the `numpy` backend cannot be disabled.


List of compatible modules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
120 changes: 78 additions & 42 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,43 +87,67 @@
# License: MIT License

import numpy as np
import os
import scipy
import scipy.linalg
import scipy.special as special
from scipy.sparse import issparse, coo_matrix, csr_matrix
import warnings
import scipy.special as special
import time
import warnings


DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY'
DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW'


try:
import torch
torch_type = torch.Tensor
except ImportError:
if not os.environ.get(DISABLE_TORCH_KEY, False):
try:
import torch
torch_type = torch.Tensor
except ImportError:
torch = False
torch_type = float
else:
torch = False
torch_type = float

try:
import jax
import jax.numpy as jnp
import jax.scipy.special as jspecial
from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
if not os.environ.get(DISABLE_JAX_KEY, False):
try:
import jax
import jax.numpy as jnp
import jax.scipy.special as jspecial
from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
jax = False
jax_type = float
else:
jax = False
jax_type = float

try:
import cupy as cp
import cupyx
cp_type = cp.ndarray
except ImportError:
if not os.environ.get(DISABLE_CUPY_KEY, False):
try:
import cupy as cp
import cupyx
cp_type = cp.ndarray
except ImportError:
cp = False
cp_type = float
else:
cp = False
cp_type = float

try:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
tf_type = tf.Tensor
except ImportError:
if not os.environ.get(DISABLE_TF_KEY, False):
try:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
tf_type = tf.Tensor
except ImportError:
tf = False
tf_type = float
else:
tf = False
tf_type = float

Expand All @@ -132,26 +156,38 @@


# Mapping between argument types and the existing backend
_BACKENDS = []
_BACKEND_IMPLEMENTATIONS = []
_BACKENDS = {}


def register_backend(backend):
_BACKENDS.append(backend)
def get_backend_list():
"""Returns the list of already instantiated backends."""
return list(_BACKENDS.values())


def get_backend_list():
"""Returns the list of available backends"""
return _BACKENDS
def get_available_backend_implementations():
"""Returns the list of available backend implementations."""
return _BACKEND_IMPLEMENTATIONS


def _register_backend_implementation(backend_impl):
_BACKEND_IMPLEMENTATIONS.append(backend_impl)


def _get_backend_instance(backend_impl):
if backend_impl.__name__ not in _BACKENDS:
_BACKENDS[backend_impl.__name__] = backend_impl()
return _BACKENDS[backend_impl.__name__]


def _check_args_backend(backend, args):
is_instance = set(isinstance(a, backend.__type__) for a in args)
def _check_args_backend(backend_impl, args):
is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args)
# check that all arguments matched or not the type
if len(is_instance) == 1:
return is_instance.pop()

# Oterwise return an error
raise ValueError(str_type_error.format([type(a) for a in args]))
# Otherwise return an error
raise ValueError(str_type_error.format([type(arg) for arg in args]))


def get_backend(*args):
Expand All @@ -160,12 +196,12 @@ def get_backend(*args):
Also raises TypeError if all arrays are not from the same backend
"""
# check that some arrays given
if not len(args) > 0:
if len(args) == 0:
raise ValueError(" The function takes at least one parameter")

for backend in _BACKENDS:
if _check_args_backend(backend, args):
return backend
for backend_impl in _BACKEND_IMPLEMENTATIONS:
if _check_args_backend(backend_impl, args):
return _get_backend_instance(backend_impl)

raise ValueError("Unknown type of non implemented backend.")

Expand Down Expand Up @@ -1337,7 +1373,7 @@ def matmul(self, a, b):
return np.matmul(a, b)


register_backend(NumpyBackend())
_register_backend_implementation(NumpyBackend)


class JaxBackend(Backend):
Expand Down Expand Up @@ -1706,7 +1742,7 @@ def matmul(self, a, b):

if jax:
# Only register jax backend if it is installed
register_backend(JaxBackend())
_register_backend_implementation(JaxBackend)


class TorchBackend(Backend):
Expand Down Expand Up @@ -2189,7 +2225,7 @@ def matmul(self, a, b):

if torch:
# Only register torch backend if it is installed
register_backend(TorchBackend())
_register_backend_implementation(TorchBackend)


class CupyBackend(Backend): # pragma: no cover
Expand Down Expand Up @@ -2582,7 +2618,7 @@ def matmul(self, a, b):

if cp:
# Only register cp backend if it is installed
register_backend(CupyBackend())
_register_backend_implementation(CupyBackend)


class TensorflowBackend(Backend):
Expand Down Expand Up @@ -2995,4 +3031,4 @@ def matmul(self, a, b):

if tf:
# Only register tensorflow backend if it is installed
register_backend(TensorflowBackend())
_register_backend_implementation(TensorflowBackend)
32 changes: 29 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,45 @@

# License: MIT License

import pytest
from ot.backend import jax, tf
from ot.backend import get_backend_list
import functools
import os
import pytest

from ot.backend import (
get_available_backend_implementations,
get_backend_list,
_get_backend_instance,
jax,
tf
)


if jax:
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from jax.config import config
config.update("jax_enable_x64", True)

if tf:
# make sure TF doesn't allocate entire GPU
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
try:
tf.config.experimental.set_memory_growth(device, True)
except Exception:
pass

# allow numpy API for TF
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()


# before taking list of backends, we need to make sure all
# available implementations are instantiated. looks somewhat hacky,
# but hopefully it won't be needed for a common library use
for backend_impl in get_available_backend_implementations():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should do that in get_backend_list ? I mean this funtiinn should return all bacjen even if the backend was not used in the past... Of course we need to put more detail in the doc of get_backend_list to be clear about the potentiel memory problems

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This depends mainly on the use case get_backend_list was originally dedicated for. get_available_backend_implementations clearly conveys the intention: those are backends "available" to be used. But I'm not sure what would be a use case for get_backend_list except for running tests on all backends at once. If we want get_backend_list to always force instantiation of all them - sure, we can do that. Just being very careful with the documentation.

_get_backend_instance(backend_impl)

backend_list = get_backend_list()


Expand Down