From 2e944b0ad247e574ccaa1a7f89db6ae36cced250 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Mon, 29 Nov 2021 17:34:22 +0100 Subject: [PATCH 01/46] First batch of tf methods (to be continued) --- ot/backend.py | 295 +++++++++++++++++++++++++++++++++++++++++++ test/test_backend.py | 71 +++++++---- 2 files changed, 341 insertions(+), 25 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index fa164c39f..f8764f865 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -44,6 +44,13 @@ jax = False jax_type = float +try: + import tensorflow as tf + tf_type = tf.Tensor +except ImportError: + tf = False + tf_type = float + str_type_error = "All array should be from the same type/backend. Current types are : {}" @@ -57,6 +64,9 @@ def get_backend_list(): if jax: lst.append(JaxBackend()) + if tf: + lst.append(TensorflowBackend()) + return lst @@ -78,6 +88,8 @@ def get_backend(*args): return TorchBackend() elif isinstance(args[0], jax_type): return JaxBackend() + elif isinstance(args[0], tf_type): + return TensorflowBackend() else: raise ValueError("Unknown type of non implemented backend.") @@ -665,6 +677,15 @@ def assert_same_dtype_device(self, a, b): """ raise NotImplementedError() + def T(self, a): + r""" + Returns the transpose of a tensor. + + This function follows the api from :any:`numpy.ndarray.T`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.T.html + """ + class NumpyBackend(Backend): """ @@ -902,6 +923,9 @@ def assert_same_dtype_device(self, a, b): # numpy has implicit type conversion so we automatically validate the test pass + def T(self, a): + return a.T + class JaxBackend(Backend): """ @@ -1162,6 +1186,9 @@ def assert_same_dtype_device(self, a, b): assert a_dtype == b_dtype, "Dtype discrepancy" assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + def T(self, a): + return a.T + class TorchBackend(Backend): """ @@ -1500,3 +1527,271 @@ def assert_same_dtype_device(self, a, b): assert a_dtype == b_dtype, "Dtype discrepancy" assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + + def T(self, a): + return a.T + + +class TensorflowBackend(Backend): + + __name__ = "tf" + __type__ = tf_type + __type_list__ = None + + rng_ = None + + def __init__(self): + self.rng_ = tf.random.Generator.from_non_deterministic_state() + + self.__type_list__ = [ + tf.convert_to_tensor([1], dtype=tf.float32), + tf.convert_to_tensor([1], dtype=tf.float64) + ] + + def to_numpy(self, a): + return a.numpy() + + def from_numpy(self, a, type_as=None): + if type_as is None: + return tf.convert_to_tensor(a) + else: + return tf.convert_to_tensor(a, dtype=type_as.dtype) + + def set_gradients(self, val, inputs, grads): + @tf.custom_gradient + def tmp(input): + def grad(upstream): + return grads + return val, grad + return tmp(inputs) + + def zeros(self, shape, type_as=None): + if type_as is None: + return tf.zeros(shape) + else: + return tf.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return tf.ones(shape) + else: + return tf.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return tf.range(start=start, limit=stop, delta=step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return tf.fill(shape, fill_value) + else: + return tf.cast(tf.fill(shape, fill_value), dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if M is None: + M = N + if type_as is None: + return tf.eye(N, M) + else: + return tf.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + if axis is None: + return tf.math.reduce_sum(a) + else: + return tf.math.reduce_sum(a, axis=axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + if axis is None: + return tf.math.cumsum(tf.reshape(a, [-1]), axis=0) + else: + return tf.math.cumsum(a, axis=axis) + + def max(self, a, axis=None, keepdims=False): + if axis is None: + return tf.math.reduce_max(a) + else: + return tf.math.reduce_max(a, axis=axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + if axis is None: + return tf.math.reduce_min(a) + else: + return tf.math.reduce_min(a, axis=axis, keepdims=keepdims) + + def maximum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = tf.constant([float(a)], dtype=b.dtype) + if isinstance(b, int) or isinstance(b, float): + b = tf.constant([float(b)], dtype=a.dtype) + return tf.math.maximum(a, b) + + def minimum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = tf.constant([float(a)], dtype=b.dtype) + if isinstance(b, int) or isinstance(b, float): + b = tf.constant([float(b)], dtype=a.dtype) + return tf.math.minimum(a, b) + + def dot(self, a, b): + if len(b.shape) == 1: + if len(a.shape) == 1: + # inner product + return tf.reduce_sum(tf.multiply(a, b)) + else: + # matrix vector + return tf.linalg.matvec(a, b) + else: + if len(a.shape) == 1: + return self.T(tf.linalg.matvec(self.t(b), self.T(a))) + else: + return tf.matmul(a, b) + + def abs(self, a): + return tf.math.abs(a) + + def exp(self, a): + return tf.math.exp(a) + + def log(self, a): + return tf.math.log(a) + + def sqrt(self, a): + return tf.math.sqrt(a) + + def power(self, a, exponents): + return tf.math.pow(a, exponents) + + def norm(self, a): + return tf.norm(a, ord=2) + + def any(self, a): + return tf.math.reduce_any(a) + + def isnan(self, a): + return tf.math.is_nan(a) + + def isinf(self, a): + return tf.math.is_inf(a) + + def einsum(self, subscripts, *operands): + return tf.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return tf.sort(a, axis=axis) + + def argsort(self, a, axis=-1): + return tf.argsort(a, axis=axis) + + def searchsorted(self, a, v, side='left'): + return tf.searchsorted(a, v, side=side) + + def flip(self, a, axis=None): + if axis is None: + return tf.reverse(a, tuple(i for i in range(len(a.shape)))) + if isinstance(axis, int): + return tf.reverse(a, (axis,)) + else: + return tf.reverse(a, axis=axis) + + def outer(self, a, b): + return tf.einsum('i,j->ij', a, b) + + def clip(self, a, a_min, a_max): + return tf.clip_by_value(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return tf.repeat(a, repeats, axis=axis) + + def take_along_axis(self, arr, indices, axis): + return tf.gather(arr, indices, axis=axis) + + def concatenate(self, arrays, axis=0): + return tf.concat(arrays, axis=axis) + + def zero_pad(self, a, pad_with): + return tf.pad(a, pad_with) + + def argmax(self, a, axis=None): + if axis is None: + return tf.argmax(tf.reshape(a, [-1])) + else: + return tf.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return tf.math.reduce_mean(a, axis=axis) + + def std(self, a, axis=None): + return tf.math.reduce_std(a, axis=axis) + + def linspace(self, start, stop, num): + return tf.linspace(start, stop, num) + + def meshgrid(self, a, b): + return tf.meshgrid(a, b) + + def diag(self, a, k=0): + if len(a.shape) == 1: + return tf.linalg.diag(a, k=k) + else: + return tf.linalg.diag_part(a, k=k) + + def unique(self, a): + return tf.sort(tf.unique(tf.reshape(a, [-1]))[0]) + + def logsumexp(self, a, axis=None): + return tf.math.reduce_logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return tf.stack(arrays, axis=axis) + + def reshape(self, a, shape): + return tf.reshape(a, shape) + + def seed(self, seed=None): + if isinstance(seed, int): + self.rng_ = tf.random.Generator.from_seed(1234) + elif isinstance(seed, tf.random.Generator): + self.rng_ = seed + elif seed is None: + self.rng_ = tf.random.Generator.from_non_deterministic_state() + else: + raise ValueError("Non compatible seed : {}".format(seed)) + + def rand(self, *size, type_as=None): + raise NotImplementedError() + + def randn(self, *size, type_as=None): + raise NotImplementedError() + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + raise NotImplementedError() + + def issparse(self, a): + raise NotImplementedError() + + def tocsr(self, a): + raise NotImplementedError() + + def eliminate_zeros(self, a, threshold=0.): + raise NotImplementedError() + + def todense(self, a): + raise NotImplementedError() + + def where(self, condition, x, y): + raise NotImplementedError() + + def copy(self, a): + raise NotImplementedError() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + raise NotImplementedError() + + def dtype_device(self, a): + raise NotImplementedError() + + def assert_same_dtype_device(self, a, b): + raise NotImplementedError() + + def T(self, a): + return tf.transpose(a) diff --git a/test/test_backend.py b/test/test_backend.py index 1832b9187..e5d1f56fb 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -7,7 +7,7 @@ import ot import ot.backend -from ot.backend import torch, jax +from ot.backend import torch, jax, tf import pytest @@ -87,6 +87,21 @@ def test_get_backend(): with pytest.raises(ValueError): get_backend(A, B2) + if tf: + A2 = tf.convert_to_tensor(A) + B2 = tf.convert_to_tensor(B) + + nx = get_backend(A2) + assert nx.__name__ == 'tensorflow' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'tensorflow' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + + def test_convert_between_backends(nx): @@ -228,6 +243,8 @@ def test_empty_backend(): nx.copy(M) with pytest.raises(NotImplementedError): nx.allclose(M, M) + with pytest.raises(NotImplementedError): + nx.T(M) def test_func_backends(nx): @@ -353,7 +370,7 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('dot(M,v)') - A = nx.dot(Mb, Mb.T) + A = nx.dot(Mb, nx.T(Mb)) lst_b.append(nx.to_numpy(A)) lst_name.append('dot(M,M)') @@ -470,36 +487,40 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('reshape') - sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4)) - nx.todense(Mb) - lst_b.append(nx.to_numpy(nx.todense(sp_Mb))) - lst_name.append('coo_matrix') + # sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4)) + # nx.todense(Mb) + # lst_b.append(nx.to_numpy(nx.todense(sp_Mb))) + # lst_name.append('coo_matrix') - assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' - assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)' + # assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' + # assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)' - A = nx.tocsr(sp_Mb) - lst_b.append(nx.to_numpy(nx.todense(A))) - lst_name.append('tocsr') + # A = nx.tocsr(sp_Mb) + # lst_b.append(nx.to_numpy(nx.todense(A))) + # lst_name.append('tocsr') - A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.) - lst_b.append(nx.to_numpy(A)) - lst_name.append('eliminate_zeros (dense)') + # A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.) + # lst_b.append(nx.to_numpy(A)) + # lst_name.append('eliminate_zeros (dense)') - A = nx.eliminate_zeros(sp_Mb) - lst_b.append(nx.to_numpy(nx.todense(A))) - lst_name.append('eliminate_zeros (sparse)') + # A = nx.eliminate_zeros(sp_Mb) + # lst_b.append(nx.to_numpy(nx.todense(A))) + # lst_name.append('eliminate_zeros (sparse)') - A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) - lst_b.append(nx.to_numpy(A)) - lst_name.append('where') + # A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) + # lst_b.append(nx.to_numpy(A)) + # lst_name.append('where') - A = nx.copy(Mb) - lst_b.append(nx.to_numpy(A)) - lst_name.append('copy') + # A = nx.copy(Mb) + # lst_b.append(nx.to_numpy(A)) + # lst_name.append('copy') - assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' - assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' + # assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' + # assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' + + A = nx.T(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('transpose') lst_tot.append(lst_b) From 1653b707af1114e46d8feb41cc3128e2fea1d7fa Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 30 Nov 2021 11:29:24 +0100 Subject: [PATCH 02/46] Second batch of method (yet to debug) --- ot/backend.py | 97 +++++++++++++++++++++++++++++++++++--------- test/test_backend.py | 51 ++++++++++++----------- 2 files changed, 102 insertions(+), 46 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index f8764f865..7d382db4c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -685,6 +685,7 @@ def T(self, a): See: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.T.html """ + raise NotImplementedError() class NumpyBackend(Backend): @@ -1541,13 +1542,19 @@ class TensorflowBackend(Backend): rng_ = None def __init__(self): - self.rng_ = tf.random.Generator.from_non_deterministic_state() + self.seed(None) self.__type_list__ = [ tf.convert_to_tensor([1], dtype=tf.float32), tf.convert_to_tensor([1], dtype=tf.float64) ] + def _constant(self, val, type_as=None): + if type_as is None: + return tf.constant(val) + else: + return tf.cast(tf.constant(val), dtype=type_as.dtype) + def to_numpy(self, a): return a.numpy() @@ -1642,7 +1649,7 @@ def dot(self, a, b): return tf.linalg.matvec(a, b) else: if len(a.shape) == 1: - return self.T(tf.linalg.matvec(self.t(b), self.T(a))) + return self.T(tf.linalg.matvec(self.T(b), self.T(a))) else: return tf.matmul(a, b) @@ -1749,7 +1756,7 @@ def reshape(self, a, shape): def seed(self, seed=None): if isinstance(seed, int): - self.rng_ = tf.random.Generator.from_seed(1234) + self.rng_ = tf.random.Generator.from_seed(seed) elif isinstance(seed, tf.random.Generator): self.rng_ = seed elif seed is None: @@ -1758,40 +1765,90 @@ def seed(self, seed=None): raise ValueError("Non compatible seed : {}".format(seed)) def rand(self, *size, type_as=None): - raise NotImplementedError() - + if type_as is None: + return self.rng_.uniform(size, minval=0., maxval=1.) + else: + return self.rng_.uniform( + size, minval=0., maxval=1., dtype=type_as.dtype + ) + def randn(self, *size, type_as=None): - raise NotImplementedError() + if type_as is None: + return self.rng_.normal(size) + else: + return self.rng_.normal(size, dtype=type_as.dtype) + + def _convert_to_index_for_coo(self, tensor): + if isinstance(tensor, self.__type__): + return int(self.max(tensor)) + 1 + else: + return int(np.max(tensor)) + 1 def coo_matrix(self, data, rows, cols, shape=None, type_as=None): - raise NotImplementedError() - + if shape is None: + shape = ( + self._convert_to_index_for_coo(rows), + self._convert_to_index_for_coo(cols) + ) + sparse_tensor = tf.sparse.SparseTensor( + indices=self.T(self.stack([rows, cols])), + values=data, + dense_shape=shape + ) + if type_as is not None: + sparse_tensor = self.from_numpy(sparse_tensor) + return sparse_tensor + def issparse(self, a): - raise NotImplementedError() - + return isinstance(a, tf.sparse.SparseTensor) + def tocsr(self, a): - raise NotImplementedError() + return tf.sparse.reorder(a) def eliminate_zeros(self, a, threshold=0.): - raise NotImplementedError() + threshold = self._constant(threshold, type_as=a) + if self.issparse(a): + values = a.values + if threshold > 0: + mask = self.abs(values) <= threshold + else: + mask = values == self._constant(0., type_as=a) + return tf.sparse.retain(a, ~mask) + else: + if threshold > 0: + a = tf.where( + self.abs(a) <= threshold, + self._constant(0., type_as=a), + a + ) + return a def todense(self, a): - raise NotImplementedError() + if self.issparse(a): + return tf.sparse.to_dense(tf.sparse.reorder(a)) + else: + return a def where(self, condition, x, y): - raise NotImplementedError() + return tf.where(condition, x, y) def copy(self, a): - raise NotImplementedError() - + return tf.identity(a) + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - raise NotImplementedError() - + return tf.experimental.numpy.allclose( + a, b, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + def dtype_device(self, a): - raise NotImplementedError() + return a.dtype, a.device.split("device")[1] def assert_same_dtype_device(self, a, b): - raise NotImplementedError() + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" def T(self, a): return tf.transpose(a) diff --git a/test/test_backend.py b/test/test_backend.py index e5d1f56fb..2fe661e98 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -92,17 +92,16 @@ def test_get_backend(): B2 = tf.convert_to_tensor(B) nx = get_backend(A2) - assert nx.__name__ == 'tensorflow' + assert nx.__name__ == 'tf' nx = get_backend(A2, B2) - assert nx.__name__ == 'tensorflow' + assert nx.__name__ == 'tf' # test not unique types in input with pytest.raises(ValueError): get_backend(A, B2) - def test_convert_between_backends(nx): A = np.zeros((3, 2)) @@ -487,36 +486,36 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('reshape') - # sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4)) - # nx.todense(Mb) - # lst_b.append(nx.to_numpy(nx.todense(sp_Mb))) - # lst_name.append('coo_matrix') + sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4)) + nx.todense(Mb) + lst_b.append(nx.to_numpy(nx.todense(sp_Mb))) + lst_name.append('coo_matrix') - # assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' - # assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)' + assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' + assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)' - # A = nx.tocsr(sp_Mb) - # lst_b.append(nx.to_numpy(nx.todense(A))) - # lst_name.append('tocsr') + A = nx.tocsr(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('tocsr') - # A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.) - # lst_b.append(nx.to_numpy(A)) - # lst_name.append('eliminate_zeros (dense)') + A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.) + lst_b.append(nx.to_numpy(A)) + lst_name.append('eliminate_zeros (dense)') - # A = nx.eliminate_zeros(sp_Mb) - # lst_b.append(nx.to_numpy(nx.todense(A))) - # lst_name.append('eliminate_zeros (sparse)') + A = nx.eliminate_zeros(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('eliminate_zeros (sparse)') - # A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) - # lst_b.append(nx.to_numpy(A)) - # lst_name.append('where') + A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('where') - # A = nx.copy(Mb) - # lst_b.append(nx.to_numpy(A)) - # lst_name.append('copy') + A = nx.copy(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('copy') - # assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' - # assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' + assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' + assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' A = nx.T(Mb) lst_b.append(nx.to_numpy(A)) From 2fe3441cbe185afa770f9caf6f5a6cb3bb8b7f97 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 30 Nov 2021 17:19:18 +0100 Subject: [PATCH 03/46] tensorflow for cpu --- ot/backend.py | 230 ++++++++++++++++++++++--------------------- ot/bregman.py | 72 ++++++++------ ot/gromov.py | 2 +- ot/lp/solver_1d.py | 4 +- test/conftest.py | 6 +- test/test_backend.py | 9 +- test/test_bregman.py | 14 ++- 7 files changed, 184 insertions(+), 153 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 7d382db4c..c9a7779a0 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -27,6 +27,7 @@ import numpy as np import scipy.special as scipy from scipy.sparse import issparse, coo_matrix, csr_matrix +import warnings try: import torch @@ -46,6 +47,7 @@ try: import tensorflow as tf + import tensorflow.experimental.numpy as tnp tf_type = tf.Tensor except ImportError: tf = False @@ -62,7 +64,8 @@ def get_backend_list(): lst.append(TorchBackend()) if jax: - lst.append(JaxBackend()) + pass + # lst.append(JaxBackend()) if tf: lst.append(TensorflowBackend()) @@ -687,6 +690,16 @@ def T(self, a): """ raise NotImplementedError() + def squeeze(self, a, axis=None): + r""" + Remove axes of length one from a. + + This function follows the api from :any:`numpy.squeeze`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -927,6 +940,9 @@ def assert_same_dtype_device(self, a, b): def T(self, a): return a.T + def squeeze(self, a, axis=None): + return np.squeeze(a, axis=axis) + class JaxBackend(Backend): """ @@ -1190,6 +1206,9 @@ def assert_same_dtype_device(self, a, b): def T(self, a): return a.T + def squeeze(self, a, axis=None): + return jnp.squeeze(a, axis=axis) + class TorchBackend(Backend): """ @@ -1532,6 +1551,9 @@ def assert_same_dtype_device(self, a, b): def T(self, a): return a.T + def squeeze(self, a, axis=None): + return torch.squeeze(a, dim=axis) + class TensorflowBackend(Backend): @@ -1549,20 +1571,31 @@ def __init__(self): tf.convert_to_tensor([1], dtype=tf.float64) ] - def _constant(self, val, type_as=None): - if type_as is None: - return tf.constant(val) - else: - return tf.cast(tf.constant(val), dtype=type_as.dtype) + tmp = self.randn(15, 10) + try: + tmp.reshape((150, 1)) + except AttributeError: + warnings.warn( + "To use TensorflowBackend, you need to activate the tensorflow " + "numpy API. You can activate it by running: \n" + "from tensorflow.python.ops.numpy_ops import np_config\n" + "np_config.enable_numpy_behavior()" + ) def to_numpy(self, a): return a.numpy() def from_numpy(self, a, type_as=None): - if type_as is None: - return tf.convert_to_tensor(a) + if not isinstance(a, self.__type__): + if type_as is None: + return tf.convert_to_tensor(a) + else: + return tf.convert_to_tensor(a, dtype=type_as.dtype) else: - return tf.convert_to_tensor(a, dtype=type_as.dtype) + if type_as is None: + return a + else: + return tf.cast(a, dtype=type_as.dtype) def set_gradients(self, val, inputs, grads): @tf.custom_gradient @@ -1574,173 +1607,141 @@ def grad(upstream): def zeros(self, shape, type_as=None): if type_as is None: - return tf.zeros(shape) + return tnp.zeros(shape) else: - return tf.zeros(shape, dtype=type_as.dtype) + return tnp.zeros(shape, dtype=type_as.dtype) def ones(self, shape, type_as=None): if type_as is None: - return tf.ones(shape) + return tnp.ones(shape) else: - return tf.ones(shape, dtype=type_as.dtype) + return tnp.ones(shape, dtype=type_as.dtype) def arange(self, stop, start=0, step=1, type_as=None): - return tf.range(start=start, limit=stop, delta=step) + return tnp.arange(start, stop, step) def full(self, shape, fill_value, type_as=None): if type_as is None: - return tf.fill(shape, fill_value) + return tnp.full(shape, fill_value) else: - return tf.cast(tf.fill(shape, fill_value), dtype=type_as.dtype) + return tnp.full(shape, fill_value, dtype=type_as.dtype) def eye(self, N, M=None, type_as=None): - if M is None: - M = N if type_as is None: - return tf.eye(N, M) + return tnp.eye(N, M) else: - return tf.eye(N, M, dtype=type_as.dtype) + return tnp.eye(N, M, dtype=type_as.dtype) def sum(self, a, axis=None, keepdims=False): - if axis is None: - return tf.math.reduce_sum(a) - else: - return tf.math.reduce_sum(a, axis=axis, keepdims=keepdims) + return tnp.sum(a, axis, keepdims=keepdims) def cumsum(self, a, axis=None): - if axis is None: - return tf.math.cumsum(tf.reshape(a, [-1]), axis=0) - else: - return tf.math.cumsum(a, axis=axis) + return tnp.cumsum(a, axis) def max(self, a, axis=None, keepdims=False): - if axis is None: - return tf.math.reduce_max(a) - else: - return tf.math.reduce_max(a, axis=axis, keepdims=keepdims) + return tnp.max(a, axis, keepdims=keepdims) def min(self, a, axis=None, keepdims=False): - if axis is None: - return tf.math.reduce_min(a) - else: - return tf.math.reduce_min(a, axis=axis, keepdims=keepdims) + return tnp.min(a, axis, keepdims=keepdims) def maximum(self, a, b): - if isinstance(a, int) or isinstance(a, float): - a = tf.constant([float(a)], dtype=b.dtype) - if isinstance(b, int) or isinstance(b, float): - b = tf.constant([float(b)], dtype=a.dtype) - return tf.math.maximum(a, b) + return tnp.maximum(a, b) def minimum(self, a, b): - if isinstance(a, int) or isinstance(a, float): - a = tf.constant([float(a)], dtype=b.dtype) - if isinstance(b, int) or isinstance(b, float): - b = tf.constant([float(b)], dtype=a.dtype) - return tf.math.minimum(a, b) + return tnp.minimum(a, b) def dot(self, a, b): - if len(b.shape) == 1: - if len(a.shape) == 1: - # inner product - return tf.reduce_sum(tf.multiply(a, b)) - else: - # matrix vector - return tf.linalg.matvec(a, b) - else: - if len(a.shape) == 1: - return self.T(tf.linalg.matvec(self.T(b), self.T(a))) - else: - return tf.matmul(a, b) + return tnp.dot(a, b) + # if len(b.shape) == 1: + # if len(a.shape) == 1: + # # inner product + # return tf.reduce_sum(tf.multiply(a, b)) + # else: + # # matrix vector + # return tf.linalg.matvec(a, b) + # else: + # if len(a.shape) == 1: + # return self.T(tf.linalg.matvec(self.T(b), self.T(a))) + # else: + # return tf.matmul(a, b) def abs(self, a): - return tf.math.abs(a) + return tnp.abs(a) def exp(self, a): - return tf.math.exp(a) + return tnp.exp(a) def log(self, a): - return tf.math.log(a) + return tnp.log(a) def sqrt(self, a): - return tf.math.sqrt(a) + return tnp.sqrt(a) def power(self, a, exponents): - return tf.math.pow(a, exponents) + return tnp.power(a, exponents) def norm(self, a): - return tf.norm(a, ord=2) + return tnp.sqrt(tnp.sum(tnp.square(a))) def any(self, a): - return tf.math.reduce_any(a) + return tnp.any(a) def isnan(self, a): - return tf.math.is_nan(a) + return tnp.isnan(a) def isinf(self, a): - return tf.math.is_inf(a) + return tnp.isinf(a) def einsum(self, subscripts, *operands): - return tf.einsum(subscripts, *operands) + return tnp.einsum(subscripts, *operands) def sort(self, a, axis=-1): - return tf.sort(a, axis=axis) + return tnp.sort(a, axis) def argsort(self, a, axis=-1): - return tf.argsort(a, axis=axis) + return tnp.argsort(a, axis) def searchsorted(self, a, v, side='left'): return tf.searchsorted(a, v, side=side) def flip(self, a, axis=None): - if axis is None: - return tf.reverse(a, tuple(i for i in range(len(a.shape)))) - if isinstance(axis, int): - return tf.reverse(a, (axis,)) - else: - return tf.reverse(a, axis=axis) + return tnp.flip(a, axis) def outer(self, a, b): - return tf.einsum('i,j->ij', a, b) + return tnp.outer(a, b) def clip(self, a, a_min, a_max): - return tf.clip_by_value(a, a_min, a_max) + return tnp.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): - return tf.repeat(a, repeats, axis=axis) + return tnp.repeat(a, repeats, axis) def take_along_axis(self, arr, indices, axis): - return tf.gather(arr, indices, axis=axis) + return tnp.take_along_axis(arr, indices, axis) def concatenate(self, arrays, axis=0): - return tf.concat(arrays, axis=axis) + return tnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_with): - return tf.pad(a, pad_with) + def zero_pad(self, a, pad_width): + return tnp.pad(a, pad_width, mode="constant") def argmax(self, a, axis=None): - if axis is None: - return tf.argmax(tf.reshape(a, [-1])) - else: - return tf.argmax(a, axis=axis) + return tnp.argmax(a, axis=axis) def mean(self, a, axis=None): - return tf.math.reduce_mean(a, axis=axis) + return tnp.mean(a, axis=axis) def std(self, a, axis=None): - return tf.math.reduce_std(a, axis=axis) + return tnp.std(a, axis=axis) def linspace(self, start, stop, num): - return tf.linspace(start, stop, num) + return tnp.linspace(start, stop, num) def meshgrid(self, a, b): - return tf.meshgrid(a, b) + return tnp.meshgrid(a, b) def diag(self, a, k=0): - if len(a.shape) == 1: - return tf.linalg.diag(a, k=k) - else: - return tf.linalg.diag_part(a, k=k) + return tnp.diag(a, k) def unique(self, a): return tf.sort(tf.unique(tf.reshape(a, [-1]))[0]) @@ -1749,10 +1750,10 @@ def logsumexp(self, a, axis=None): return tf.math.reduce_logsumexp(a, axis=axis) def stack(self, arrays, axis=0): - return tf.stack(arrays, axis=axis) + return tnp.stack(arrays, axis) def reshape(self, a, shape): - return tf.reshape(a, shape) + return tnp.reshape(a, shape) def seed(self, seed=None): if isinstance(seed, int): @@ -1790,37 +1791,37 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None): self._convert_to_index_for_coo(rows), self._convert_to_index_for_coo(cols) ) + if type_as is not None: + data = self.from_numpy(data, type_as=type_as) + sparse_tensor = tf.sparse.SparseTensor( - indices=self.T(self.stack([rows, cols])), + indices=tnp.stack([rows, cols]).T, + # indices=tf.cast(self.T(self.stack([rows, cols])), dtype=tf.int64), values=data, dense_shape=shape ) - if type_as is not None: - sparse_tensor = self.from_numpy(sparse_tensor) - return sparse_tensor + # if type_as is not None: + # sparse_tensor = self.from_numpy(sparse_tensor, type_as=type_as) + # SparseTensor are not subscriptable so we use dense tensors + return self.todense(sparse_tensor) def issparse(self, a): return isinstance(a, tf.sparse.SparseTensor) def tocsr(self, a): - return tf.sparse.reorder(a) + return a def eliminate_zeros(self, a, threshold=0.): - threshold = self._constant(threshold, type_as=a) if self.issparse(a): values = a.values if threshold > 0: mask = self.abs(values) <= threshold else: - mask = values == self._constant(0., type_as=a) + mask = values == 0 return tf.sparse.retain(a, ~mask) else: if threshold > 0: - a = tf.where( - self.abs(a) <= threshold, - self._constant(0., type_as=a), - a - ) + a = tnp.where(self.abs(a) > threshold, a, 0.) return a def todense(self, a): @@ -1830,18 +1831,18 @@ def todense(self, a): return a def where(self, condition, x, y): - return tf.where(condition, x, y) + return tnp.where(condition, x, y) def copy(self, a): return tf.identity(a) def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - return tf.experimental.numpy.allclose( + return tnp.allclose( a, b, rtol=rtol, atol=atol, equal_nan=equal_nan ) def dtype_device(self, a): - return a.dtype, a.device.split("device")[1] + return a.dtype, a.device.split("device:")[1] def assert_same_dtype_device(self, a, b): a_dtype, a_device = self.dtype_device(a) @@ -1851,4 +1852,7 @@ def assert_same_dtype_device(self, a, b): assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" def T(self, a): - return tf.transpose(a) + return a.T + + def squeeze(self, a, axis=None): + return tnp.squeeze(a, axis=axis) \ No newline at end of file diff --git a/ot/bregman.py b/ot/bregman.py index cce52e26e..fc2017537 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -830,9 +830,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) - if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received. Greenkhorn is not " - "compatible with JAX") + if nx.__name__ in ("jax", "tf"): + raise TypeError("JAX or TF arrays have been received. Greenkhorn is not " + "compatible with neither JAX nor TF") if len(a) == 0: a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] @@ -865,20 +865,20 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, if m_viol_1 > m_viol_2: old_u = u[i_1] - new_u = a[i_1] / (K[i_1, :].dot(v)) + new_u = a[i_1] / nx.dot(K[i_1, :], v) G[i_1, :] = new_u * K[i_1, :] * v - viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1] + viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1] viol_2 += (K[i_1, :].T * (new_u - old_u) * v) u[i_1] = new_u else: old_v = v[i_2] - new_v = b[i_2] / (K[:, i_2].T.dot(u)) + new_v = b[i_2] / nx.dot(K[:, i_2].T, u) G[:, i_2] = u * K[:, i_2] * new_v # aviol = (G@one_m - a) # aviol_2 = (G.T@one_n - b) viol += (-old_v + new_v) * K[:, i_2] * u - viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2] + viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2] v[i_2] = new_v if stopThr_val <= stopThr: @@ -1550,9 +1550,11 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, nx = get_backend(A, M) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and tf. Use numpy or torch arrays instead." + ) if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists @@ -1886,9 +1888,11 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, dim, n_hists = A.shape nx = get_backend(A, M) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists @@ -2043,7 +2047,7 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, log = {'err': []} bar = nx.ones(A.shape[1:], type_as=A) - bar /= bar.sum() + bar /= nx.sum(bar) U = nx.ones(A.shape, type_as=A) V = nx.ones(A.shape, type_as=A) err = 1 @@ -2069,9 +2073,11 @@ def convol_imgs(imgs): KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) - bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + bar = nx.exp( + nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) + ) if ii % 10 == 9: - err = (V * KU).std(axis=0).sum() + err = nx.sum(nx.std(V * KU, axis=0)) # log and verbose print if log: log['err'].append(err) @@ -2106,9 +2112,11 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, A = list_to_array(A) nx = get_backend(A) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) n_hists, width, height = A.shape @@ -2298,13 +2306,15 @@ def convol_imgs(imgs): KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) - bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + bar = c * nx.exp( + nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) + ) for _ in range(10): - c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5 + c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5 if ii % 10 == 9: - err = (V * KU).std(axis=0).sum() + err = nx.sum(nx.std(V * KU, axis=0)) # log and verbose print if log: log['err'].append(err) @@ -2340,9 +2350,11 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10 A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: @@ -2382,7 +2394,7 @@ def convol_img(log_img): c = 0.5 * (c + log_bar - convol_img(c)) if ii % 10 == 9: - err = nx.exp(G + log_KU).std(axis=0).sum() + err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0)) # log and verbose print if log: log['err'].append(err) @@ -3312,9 +3324,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) - if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received but screenkhorn is not " - "compatible with JAX.") + if nx.__name__ in ("jax", "tf"): + raise TypeError("JAX or TF arrays have been received but screenkhorn is not " + "compatible with neither JAX nor TF.") ns, nt = M.shape @@ -3328,7 +3340,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, K = nx.exp(-M / reg) def projection(u, epsilon): - u[u <= epsilon] = epsilon + u = nx.maximum(u, epsilon) return u # ----------------------------------------------------------------------------------------------------------------# diff --git a/ot/gromov.py b/ot/gromov.py index ea667e414..209fe5561 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -942,7 +942,7 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, for cpt in range(max_iter): index[0] = generator.choice(len_p, size=1, p=p) T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) - index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum()) + index[1] = generator.choice(len_q, size=1, p=T_index0 / nx.sum(T_index0)) if alpha == 1: T = nx.tocsr( diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 8b4d0c3be..43763a9bd 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -100,11 +100,11 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ m = v_values.shape[0] if u_weights is None: - u_weights = nx.full(u_values.shape, 1. / n) + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: - v_weights = nx.full(v_values.shape, 1. / m) + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) diff --git a/test/conftest.py b/test/conftest.py index 987d98e25..9b8a18270 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,7 +5,7 @@ # License: MIT License import pytest -from ot.backend import jax +from ot.backend import jax, tf from ot.backend import get_backend_list import functools @@ -13,6 +13,10 @@ from jax.config import config config.update("jax_enable_x64", True) +if tf: + from tensorflow.python.ops.numpy_ops import np_config + np_config.enable_numpy_behavior() + backend_list = get_backend_list() diff --git a/test/test_backend.py b/test/test_backend.py index 2fe661e98..4d2bd1dfd 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -244,6 +244,8 @@ def test_empty_backend(): nx.allclose(M, M) with pytest.raises(NotImplementedError): nx.T(M) + with pytest.raises(NotImplementedError): + nx.squeeze(M) def test_func_backends(nx): @@ -369,7 +371,7 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('dot(M,v)') - A = nx.dot(Mb, nx.T(Mb)) + A = nx.dot(Mb, Mb.T) lst_b.append(nx.to_numpy(A)) lst_name.append('dot(M,M)') @@ -492,7 +494,7 @@ def test_func_backends(nx): lst_name.append('coo_matrix') assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' - assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)' + assert nx.issparse(sp_Mb) or nx.__name__ in ("jax", "tf"), 'Assert fail on: issparse (expected True)' A = nx.tocsr(sp_Mb) lst_b.append(nx.to_numpy(nx.todense(A))) @@ -521,6 +523,9 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('transpose') + A = nx.squeeze(nx.zeros((3, 1, 4, 1))) + assert tuple(A.shape) == (3, 4), 'Assert fail on: squeeze' + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_bregman.py b/test/test_bregman.py index 830052dae..ab072291f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -248,6 +248,7 @@ def test_sinkhorn_empty(): ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) +@pytest.skip_backend('tf') @pytest.skip_backend("jax") def test_sinkhorn_variants(nx): # test sinkhorn @@ -282,6 +283,8 @@ def test_sinkhorn_variants(nx): "sinkhorn_epsilon_scaling", "greenkhorn", "sinkhorn_log"]) +@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str) @pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) @pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) def test_sinkhorn_variants_dtype_device(nx, method): @@ -323,6 +326,7 @@ def test_sinkhorn2_variants_dtype_device(nx, method): nx.assert_same_dtype_device(Mb, lossb) +@pytest.skip_backend('tf') @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn @@ -352,6 +356,7 @@ def test_sinkhorn_variants_multi_b(nx): np.testing.assert_allclose(G0, Gs, atol=1e-05) +@pytest.skip_backend('tf') @pytest.skip_backend("jax") def test_sinkhorn2_variants_multi_b(nx): # test sinkhorn @@ -454,7 +459,7 @@ def test_barycenter(nx, method, verbose, warn): weights_nx = nx.from_numpy(weights) reg = 1e-2 - if nx.__name__ == "jax" and method == "sinkhorn_log": + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) else: @@ -495,7 +500,7 @@ def test_barycenter_debiased(nx, method, verbose, warn): # wasserstein reg = 1e-2 - if nx.__name__ == "jax" and method == "sinkhorn_log": + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) else: @@ -597,7 +602,7 @@ def test_wasserstein_bary_2d(nx, method): # wasserstein reg = 1e-2 - if nx.__name__ == "jax" and method == "sinkhorn_log": + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: @@ -629,7 +634,7 @@ def test_wasserstein_bary_2d_debiased(nx, method): # wasserstein reg = 1e-2 - if nx.__name__ == "jax" and method == "sinkhorn_log": + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) else: @@ -888,6 +893,7 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) +@pytest.skip_backend('tf') @pytest.skip_backend("jax") @pytest.mark.filterwarnings("ignore:Bottleneck") def test_screenkhorn(nx): From 595564379fe7008b55fd4c9c6fa8311cec4eacbd Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 30 Nov 2021 17:21:20 +0100 Subject: [PATCH 04/46] add tf requirement --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 43532470a..6cd02920d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ scikit-learn torch jax jaxlib +tensorflow pytest \ No newline at end of file From d988df045d46f07757f2f7698c2a75b6c1fa3a85 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 30 Nov 2021 17:30:58 +0100 Subject: [PATCH 05/46] pep8 + bug --- ot/backend.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index c9a7779a0..58f0a555c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1552,7 +1552,10 @@ def T(self, a): return a.T def squeeze(self, a, axis=None): - return torch.squeeze(a, dim=axis) + if axis is None: + return torch.squeeze(a) + else: + return torch.squeeze(a, dim=axis) class TensorflowBackend(Backend): @@ -1855,4 +1858,4 @@ def T(self, a): return a.T def squeeze(self, a, axis=None): - return tnp.squeeze(a, axis=axis) \ No newline at end of file + return tnp.squeeze(a, axis=axis) From 86cdbc6469971883f7c6f66f0298409369bac2aa Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 1 Dec 2021 11:21:16 +0100 Subject: [PATCH 06/46] small changes --- ot/backend.py | 30 ++++++++++++++---------------- test/conftest.py | 6 +++--- test/test_gromov.py | 3 +++ 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 58f0a555c..305784f22 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -64,8 +64,7 @@ def get_backend_list(): lst.append(TorchBackend()) if jax: - pass - # lst.append(JaxBackend()) + lst.append(JaxBackend()) if tf: lst.append(TensorflowBackend()) @@ -1654,19 +1653,18 @@ def minimum(self, a, b): return tnp.minimum(a, b) def dot(self, a, b): - return tnp.dot(a, b) - # if len(b.shape) == 1: - # if len(a.shape) == 1: - # # inner product - # return tf.reduce_sum(tf.multiply(a, b)) - # else: - # # matrix vector - # return tf.linalg.matvec(a, b) - # else: - # if len(a.shape) == 1: - # return self.T(tf.linalg.matvec(self.T(b), self.T(a))) - # else: - # return tf.matmul(a, b) + if len(b.shape) == 1: + if len(a.shape) == 1: + # inner product + return tf.reduce_sum(tf.multiply(a, b)) + else: + # matrix vector + return tf.linalg.matvec(a, b) + else: + if len(a.shape) == 1: + return tf.linalg.matvec(b.T, a.T).T + else: + return tf.matmul(a, b) def abs(self, a): return tnp.abs(a) @@ -1684,7 +1682,7 @@ def power(self, a, exponents): return tnp.power(a, exponents) def norm(self, a): - return tnp.sqrt(tnp.sum(tnp.square(a))) + return tf.math.reduce_euclidean_norm(a) def any(self, a): return tnp.any(a) diff --git a/test/conftest.py b/test/conftest.py index 9b8a18270..c0db8abe2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -28,16 +28,16 @@ def nx(request): def skip_arg(arg, value, reason=None, getter=lambda x: x): - if isinstance(arg, tuple) or isinstance(arg, list): + if isinstance(arg, (tuple, list)): n = len(arg) else: arg = (arg, ) n = 1 - if n != 1 and (isinstance(value, tuple) or isinstance(value, list)): + if n != 1 and isinstance(value, (tuple, list)): pass else: value = (value, ) - if isinstance(getter, tuple) or isinstance(value, list): + if isinstance(getter, (tuple, list)): pass else: getter = [getter] * n diff --git a/test/test_gromov.py b/test/test_gromov.py index c4bc04c5e..690b19408 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -147,6 +147,7 @@ def test_gromov2_gradients(): @pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov(nx): n_samples = 50 # nb samples @@ -204,6 +205,7 @@ def test_entropic_gromov(nx): @pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov_dtype_device(nx): # setup n_samples = 50 # nb samples @@ -302,6 +304,7 @@ def lossb(x, y): np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8) +@pytest.skip_backend("tf", reason="test very slow with tf backend") @pytest.skip_backend("jax", reason="test very slow with jax backend") def test_sampled_gromov(nx): n_samples = 50 # nb samples From 93bc691267637f7829202a2a1ac3bd05bb0c6970 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 1 Dec 2021 11:24:39 +0100 Subject: [PATCH 07/46] attempt to solve pymanopt bug with tf2 --- ot/dr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ot/dr.py b/ot/dr.py index c2f51f86c..45b342c57 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -181,6 +181,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no else: regmean = np.ones((len(xc), len(xc))) + + @pymanopt.function.Autograd def cost(P): # wda loss loss_b = 0 From 8baeae9bbfd72d7e7032207a8e496333c45139cb Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 1 Dec 2021 11:28:20 +0100 Subject: [PATCH 08/46] attempt #2 --- ot/dr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/dr.py b/ot/dr.py index 45b342c57..1671ca0f8 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -16,6 +16,7 @@ from scipy import linalg import autograd.numpy as np +from pymanopt.function import Autograd from pymanopt.manifolds import Stiefel from pymanopt import Problem from pymanopt.solvers import SteepestDescent, TrustRegions @@ -181,8 +182,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no else: regmean = np.ones((len(xc), len(xc))) - - @pymanopt.function.Autograd + @Autograd def cost(P): # wda loss loss_b = 0 From 1881f8570d82df48649b272ea7e574530ab9066f Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 1 Dec 2021 12:35:51 +0100 Subject: [PATCH 09/46] attempt #3 --- ot/dr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/dr.py b/ot/dr.py index 1671ca0f8..241b3ae0b 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -16,7 +16,7 @@ from scipy import linalg import autograd.numpy as np -from pymanopt.function import Autograd +import pymanopt from pymanopt.manifolds import Stiefel from pymanopt import Problem from pymanopt.solvers import SteepestDescent, TrustRegions @@ -182,7 +182,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no else: regmean = np.ones((len(xc), len(xc))) - @Autograd + @pymanopt.function.Autograd def cost(P): # wda loss loss_b = 0 From 249ea2fb46bdcb7c1be2bee70067eb21d183af5d Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 1 Dec 2021 16:29:09 +0100 Subject: [PATCH 10/46] attempt 4 --- ot/dr.py | 4 ++-- requirements.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ot/dr.py b/ot/dr.py index 241b3ae0b..1671ca0f8 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -16,7 +16,7 @@ from scipy import linalg import autograd.numpy as np -import pymanopt +from pymanopt.function import Autograd from pymanopt.manifolds import Stiefel from pymanopt import Problem from pymanopt.solvers import SteepestDescent, TrustRegions @@ -182,7 +182,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no else: regmean = np.ones((len(xc), len(xc))) - @pymanopt.function.Autograd + @Autograd def cost(P): # wda loss loss_b = 0 diff --git a/requirements.txt b/requirements.txt index 6cd02920d..d191907fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ cython matplotlib autograd pymanopt==0.2.4; python_version <'3' -pymanopt; python_version >= '3' +git+https://github.com/pymanopt/pymanopt; python_version >= '3' cvxopt scikit-learn torch From b8877465ea7a80efa7e8ac8bd89eed14c8d6b5e3 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 1 Dec 2021 17:23:47 +0100 Subject: [PATCH 11/46] docstring --- ot/backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 305784f22..0cc6c5ff2 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -3,7 +3,7 @@ Multi-lib backend for POT The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch, -or Jax, POT code should work nonetheless. +Jax, or Tensorflow, POT code should work nonetheless. To achieve that, POT provides backend classes which implements functions in their respective backend imitating Numpy API. As a convention, we use nx instead of np to refer to the backend. @@ -108,7 +108,8 @@ def to_numpy(*args): class Backend(): """ Backend abstract class. - Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend` + Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`, + :py:class:`TensorflowBackend` - The `__name__` class attribute refers to the name of the backend. - The `__type__` class attribute refers to the data structure used by the backend. From a34aeffc5d83a55c7e4d48fd83b2ff1a810e7af0 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Fri, 3 Dec 2021 14:39:05 +0100 Subject: [PATCH 12/46] correct pep8 violation introduced in merge conflicts resolution --- ot/backend.py | 2 +- test/test_backend.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index ad6651680..4296c7264 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1869,7 +1869,7 @@ class TensorflowBackend(Backend): __type_list__ = None rng_ = None - + def __init__(self): self.seed(None) diff --git a/test/test_backend.py b/test/test_backend.py index cfc727ca6..6bf8ee840 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -7,7 +7,6 @@ import ot import ot.backend - from ot.backend import torch, jax, cp, tf import pytest From eae1c9af7516401d024a9edbe3a64be9128492d5 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Fri, 3 Dec 2021 14:45:31 +0100 Subject: [PATCH 13/46] attempt 5 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d191907fe..d43be7ad8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ cython matplotlib autograd pymanopt==0.2.4; python_version <'3' -git+https://github.com/pymanopt/pymanopt; python_version >= '3' +pymanopt==0.2.6rc1; python_version >= '3' cvxopt scikit-learn torch From 01fce562d849dcee6c8582712ecca326a7fb764f Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Fri, 3 Dec 2021 15:04:59 +0100 Subject: [PATCH 14/46] attempt 6 --- .github/requirements_test_windows.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt index 331dd570a..b94392f8f 100644 --- a/.github/requirements_test_windows.txt +++ b/.github/requirements_test_windows.txt @@ -4,7 +4,7 @@ cython matplotlib autograd pymanopt==0.2.4; python_version <'3' -pymanopt; python_version >= '3' +pymanopt==0.2.6rc1; python_version >= '3' cvxopt scikit-learn pytest \ No newline at end of file From 8223e768bfe33635549fb66cca2267514a60ebbf Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Fri, 3 Dec 2021 15:15:59 +0100 Subject: [PATCH 15/46] just a random try --- .github/requirements_test_windows.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt index b94392f8f..dc25a648d 100644 --- a/.github/requirements_test_windows.txt +++ b/.github/requirements_test_windows.txt @@ -7,4 +7,7 @@ pymanopt==0.2.4; python_version <'3' pymanopt==0.2.6rc1; python_version >= '3' cvxopt scikit-learn +jax +jaxlib +tensorflow pytest \ No newline at end of file From aaac0ee14499423776551ed54fed00cdd31bed40 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Fri, 3 Dec 2021 15:22:56 +0100 Subject: [PATCH 16/46] Revert "just a random try" This reverts commit 8223e768bfe33635549fb66cca2267514a60ebbf. --- .github/requirements_test_windows.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt index dc25a648d..b94392f8f 100644 --- a/.github/requirements_test_windows.txt +++ b/.github/requirements_test_windows.txt @@ -7,7 +7,4 @@ pymanopt==0.2.4; python_version <'3' pymanopt==0.2.6rc1; python_version >= '3' cvxopt scikit-learn -jax -jaxlib -tensorflow pytest \ No newline at end of file From 0feea713077aaf14a9f97fb3c74311e489b096f1 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Mon, 6 Dec 2021 10:29:34 +0100 Subject: [PATCH 17/46] GPU tests for tensorflow --- test/test_1d_solver.py | 68 ++++++++++++++++++++++++++++++++++++++++-- test/test_bregman.py | 31 ++++++++++++++++++- test/test_gromov.py | 42 +++++++++++++++++++++++++- test/test_ot.py | 36 +++++++++++++++++++++- test/test_sliced.py | 57 +++++++++++++++++++++++++++++++++++ 5 files changed, 228 insertions(+), 6 deletions(-) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index cb85cb959..6a42cfe83 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -11,7 +11,7 @@ import ot from ot.lp import wasserstein_1d -from ot.backend import get_backend_list +from ot.backend import get_backend_list, tf from scipy.stats import wasserstein_distance backend_list = get_backend_list() @@ -86,7 +86,6 @@ def test_wasserstein_1d(nx): def test_wasserstein_1d_type_devices(nx): - rng = np.random.RandomState(0) n = 10 @@ -108,6 +107,37 @@ def test_wasserstein_1d_type_devices(nx): nx.assert_same_dtype_device(xb, res) +@pytest.mark.skipif(not tf, reason="tf not installed") +def test_wasserstein_1d_device_tf(): + if not tf: + return + nx = ot.backend.TensorflowBackend() + rng = np.random.RandomState(0) + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + xb = nx.from_numpy(x) + rho_ub = nx.from_numpy(rho_u) + rho_vb = nx.from_numpy(rho_v) + res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) + nx.assert_same_dtype_device(xb, res) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + xb = nx.from_numpy(x) + rho_ub = nx.from_numpy(rho_u) + rho_vb = nx.from_numpy(rho_v) + res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) + nx.assert_same_dtype_device(xb, res) + assert nx.dtype_device(res)[1].startswith("GPU") + + def test_emd_1d_emd2_1d(): # test emd1d gives similar results as emd n = 20 @@ -148,7 +178,6 @@ def test_emd_1d_emd2_1d(): def test_emd1d_type_devices(nx): - rng = np.random.RandomState(0) n = 10 @@ -170,3 +199,36 @@ def test_emd1d_type_devices(nx): nx.assert_same_dtype_device(xb, emd) nx.assert_same_dtype_device(xb, emd2) + + +@pytest.mark.skipif(not tf, reason="tf not installed") +def test_emd1d_device_tf(): + nx = ot.backend.TensorflowBackend() + rng = np.random.RandomState(0) + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + xb = nx.from_numpy(x) + rho_ub = nx.from_numpy(rho_u) + rho_vb = nx.from_numpy(rho_v) + emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) + emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) + nx.assert_same_dtype_device(xb, emd) + nx.assert_same_dtype_device(xb, emd2) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + xb = nx.from_numpy(x) + rho_ub = nx.from_numpy(rho_u) + rho_vb = nx.from_numpy(rho_v) + emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) + emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) + nx.assert_same_dtype_device(xb, emd) + nx.assert_same_dtype_device(xb, emd2) + assert nx.dtype_device(emd)[1].startswith("GPU") diff --git a/test/test_bregman.py b/test/test_bregman.py index 413933ad6..45db06925 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -12,7 +12,7 @@ import pytest import ot -from ot.backend import torch +from ot.backend import torch, tf @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) @@ -326,6 +326,35 @@ def test_sinkhorn2_variants_dtype_device(nx, method): nx.assert_same_dtype_device(Mb, lossb) +@pytest.mark.skipif(not tf, reason="tf not installed") +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) +def test_sinkhorn2_variants_dtype_device(method): + nx = ot.backend.TensorflowBackend() + n = 100 + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + M = ot.dist(x, x) + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + ub = nx.from_numpy(u) + Mb = nx.from_numpy(M) + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, lossb) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + ub = nx.from_numpy(u) + Mb = nx.from_numpy(M) + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, lossb) + assert nx.dtype_device(Gb)[1].startswith("GPU") + + @pytest.skip_backend('tf') @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): diff --git a/test/test_gromov.py b/test/test_gromov.py index 987dc395a..4200e2a05 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -9,7 +9,7 @@ import numpy as np import ot from ot.backend import NumpyBackend -from ot.backend import torch +from ot.backend import torch, tf import pytest @@ -113,6 +113,46 @@ def test_gromov_dtype_device(nx): nx.assert_same_dtype_device(C1b, gw_valb) +@pytest.mark.skipif(not tf, reason="tf not installed") +def test_gromov_device_tf(): + nx = ot.backend.TensorflowBackend() + n_samples = 50 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + p = ot.unif(n_samples) + q = ot.unif(n_samples) + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + C1 /= C1.max() + C2 /= C2.max() + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + assert nx.dtype_device(Gb)[1].startswith("GPU") + + + def test_gromov2_gradients(): n_samples = 50 # nb samples diff --git a/test/test_ot.py b/test/test_ot.py index c4d771332..53edf4f65 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -11,7 +11,7 @@ import ot from ot.datasets import make_1D_gauss as gauss -from ot.backend import torch +from ot.backend import torch, tf def test_emd_dimension_and_mass_mismatch(): @@ -101,6 +101,40 @@ def test_emd_emd2_types_devices(nx): nx.assert_same_dtype_device(Mb, w) +@pytest.mark.skipif(not tf, reason="tf not installed") +def test_emd_emd2_devices_tf(): + if not tf: + return + nx = ot.backend.TensorflowBackend() + + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + M = ot.dist(x, y) + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + Gb = ot.emd(ab, ab, Mb) + w = ot.emd2(ab, ab, Mb) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + Gb = ot.emd(ab, ab, Mb) + w = ot.emd2(ab, ab, Mb) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) + assert nx.dtype_device(Gb)[1].startswith("GPU") + + def test_emd2_gradients(): n_samples = 100 n_features = 2 diff --git a/test/test_sliced.py b/test/test_sliced.py index 245202c86..91e09613a 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -10,6 +10,7 @@ import ot from ot.sliced import get_random_projections +from ot.backend import tf def test_get_random_projections(): @@ -161,6 +162,34 @@ def test_sliced_backend_type_devices(nx): nx.assert_same_dtype_device(xb, valb) +@pytest.mark.skipif(not tf, reason="tf not installed") +def test_sliced_backend_device_tf(): + nx = ot.backend.TensorflowBackend() + n = 100 + rng = np.random.RandomState(0) + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) + nx.assert_same_dtype_device(xb, valb) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) + nx.assert_same_dtype_device(xb, valb) + assert nx.dtype_device(valb)[1].startswith("GPU") + + def test_max_sliced_backend(nx): n = 100 @@ -211,3 +240,31 @@ def test_max_sliced_backend_type_devices(nx): valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) + + +@pytest.mark.skipif(not tf, reason="tf not installed") +def test_max_sliced_backend_device_tf(): + nx = ot.backend.TensorflowBackend() + n = 100 + rng = np.random.RandomState(0) + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) + nx.assert_same_dtype_device(xb, valb) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) + nx.assert_same_dtype_device(xb, valb) + assert nx.dtype_device(valb)[1].startswith("GPU") From 32c2838eaf63aa9ed05a75c84474109a6cbd80b4 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Mon, 6 Dec 2021 10:48:45 +0100 Subject: [PATCH 18/46] pep8 --- test/test_bregman.py | 2 +- test/test_gromov.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 45db06925..6e90aa472 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -328,7 +328,7 @@ def test_sinkhorn2_variants_dtype_device(nx, method): @pytest.mark.skipif(not tf, reason="tf not installed") @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) -def test_sinkhorn2_variants_dtype_device(method): +def test_sinkhorn2_variants_device_tf(method): nx = ot.backend.TensorflowBackend() n = 100 x = np.random.randn(n, 2) diff --git a/test/test_gromov.py b/test/test_gromov.py index 4200e2a05..b4fb0800e 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -152,7 +152,6 @@ def test_gromov_device_tf(): assert nx.dtype_device(Gb)[1].startswith("GPU") - def test_gromov2_gradients(): n_samples = 50 # nb samples From 3cadd11e31a33533d298b5a7ff4e89287d7501a1 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Mon, 6 Dec 2021 12:18:27 +0100 Subject: [PATCH 19/46] attempt to solve issue with m2r2 --- README.md | 2 +- docs/requirements.txt | 1 + docs/requirements_rtd.txt | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 18064a398..172dde928 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. -* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays. +* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. POT provides the following Machine Learning related solvers: diff --git a/docs/requirements.txt b/docs/requirements.txt index 95147d226..9c311053e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,4 +4,5 @@ numpydoc memory_profiler pillow networkx +mistune==0.8.4 m2r2 \ No newline at end of file diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt index 5963ea271..7a6dbf694 100644 --- a/docs/requirements_rtd.txt +++ b/docs/requirements_rtd.txt @@ -3,6 +3,7 @@ numpydoc memory_profiler pillow networkx +mistune==0.8.4 m2r2 numpy scipy>=1.0 From aaa7e4a51a84dd97fb66e333977d2a69835477cf Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Mon, 6 Dec 2021 13:24:58 +0100 Subject: [PATCH 20/46] Remove transpose backend method --- ot/backend.py | 25 ------------------------- test/test_backend.py | 6 ------ 2 files changed, 31 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 4296c7264..60a7a51b1 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -694,16 +694,6 @@ def assert_same_dtype_device(self, a, b): """ raise NotImplementedError() - def T(self, a): - r""" - Returns the transpose of a tensor. - - This function follows the api from :any:`numpy.ndarray.T`. - - See: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.T.html - """ - raise NotImplementedError() - def squeeze(self, a, axis=None): r""" Remove axes of length one from a. @@ -951,9 +941,6 @@ def assert_same_dtype_device(self, a, b): # numpy has implicit type conversion so we automatically validate the test pass - def T(self, a): - return a.T - def squeeze(self, a, axis=None): return np.squeeze(a, axis=axis) @@ -1217,9 +1204,6 @@ def assert_same_dtype_device(self, a, b): assert a_dtype == b_dtype, "Dtype discrepancy" assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" - def T(self, a): - return a.T - def squeeze(self, a, axis=None): return jnp.squeeze(a, axis=axis) @@ -1562,9 +1546,6 @@ def assert_same_dtype_device(self, a, b): assert a_dtype == b_dtype, "Dtype discrepancy" assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" - def T(self, a): - return a.T - def squeeze(self, a, axis=None): if axis is None: return torch.squeeze(a) @@ -1855,9 +1836,6 @@ def assert_same_dtype_device(self, a, b): # we automatically validate the test for type assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" - def T(self, a): - return a.T - def squeeze(self, a, axis=None): return cp.squeeze(a, axis=axis) @@ -2156,8 +2134,5 @@ def assert_same_dtype_device(self, a, b): assert a_dtype == b_dtype, "Dtype discrepancy" assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" - def T(self, a): - return a.T - def squeeze(self, a, axis=None): return tnp.squeeze(a, axis=axis) diff --git a/test/test_backend.py b/test/test_backend.py index 6bf8ee840..d9ac16a65 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -256,8 +256,6 @@ def test_empty_backend(): nx.copy(M) with pytest.raises(NotImplementedError): nx.allclose(M, M) - with pytest.raises(NotImplementedError): - nx.T(M) with pytest.raises(NotImplementedError): nx.squeeze(M) @@ -534,10 +532,6 @@ def test_func_backends(nx): assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' - A = nx.T(Mb) - lst_b.append(nx.to_numpy(A)) - lst_name.append('transpose') - A = nx.squeeze(nx.zeros((3, 1, 4, 1))) assert tuple(A.shape) == (3, 4), 'Assert fail on: squeeze' From 245a3c22434a6edd17559234f09d74cc17df70b0 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Mon, 6 Dec 2021 18:04:16 +0100 Subject: [PATCH 21/46] first draft of benchmarker (need to correct time measurement) --- benchmarks/__init__.py | 4 ++ benchmarks/benchmark.py | 67 +++++++++++++++++++++++++++++ benchmarks/sinkhorn_knopp.py | 34 +++++++++++++++ ot/backend.py | 82 ++++++++++++++++++++++++++++++++++++ 4 files changed, 187 insertions(+) create mode 100644 benchmarks/__init__.py create mode 100644 benchmarks/benchmark.py create mode 100644 benchmarks/sinkhorn_knopp.py diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 000000000..0ecbfad71 --- /dev/null +++ b/benchmarks/__init__.py @@ -0,0 +1,4 @@ +from . import benchmark +from . import sinkhorn_knopp + +__all__= ["benchmark", "sinkhorn_knopp"] diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py new file mode 100644 index 000000000..03a6ed6dd --- /dev/null +++ b/benchmarks/benchmark.py @@ -0,0 +1,67 @@ +# /usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +from ot.backend import get_backend_list, jax, tf + + +def setup_backends(): + if jax: + from jax.config import config + config.update("jax_enable_x64", True) + + if tf: + from tensorflow.python.ops.numpy_ops import np_config + np_config.enable_numpy_behavior() + + +def exec_bench(setup, tested_function, param_list, n_runs): + backend_list = get_backend_list() + results = dict() + for param in param_list: + L = dict() + inputs = setup(param) + for nx in backend_list: + results_nx = nx._bench( + tested_function, + *inputs, + n_runs=n_runs + ) + L.update(results_nx) + results[param] = L + return results + + +def get_keys(d): + return sorted(list(d.keys())) + + +def convert_to_html_table(results, param_name): + string = "\n" + keys = get_keys(results) + print(results[keys[0]].keys()) + subkeys = get_keys(results[keys[0]]) + names, devices, bitsizes = zip(*subkeys) + + names = sorted(list(set(zip(names, devices)))) + length = len(names) + 1 + + for bitsize in sorted(list(set(bitsizes))): + string += f'\n' + string += f'\n" + + for key in keys: + subdict = results[key] + subkeys = get_keys(subdict) + string += f'' + for subkey in subkeys: + name, device, size = subkey + if size == bitsize: + string += f'' + string += "\n" + + string += "
{bitsize} bits
{param_name}' + for name, device in names: + string += f'{name} {device}' + string += "
{key}{subdict[subkey]:.4f}
" + return string diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py new file mode 100644 index 000000000..0297379fa --- /dev/null +++ b/benchmarks/sinkhorn_knopp.py @@ -0,0 +1,34 @@ +# /usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +import ot +from .benchmark import ( + setup_backends, + exec_bench, + convert_to_html_table +) + + +def setup(n_samples): + rng = np.random.RandomState(123456789) + a = rng.rand(n_samples // 4, 100) + b = rng.rand(n_samples, 100) + + wa = ot.unif(n_samples // 4) + wb = ot.unif(n_samples) + + M = ot.dist(a.copy(), b.copy()) + return wa, wb, M + + +if __name__ == "__main__": + setup_backends() + results = exec_bench( + setup=setup, + tested_function=lambda *args: ot.bregman.sinkhorn(*args, reg=1, stopThr=1e-7), + param_list=[50, 100, 500, 1000], #, 2000, 5000, 10000], + n_runs=10 + ) + + print(convert_to_html_table(results, param_name="Sample size")) diff --git a/ot/backend.py b/ot/backend.py index 60a7a51b1..ecd7e4161 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -28,6 +28,7 @@ import scipy.special as scipy from scipy.sparse import issparse, coo_matrix, csr_matrix import warnings +import time try: import torch @@ -944,6 +945,19 @@ def assert_same_dtype_device(self, a, b): def squeeze(self, a, axis=None): return np.squeeze(a, axis=axis) + def _bench(self, callable, *args, n_runs=1): + results = dict() + for type_as in self.__type_list__: + args = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*args) + t0 = time.perf_counter() + for _ in range(n_runs): + callable(*args) + t1 = time.perf_counter() + key = ("Numpy", "CPU", type_as.itemsize * 8) + results[key] = (t1 - t0) / n_runs + return results + class JaxBackend(Backend): """ @@ -1207,6 +1221,21 @@ def assert_same_dtype_device(self, a, b): def squeeze(self, a, axis=None): return jnp.squeeze(a, axis=axis) + def _bench(self, callable, *args, n_runs=1): + results = dict() + for type_as in self.__type_list__: + args = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*args) + t0 = time.perf_counter() + for _ in range(n_runs): + callable(*args) + t1 = time.perf_counter() + device = "CPU" if "cpu" in str(type_as.device_buffer.device()) else "GPU" + bitsize = type_as.dtype.itemsize * 8 + key = ("Jax", device, bitsize) + results[key] = (t1 - t0) / n_runs + return results + class TorchBackend(Backend): """ @@ -1552,6 +1581,21 @@ def squeeze(self, a, axis=None): else: return torch.squeeze(a, dim=axis) + def _bench(self, callable, *args, n_runs=1): + results = dict() + for type_as in self.__type_list__: + args = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*args) + t0 = time.perf_counter() + for _ in range(n_runs): + callable(*args) + t1 = time.perf_counter() + device = "CPU" if "cpu" in str(type_as.device) else "GPU" + bitsize = torch.finfo(type_as.dtype).bits + key = ("Pytorch", device, bitsize) + results[key] = (t1 - t0) / n_runs + return results + class CupyBackend(Backend): # pragma: no cover """ @@ -1839,6 +1883,19 @@ def assert_same_dtype_device(self, a, b): def squeeze(self, a, axis=None): return cp.squeeze(a, axis=axis) + def _bench(self, callable, *args, n_runs=1): + results = dict() + for type_as in self.__type_list__: + args = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*args) + t0 = time.perf_counter() + for _ in range(n_runs): + callable(*args) + t1 = time.perf_counter() + key = ("Cupy", "GPU", type_as.itemsize * 8) + results[key] = (t1 - t0) / n_runs + return results + class TensorflowBackend(Backend): @@ -2136,3 +2193,28 @@ def assert_same_dtype_device(self, a, b): def squeeze(self, a, axis=None): return tnp.squeeze(a, axis=axis) + + def _bench(self, callable, *args, n_runs=1): + results = dict() + with tf.device("/CPU:0"): + for type_as in self.__type_list__: + args = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*args) + t0 = time.perf_counter() + for _ in range(n_runs): + callable(*args) + t1 = time.perf_counter() + key = ("Tensorflow", "CPU", type_as.dtype.size * 8) + results[key] = (t1 - t0) / n_runs + + if len(tf.config.list_physical_devices('GPU')) > 0: + for type_as in self.__type_list__: + args = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*args) + t0 = time.perf_counter() + for _ in range(n_runs): + callable(*args) + t1 = time.perf_counter() + key = ("Tensorflow", "GPU", type_as.dtype.size * 8) + results[key] = (t1 - t0) / n_runs + return results From f269fdaa04005fef214a899920b5910a7ce38369 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 09:44:10 +0100 Subject: [PATCH 22/46] prettier bench table --- benchmarks/benchmark.py | 39 +++++++++++++++++++++++++++--------- benchmarks/sinkhorn_knopp.py | 6 +++++- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 03a6ed6dd..f45c4817e 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -22,6 +22,7 @@ def exec_bench(setup, tested_function, param_list, n_runs): L = dict() inputs = setup(param) for nx in backend_list: + print(param, nx) results_nx = nx._bench( tested_function, *inputs, @@ -36,23 +37,43 @@ def get_keys(d): return sorted(list(d.keys())) -def convert_to_html_table(results, param_name): +def convert_to_html_table(results, param_name, comments=None): string = "\n" keys = get_keys(results) - print(results[keys[0]].keys()) subkeys = get_keys(results[keys[0]]) names, devices, bitsizes = zip(*subkeys) - names = sorted(list(set(zip(names, devices)))) - length = len(names) + 1 + devices_names = sorted(list(set(zip(devices, names)))) + length = len(devices_names) + 1 + n_bitsizes = len(set(bitsizes)) + cpus_cols = list(devices).count("CPU") / n_bitsizes + gpus_cols = list(devices).count("GPU") / n_bitsizes + assert cpus_cols + gpus_cols == len(devices_names) - for bitsize in sorted(list(set(bitsizes))): - string += f'\n' - string += f'\n' + + # make device header + string += f'' + string += f'' + string += f'\n' + + # make param_name / backend header + string += f'' + for device, name in devices_names: + string += f'' string += "\n" + # make results rows for key in keys: subdict = results[key] subkeys = get_keys(subdict) diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py index 0297379fa..39376d75f 100644 --- a/benchmarks/sinkhorn_knopp.py +++ b/benchmarks/sinkhorn_knopp.py @@ -31,4 +31,8 @@ def setup(n_samples): n_runs=10 ) - print(convert_to_html_table(results, param_name="Sample size")) + print(convert_to_html_table( + results, + param_name="Sample size", + comments="Sinkhorn Knopp" + )) From bdba755a5f586d4ecfc2391ede5cfdcded11da5d Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 10:55:10 +0100 Subject: [PATCH 23/46] Bitsize and prettier device methods --- ot/backend.py | 123 +++++++++++++++++++++++++++++-------------- test/test_backend.py | 15 ++++++ 2 files changed, 98 insertions(+), 40 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index ecd7e4161..33361b568 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -705,6 +705,24 @@ def squeeze(self, a, axis=None): """ raise NotImplementedError() + def bitsize(self, type_as): + r""" + Gives the number of bits used by the data type of the given tensor. + """ + raise NotImplementedError() + + def prettier_device(self, type_as): + r""" + Returns CPU or GPU depending on the device where the given tensor is located. + """ + raise NotImplementedError() + + def _bench(self, callable, *args, n_runs=1): + r""" + Executes a benchmark of the given callable with the given arguments. + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -945,16 +963,22 @@ def assert_same_dtype_device(self, a, b): def squeeze(self, a, axis=None): return np.squeeze(a, axis=axis) + def bitsize(self, type_as): + return type_as.itemsize * 8 + + def prettier_device(self, type_as): + return "CPU" + def _bench(self, callable, *args, n_runs=1): results = dict() for type_as in self.__type_list__: - args = [self.from_numpy(arg, type_as=type_as) for arg in args] - callable(*args) + inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*inputs) t0 = time.perf_counter() for _ in range(n_runs): - callable(*args) + callable(*inputs) t1 = time.perf_counter() - key = ("Numpy", "CPU", type_as.itemsize * 8) + key = ("Numpy", self.prettier_device(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs return results @@ -1221,18 +1245,22 @@ def assert_same_dtype_device(self, a, b): def squeeze(self, a, axis=None): return jnp.squeeze(a, axis=axis) + def bitsize(self, type_as): + return type_as.dtype.itemsize * 8 + + def prettier_device(self, type_as): + return "CPU" if "cpu" in str(type_as.device_buffer.device()) else "GPU" + def _bench(self, callable, *args, n_runs=1): results = dict() for type_as in self.__type_list__: - args = [self.from_numpy(arg, type_as=type_as) for arg in args] - callable(*args) + inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*inputs) t0 = time.perf_counter() for _ in range(n_runs): - callable(*args) + callable(*inputs) t1 = time.perf_counter() - device = "CPU" if "cpu" in str(type_as.device_buffer.device()) else "GPU" - bitsize = type_as.dtype.itemsize * 8 - key = ("Jax", device, bitsize) + key = ("Jax", self.prettier_device(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs return results @@ -1581,18 +1609,22 @@ def squeeze(self, a, axis=None): else: return torch.squeeze(a, dim=axis) + def bitsize(self, type_as): + return torch.finfo(type_as.dtype).bits + + def prettier_device(self, type_as): + return "CPU" if "cpu" in str(type_as.device) else "GPU" + def _bench(self, callable, *args, n_runs=1): results = dict() for type_as in self.__type_list__: - args = [self.from_numpy(arg, type_as=type_as) for arg in args] - callable(*args) + inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*inputs) t0 = time.perf_counter() for _ in range(n_runs): - callable(*args) + callable(*inputs) t1 = time.perf_counter() - device = "CPU" if "cpu" in str(type_as.device) else "GPU" - bitsize = torch.finfo(type_as.dtype).bits - key = ("Pytorch", device, bitsize) + key = ("Pytorch", self.prettier_device(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs return results @@ -1883,16 +1915,22 @@ def assert_same_dtype_device(self, a, b): def squeeze(self, a, axis=None): return cp.squeeze(a, axis=axis) + def bitsize(self, type_as): + return type_as.itemsize * 8 + + def prettier_device(self, type_as): + return "GPU" + def _bench(self, callable, *args, n_runs=1): results = dict() for type_as in self.__type_list__: - args = [self.from_numpy(arg, type_as=type_as) for arg in args] - callable(*args) + inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*inputs) t0 = time.perf_counter() for _ in range(n_runs): - callable(*args) + callable(*inputs) t1 = time.perf_counter() - key = ("Cupy", "GPU", type_as.itemsize * 8) + key = ("Cupy", self.prettier_device(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs return results @@ -2194,27 +2232,32 @@ def assert_same_dtype_device(self, a, b): def squeeze(self, a, axis=None): return tnp.squeeze(a, axis=axis) + def bitsize(self, type_as): + return type_as.dtype.size * 8 + + def prettier_device(self, type_as): + return "CPU" if "CPU" in type_as.device else "GPU" + def _bench(self, callable, *args, n_runs=1): results = dict() - with tf.device("/CPU:0"): - for type_as in self.__type_list__: - args = [self.from_numpy(arg, type_as=type_as) for arg in args] - callable(*args) - t0 = time.perf_counter() - for _ in range(n_runs): - callable(*args) - t1 = time.perf_counter() - key = ("Tensorflow", "CPU", type_as.dtype.size * 8) - results[key] = (t1 - t0) / n_runs - + device_contexts = [tf.device("/CPU:0")] if len(tf.config.list_physical_devices('GPU')) > 0: - for type_as in self.__type_list__: - args = [self.from_numpy(arg, type_as=type_as) for arg in args] - callable(*args) - t0 = time.perf_counter() - for _ in range(n_runs): - callable(*args) - t1 = time.perf_counter() - key = ("Tensorflow", "GPU", type_as.dtype.size * 8) - results[key] = (t1 - t0) / n_runs + device_contexts.append(tf.device("/GPU:0")) + + for device_context in device_contexts: + with device_context: + for type_as in self.__type_list__: + inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + callable(*inputs) + t0 = time.perf_counter() + for _ in range(n_runs): + callable(*inputs) + t1 = time.perf_counter() + key = ( + "Tensorflow", + self.prettier_device(inputs[0]), + self.bitsize(type_as) + ) + results[key] = (t1 - t0) / n_runs + return results diff --git a/test/test_backend.py b/test/test_backend.py index d9ac16a65..1ba7efcd1 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -258,6 +258,12 @@ def test_empty_backend(): nx.allclose(M, M) with pytest.raises(NotImplementedError): nx.squeeze(M) + with pytest.raises(NotImplementedError): + nx.bitsize(M) + with pytest.raises(NotImplementedError): + nx.prettier_device(M) + with pytest.raises(NotImplementedError): + nx._bench(lambda x: x, M, n_runs=1) def test_func_backends(nx): @@ -535,6 +541,15 @@ def test_func_backends(nx): A = nx.squeeze(nx.zeros((3, 1, 4, 1))) assert tuple(A.shape) == (3, 4), 'Assert fail on: squeeze' + A = nx.bitsize(Mb) + lst_b.append(float(A)) + lst_name.append("bitsize") + + A = nx.prettier_device(Mb) + assert A in ("CPU", "GPU") + + nx._bench(lambda x: x, M, n_runs=1) + lst_tot.append(lst_b) lst_np = lst_tot[0] From 30e7ba783c647fe9f531e2fc6347ea9e5b36f73c Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 12:29:17 +0100 Subject: [PATCH 24/46] prettified table bench --- benchmarks/benchmark.py | 11 +++++++++-- benchmarks/sinkhorn_knopp.py | 10 ++++++---- ot/backend.py | 5 +++++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index f45c4817e..b832f0688 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -37,7 +37,7 @@ def get_keys(d): return sorted(list(d.keys())) -def convert_to_html_table(results, param_name, comments=None): +def convert_to_html_table(results, param_name, main_title=None, comments=None): string = "
{bitsize} bits
{param_name}' - for name, device in names: - string += f'{name} {device}' + for i, bitsize in enumerate(sorted(list(set(bitsizes)))): + + # make bitsize header + text = f"{bitsize} bits" + if comments is not None: + text += " - " + if isinstance(comments, (tuple, list)) and len(comments) == n_bitsizes: + text += str(comments[i]) + else: + text += str(comments) + string += f'
{text}
DevicesCPUGPU
{param_name}{name}
\n" keys = get_keys(results) subkeys = get_keys(results[keys[0]]) @@ -50,8 +50,14 @@ def convert_to_html_table(results, param_name, comments=None): gpus_cols = list(devices).count("GPU") / n_bitsizes assert cpus_cols + gpus_cols == len(devices_names) + if main_title is not None: + string += f'\n' + for i, bitsize in enumerate(sorted(list(set(bitsizes)))): + if i != 0: + string += f'\n' + # make bitsize header text = f"{bitsize} bits" if comments is not None: @@ -60,7 +66,8 @@ def convert_to_html_table(results, param_name, comments=None): text += str(comments[i]) else: text += str(comments) - string += f'\n' + string += f'' + string += f'\n' # make device header string += f'' diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py index 39376d75f..0eaf0fdf8 100644 --- a/benchmarks/sinkhorn_knopp.py +++ b/benchmarks/sinkhorn_knopp.py @@ -23,16 +23,18 @@ def setup(n_samples): if __name__ == "__main__": + n_runs = 100 + param_list = [50]#100, 500]#, 1000, 2000, 5000, 10000] + setup_backends() results = exec_bench( setup=setup, tested_function=lambda *args: ot.bregman.sinkhorn(*args, reg=1, stopThr=1e-7), - param_list=[50, 100, 500, 1000], #, 2000, 5000, 10000], - n_runs=10 + param_list=param_list, + n_runs=n_runs ) - print(convert_to_html_table( results, param_name="Sample size", - comments="Sinkhorn Knopp" + main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs" )) diff --git a/ot/backend.py b/ot/backend.py index 33361b568..8bb23ff92 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -980,6 +980,7 @@ def _bench(self, callable, *args, n_runs=1): t1 = time.perf_counter() key = ("Numpy", self.prettier_device(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs + del inputs return results @@ -1262,6 +1263,7 @@ def _bench(self, callable, *args, n_runs=1): t1 = time.perf_counter() key = ("Jax", self.prettier_device(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs + del inputs return results @@ -1626,6 +1628,7 @@ def _bench(self, callable, *args, n_runs=1): t1 = time.perf_counter() key = ("Pytorch", self.prettier_device(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs + del inputs return results @@ -1932,6 +1935,7 @@ def _bench(self, callable, *args, n_runs=1): t1 = time.perf_counter() key = ("Cupy", self.prettier_device(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs + del inputs return results @@ -2259,5 +2263,6 @@ def _bench(self, callable, *args, n_runs=1): self.bitsize(type_as) ) results[key] = (t1 - t0) / n_runs + del inputs return results From 689ae01eda5704f7bc57117477448b88a8e02b34 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 13:23:54 +0100 Subject: [PATCH 25/46] Bug corrected (results were mixed up in the final table) --- benchmarks/benchmark.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index b832f0688..78c2235b4 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -83,12 +83,10 @@ def convert_to_html_table(results, param_name, main_title=None, comments=None): # make results rows for key in keys: subdict = results[key] - subkeys = get_keys(subdict) string += f'' - for subkey in subkeys: - name, device, size = subkey - if size == bitsize: - string += f'' + for device, name in devices_names: + subkey = (name, device, bitsize) + string += f'' string += "\n" string += "
{str(main_title)}
 
{text}
Bitsize{text}
Devices
{key}{subdict[subkey]:.4f}{subdict[subkey]:.4f}
" From 1347387a49a3d9c2f062a9fa075871690eb15a57 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 13:24:22 +0100 Subject: [PATCH 26/46] Better perf counter (for GPU support) --- ot/backend.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 8bb23ff92..bcd534157 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -980,7 +980,6 @@ def _bench(self, callable, *args, n_runs=1): t1 = time.perf_counter() key = ("Numpy", self.prettier_device(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs - del inputs return results @@ -1254,16 +1253,21 @@ def prettier_device(self, type_as): def _bench(self, callable, *args, n_runs=1): results = dict() + @jax.jit + def add_one(M): + return M + 1 for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] callable(*inputs) t0 = time.perf_counter() for _ in range(n_runs): - callable(*inputs) + # We are technically doing more calculations but adding one + # is expected to be very quick and allows us to access the + # block_until_ready method to measure asynchronous calculations + add_one(callable(*inputs)).block_until_ready() t1 = time.perf_counter() key = ("Jax", self.prettier_device(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs - del inputs return results @@ -1622,13 +1626,16 @@ def _bench(self, callable, *args, n_runs=1): for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] callable(*inputs) - t0 = time.perf_counter() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() for _ in range(n_runs): callable(*inputs) - t1 = time.perf_counter() + end.record() + torch.cuda.synchronize() key = ("Pytorch", self.prettier_device(type_as), self.bitsize(type_as)) - results[key] = (t1 - t0) / n_runs - del inputs + results[key] = start.elapsed_time(end) / 1000. / n_runs return results @@ -1928,14 +1935,18 @@ def _bench(self, callable, *args, n_runs=1): results = dict() for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + start_gpu = cp.cuda.Event() + end_gpu = cp.cuda.Event() callable(*inputs) - t0 = time.perf_counter() + start_gpu.synchronize() + start_gpu.record() for _ in range(n_runs): callable(*inputs) - t1 = time.perf_counter() + end_gpu.record() + end_gpu.synchronize() key = ("Cupy", self.prettier_device(type_as), self.bitsize(type_as)) - results[key] = (t1 - t0) / n_runs - del inputs + t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) / 1000. + results[key] = t_gpu / n_runs return results @@ -2263,6 +2274,5 @@ def _bench(self, callable, *args, n_runs=1): self.bitsize(type_as) ) results[key] = (t1 - t0) / n_runs - del inputs return results From 96488ce74784a7cf7daaef48024092d89dc996ad Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 13:26:01 +0100 Subject: [PATCH 27/46] pep8 --- ot/backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ot/backend.py b/ot/backend.py index bcd534157..135037675 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1253,9 +1253,11 @@ def prettier_device(self, type_as): def _bench(self, callable, *args, n_runs=1): results = dict() + @jax.jit def add_one(M): return M + 1 + for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] callable(*inputs) From 59ea42ec0ba6b4e0165ce62e2a8609a2f6d8d5f8 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 13:33:00 +0100 Subject: [PATCH 28/46] EMD bench --- benchmarks/emd.py | 38 ++++++++++++++++++++++++++++++++++++ benchmarks/sinkhorn_knopp.py | 2 +- 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 benchmarks/emd.py diff --git a/benchmarks/emd.py b/benchmarks/emd.py new file mode 100644 index 000000000..189b4fec0 --- /dev/null +++ b/benchmarks/emd.py @@ -0,0 +1,38 @@ +# /usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +import ot +from .benchmark import ( + setup_backends, + exec_bench, + convert_to_html_table +) + + +def setup(n_samples): + rng = np.random.RandomState(789465132) + x = rng.randn(n_samples, 2) + y = rng.randn(n_samples, 2) + + a = ot.utils.unif(n_samples) + M = ot.dist(x, y) + return a, M + + +if __name__ == "__main__": + n_runs = 100 + param_list = [50, 100, 500] # 1000, 2000, 5000, 10000] + + setup_backends() + results = exec_bench( + setup=setup, + tested_function=lambda a, M: ot.emd(a, a, M), + param_list=param_list, + n_runs=n_runs + ) + print(convert_to_html_table( + results, + param_name="Sample size", + main_title=f"EMD - Averaged on {n_runs} runs" + )) diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py index 0eaf0fdf8..e4ca027cd 100644 --- a/benchmarks/sinkhorn_knopp.py +++ b/benchmarks/sinkhorn_knopp.py @@ -24,7 +24,7 @@ def setup(n_samples): if __name__ == "__main__": n_runs = 100 - param_list = [50]#100, 500]#, 1000, 2000, 5000, 10000] + param_list = [50, 100, 500] # 1000, 2000, 5000, 10000] setup_backends() results = exec_bench( From 22c4d0c271f752e94179ac6e6d1739b930ccee0e Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 13:40:31 +0100 Subject: [PATCH 29/46] solve bug if no GPU available --- ot/backend.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 135037675..5a743247b 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -17,6 +17,13 @@ ... nx = get_backend(a, b) # infer the backend from the arguments ... c = nx.dot(a, b) # now use the backend to do any calculation ... return c + + +.. note:: +Tensorflow only works with the numpy API. To activate it, please run the following: + +>>> from tensorflow.python.ops.numpy_ops import np_config +>>> np_config.enable_numpy_behavior() """ # Author: Remi Flamary @@ -1628,16 +1635,24 @@ def _bench(self, callable, *args, n_runs=1): for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] callable(*inputs) - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() + if self.prettier_device(type_as) == "GPU": + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + else: + start = time.perf_counter() for _ in range(n_runs): callable(*inputs) - end.record() - torch.cuda.synchronize() + if self.prettier_device(type_as) == "GPU": + end.record() + torch.cuda.synchronize() + duration = start.elapsed_time(end) / 1000. + else: + end = time.perf_counter() + duration = end - start key = ("Pytorch", self.prettier_device(type_as), self.bitsize(type_as)) - results[key] = start.elapsed_time(end) / 1000. / n_runs + results[key] = duration / n_runs return results From eca3f804b79cfdb842d865abdf7cbd1c66579034 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 13:42:02 +0100 Subject: [PATCH 30/46] pep8 --- ot/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index 5a743247b..0a9ee6472 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1644,7 +1644,7 @@ def _bench(self, callable, *args, n_runs=1): start = time.perf_counter() for _ in range(n_runs): callable(*inputs) - if self.prettier_device(type_as) == "GPU": + if self.prettier_device(type_as) == "GPU": end.record() torch.cuda.synchronize() duration = start.elapsed_time(end) / 1000. From aa5257adfe0829171af806b99bc74b537da3120e Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 13:47:28 +0100 Subject: [PATCH 31/46] warning about tensorflow numpy api being required in the backend.py docstring --- ot/backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 0a9ee6472..a6de7ff7c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -19,11 +19,11 @@ ... return c -.. note:: -Tensorflow only works with the numpy API. To activate it, please run the following: +.. warning:: + Tensorflow only works with the Numpy API. To activate it, please run the following: ->>> from tensorflow.python.ops.numpy_ops import np_config ->>> np_config.enable_numpy_behavior() + >>> from tensorflow.python.ops.numpy_ops import np_config + >>> np_config.enable_numpy_behavior() """ # Author: Remi Flamary From 1f6266071d3f824f09fcab634546e71ff6239260 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 14:12:30 +0100 Subject: [PATCH 32/46] Bug solve in backend docstring --- ot/backend.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index a6de7ff7c..fa3680500 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -22,8 +22,10 @@ .. warning:: Tensorflow only works with the Numpy API. To activate it, please run the following: - >>> from tensorflow.python.ops.numpy_ops import np_config - >>> np_config.enable_numpy_behavior() + .. code-block:: + + from tensorflow.python.ops.numpy_ops import np_config + np_config.enable_numpy_behavior() """ # Author: Remi Flamary From 0f1d299360432677a8f1c6cf8c533735b92a5618 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 14:23:50 +0100 Subject: [PATCH 33/46] not covering code which requires a GPU --- ot/backend.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index fa3680500..22fefe3f7 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -85,7 +85,7 @@ def get_backend_list(): if jax: lst.append(JaxBackend()) - if cp: + if cp: # pragma: no cover lst.append(CupyBackend()) if tf: @@ -112,7 +112,7 @@ def get_backend(*args): return TorchBackend() elif isinstance(args[0], jax_type): return JaxBackend() - elif isinstance(args[0], cp_type): + elif isinstance(args[0], cp_type): # pragma: no cover return CupyBackend() elif isinstance(args[0], tf_type): return TensorflowBackend() @@ -1637,7 +1637,7 @@ def _bench(self, callable, *args, n_runs=1): for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] callable(*inputs) - if self.prettier_device(type_as) == "GPU": + if self.prettier_device(type_as) == "GPU": # pragma: no cover torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -1646,7 +1646,7 @@ def _bench(self, callable, *args, n_runs=1): start = time.perf_counter() for _ in range(n_runs): callable(*inputs) - if self.prettier_device(type_as) == "GPU": + if self.prettier_device(type_as) == "GPU": # pragma: no cover end.record() torch.cuda.synchronize() duration = start.elapsed_time(end) / 1000. @@ -2275,7 +2275,7 @@ def prettier_device(self, type_as): def _bench(self, callable, *args, n_runs=1): results = dict() device_contexts = [tf.device("/CPU:0")] - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices('GPU')) > 0: # pragma: no cover device_contexts.append(tf.device("/GPU:0")) for device_context in device_contexts: From 09840004f453535a2d7a9f58ffc77182fef68f4f Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 14:58:50 +0100 Subject: [PATCH 34/46] Tensorflow gradients manipulation tested --- test/test_backend.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/test_backend.py b/test/test_backend.py index 1ba7efcd1..15a9dea35 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -624,3 +624,17 @@ def fun(a, b, d): np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4) np.testing.assert_allclose(grad_val[0], v, atol=1e-4) np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4) + + if tf: + nx = ot.backend.TensorflowBackend() + w = tf.Variable(tf.random.normal((3, 2)), name='w') + b = tf.Variable(tf.random.normal((2,), dtype=tf.float32), name='b') + x = tf.random.normal((1, 3), dtype=tf.float32) + + with tf.GradientTape() as tape: + y = x @ w + b + loss = tf.reduce_mean(y ** 2) + manipulated_loss = nx.set_gradients(loss, (w, b), (w, b)) + [dl_dw, dl_db] = tape.gradient(manipulated_loss, [w, b]) + assert nx.allclose(dl_dw, w) + assert nx.allclose(dl_db, b) From 9e56ccd1922fd53d70d27fbdeafcc260f81b8234 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 15:31:50 +0100 Subject: [PATCH 35/46] Number of warmup runs is now customizable --- benchmarks/benchmark.py | 5 +++-- benchmarks/emd.py | 4 +++- benchmarks/sinkhorn_knopp.py | 4 +++- ot/backend.py | 27 ++++++++++++++++----------- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 78c2235b4..af3bbd8e4 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -15,7 +15,7 @@ def setup_backends(): np_config.enable_numpy_behavior() -def exec_bench(setup, tested_function, param_list, n_runs): +def exec_bench(setup, tested_function, param_list, n_runs, warmup_runs): backend_list = get_backend_list() results = dict() for param in param_list: @@ -26,7 +26,8 @@ def exec_bench(setup, tested_function, param_list, n_runs): results_nx = nx._bench( tested_function, *inputs, - n_runs=n_runs + n_runs=n_runs, + warmup_runs=warmup_runs ) L.update(results_nx) results[param] = L diff --git a/benchmarks/emd.py b/benchmarks/emd.py index 189b4fec0..c87372113 100644 --- a/benchmarks/emd.py +++ b/benchmarks/emd.py @@ -22,6 +22,7 @@ def setup(n_samples): if __name__ == "__main__": n_runs = 100 + warmup_runs = 10 param_list = [50, 100, 500] # 1000, 2000, 5000, 10000] setup_backends() @@ -29,7 +30,8 @@ def setup(n_samples): setup=setup, tested_function=lambda a, M: ot.emd(a, a, M), param_list=param_list, - n_runs=n_runs + n_runs=n_runs, + warmup_runs=warmup_runs ) print(convert_to_html_table( results, diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py index e4ca027cd..a05075e99 100644 --- a/benchmarks/sinkhorn_knopp.py +++ b/benchmarks/sinkhorn_knopp.py @@ -24,6 +24,7 @@ def setup(n_samples): if __name__ == "__main__": n_runs = 100 + warmup_runs = 10 param_list = [50, 100, 500] # 1000, 2000, 5000, 10000] setup_backends() @@ -31,7 +32,8 @@ def setup(n_samples): setup=setup, tested_function=lambda *args: ot.bregman.sinkhorn(*args, reg=1, stopThr=1e-7), param_list=param_list, - n_runs=n_runs + n_runs=n_runs, + warmup_runs=warmup_runs ) print(convert_to_html_table( results, diff --git a/ot/backend.py b/ot/backend.py index 22fefe3f7..1ac60e9fb 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -726,7 +726,7 @@ def prettier_device(self, type_as): """ raise NotImplementedError() - def _bench(self, callable, *args, n_runs=1): + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): r""" Executes a benchmark of the given callable with the given arguments. """ @@ -978,11 +978,12 @@ def bitsize(self, type_as): def prettier_device(self, type_as): return "CPU" - def _bench(self, callable, *args, n_runs=1): + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): results = dict() for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] - callable(*inputs) + for _ in range(warmup_runs): + callable(*inputs) t0 = time.perf_counter() for _ in range(n_runs): callable(*inputs) @@ -1260,7 +1261,7 @@ def bitsize(self, type_as): def prettier_device(self, type_as): return "CPU" if "cpu" in str(type_as.device_buffer.device()) else "GPU" - def _bench(self, callable, *args, n_runs=1): + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): results = dict() @jax.jit @@ -1269,7 +1270,8 @@ def add_one(M): for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] - callable(*inputs) + for _ in range(warmup_runs): + add_one(callable(*inputs)).block_until_ready() t0 = time.perf_counter() for _ in range(n_runs): # We are technically doing more calculations but adding one @@ -1632,11 +1634,12 @@ def bitsize(self, type_as): def prettier_device(self, type_as): return "CPU" if "cpu" in str(type_as.device) else "GPU" - def _bench(self, callable, *args, n_runs=1): + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): results = dict() for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] - callable(*inputs) + for _ in range(warmup_runs): + callable(*inputs) if self.prettier_device(type_as) == "GPU": # pragma: no cover torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) @@ -1950,13 +1953,14 @@ def bitsize(self, type_as): def prettier_device(self, type_as): return "GPU" - def _bench(self, callable, *args, n_runs=1): + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): results = dict() for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] start_gpu = cp.cuda.Event() end_gpu = cp.cuda.Event() - callable(*inputs) + for _ in range(warmup_runs): + callable(*inputs) start_gpu.synchronize() start_gpu.record() for _ in range(n_runs): @@ -2272,7 +2276,7 @@ def bitsize(self, type_as): def prettier_device(self, type_as): return "CPU" if "CPU" in type_as.device else "GPU" - def _bench(self, callable, *args, n_runs=1): + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): results = dict() device_contexts = [tf.device("/CPU:0")] if len(tf.config.list_physical_devices('GPU')) > 0: # pragma: no cover @@ -2282,7 +2286,8 @@ def _bench(self, callable, *args, n_runs=1): with device_context: for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] - callable(*inputs) + for _ in range(warmup_runs): + callable(*inputs) t0 = time.perf_counter() for _ in range(n_runs): callable(*inputs) From 84fb00277fd7ad01e9289f4ea65963f8e1b10e1c Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 15:35:32 +0100 Subject: [PATCH 36/46] typo --- benchmarks/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index af3bbd8e4..37dbf92fd 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -71,7 +71,7 @@ def convert_to_html_table(results, param_name, main_title=None, comments=None): string += f'{text}\n' # make device header - string += f'Devices' + string += f'Device' string += f'CPU' string += f'GPU\n' From 9a7fd033308846c495ae10938be2d3b578a709e7 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 16:04:11 +0100 Subject: [PATCH 37/46] Remove some warnings while building docs --- ot/da.py | 44 ++++++++++++++++++++++---------------------- ot/datasets.py | 2 +- ot/plot.py | 4 ++-- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/ot/da.py b/ot/da.py index 4fd97dfa6..841f31a30 100644 --- a/ot/da.py +++ b/ot/da.py @@ -906,7 +906,7 @@ def df(G): def distribution_estimation_uniform(X): - """estimates a uniform distribution from an array of samples :math:`\mathbf{X}` + r"""estimates a uniform distribution from an array of samples :math:`\mathbf{X}` Parameters ---------- @@ -950,7 +950,7 @@ class BaseTransport(BaseEstimator): """ def fit(self, Xs=None, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1010,7 +1010,7 @@ class label return self def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` and transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` @@ -1038,7 +1038,7 @@ class label return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt) def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` + r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -1105,7 +1105,7 @@ class label return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in + r"""Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in :ref:`[27] `. Parameters @@ -1152,7 +1152,7 @@ def transform_labels(self, ys=None): def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` + r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` Parameters ---------- @@ -1218,7 +1218,7 @@ class label return transp_Xt def inverse_transform_labels(self, yt=None): - """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels + r"""Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels :math:`\mathbf{y_s}` Parameters @@ -1307,7 +1307,7 @@ def __init__(self, reg=1e-8, bias=True, log=False, self.distribution_estimation = distribution_estimation def fit(self, Xs=None, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1354,7 +1354,7 @@ class label return self def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` + r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -1387,7 +1387,7 @@ class label def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` + r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` Parameters ---------- @@ -1493,7 +1493,7 @@ def __init__(self, reg_e=1., max_iter=1000, self.out_of_sample_map = out_of_sample_map def fit(self, Xs=None, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1592,7 +1592,7 @@ def __init__(self, metric="sqeuclidean", norm=None, log=False, self.max_iter = max_iter def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1711,7 +1711,7 @@ def __init__(self, reg_e=1., reg_cl=0.1, self.limit_max = limit_max def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1839,7 +1839,7 @@ def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean", self.out_of_sample_map = out_of_sample_map def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1962,7 +1962,7 @@ def __init__(self, reg_e=1., reg_cl=0.1, self.limit_max = limit_max def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -2088,7 +2088,7 @@ def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean", self.verbose2 = verbose2 def fit(self, Xs=None, ys=None, Xt=None, yt=None): - """Builds an optimal coupling and estimates the associated mapping + r"""Builds an optimal coupling and estimates the associated mapping from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` @@ -2146,7 +2146,7 @@ class label return self def transform(self, Xs): - """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` + r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -2261,7 +2261,7 @@ def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn', self.limit_max = limit_max def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -2373,7 +2373,7 @@ def __init__(self, reg_e=.1, max_iter=10, self.out_of_sample_map = out_of_sample_map def fit(self, Xs, ys=None, Xt=None, yt=None): - """Building coupling matrices from a list of source and target sets of samples + r"""Building coupling matrices from a list of source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -2419,7 +2419,7 @@ class label return self def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` + r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -2491,7 +2491,7 @@ class label return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in + r"""Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in :ref:`[27] ` Parameters @@ -2542,7 +2542,7 @@ def transform_labels(self, ys=None): return yt.T def inverse_transform_labels(self, yt=None): - """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels + r"""Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels :math:`\mathbf{y_s}` Parameters diff --git a/ot/datasets.py b/ot/datasets.py index ad6390c20..a83907464 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -41,7 +41,7 @@ def get_1D_gauss(n, m, sigma): def make_2D_samples_gauss(n, m, sigma, random_state=None): - """Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)` + r"""Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)` Parameters ---------- diff --git a/ot/plot.py b/ot/plot.py index 3e3bed71e..2208c9088 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -18,7 +18,7 @@ def plot1D_mat(a, b, M, title=''): - """ Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution + r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution Creates a subplot with the source distribution :math:`\mathbf{a}` on the left and target distribution :math:`\mathbf{b}` on the top. The matrix :math:`\mathbf{M}` is shown in between. @@ -61,7 +61,7 @@ def plot1D_mat(a, b, M, title=''): def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): - """ Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values + r""" Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values Plot lines between source and target 2D samples with a color proportional to the value of the matrix :math:`\mathbf{G}` between samples. From acc9474602cf88105d2f1ab341ff70182ee6925f Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 17:20:35 +0100 Subject: [PATCH 38/46] Change prettier_device to device_type in backend --- ot/backend.py | 28 ++++++++++++++-------------- test/test_backend.py | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 1ac60e9fb..f86c9e871 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -720,7 +720,7 @@ def bitsize(self, type_as): """ raise NotImplementedError() - def prettier_device(self, type_as): + def device_type(self, type_as): r""" Returns CPU or GPU depending on the device where the given tensor is located. """ @@ -975,7 +975,7 @@ def squeeze(self, a, axis=None): def bitsize(self, type_as): return type_as.itemsize * 8 - def prettier_device(self, type_as): + def device_type(self, type_as): return "CPU" def _bench(self, callable, *args, n_runs=1, warmup_runs=1): @@ -988,7 +988,7 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): for _ in range(n_runs): callable(*inputs) t1 = time.perf_counter() - key = ("Numpy", self.prettier_device(type_as), self.bitsize(type_as)) + key = ("Numpy", self.device_type(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs return results @@ -1258,8 +1258,8 @@ def squeeze(self, a, axis=None): def bitsize(self, type_as): return type_as.dtype.itemsize * 8 - def prettier_device(self, type_as): - return "CPU" if "cpu" in str(type_as.device_buffer.device()) else "GPU" + def device_type(self, type_as): + return self.dtype_device(type_as)[1].platform.upper() def _bench(self, callable, *args, n_runs=1, warmup_runs=1): results = dict() @@ -1279,7 +1279,7 @@ def add_one(M): # block_until_ready method to measure asynchronous calculations add_one(callable(*inputs)).block_until_ready() t1 = time.perf_counter() - key = ("Jax", self.prettier_device(type_as), self.bitsize(type_as)) + key = ("Jax", self.device_type(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs return results @@ -1631,7 +1631,7 @@ def squeeze(self, a, axis=None): def bitsize(self, type_as): return torch.finfo(type_as.dtype).bits - def prettier_device(self, type_as): + def device_type(self, type_as): return "CPU" if "cpu" in str(type_as.device) else "GPU" def _bench(self, callable, *args, n_runs=1, warmup_runs=1): @@ -1640,7 +1640,7 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] for _ in range(warmup_runs): callable(*inputs) - if self.prettier_device(type_as) == "GPU": # pragma: no cover + if self.device_type(type_as) == "GPU": # pragma: no cover torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -1649,14 +1649,14 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): start = time.perf_counter() for _ in range(n_runs): callable(*inputs) - if self.prettier_device(type_as) == "GPU": # pragma: no cover + if self.device_type(type_as) == "GPU": # pragma: no cover end.record() torch.cuda.synchronize() duration = start.elapsed_time(end) / 1000. else: end = time.perf_counter() duration = end - start - key = ("Pytorch", self.prettier_device(type_as), self.bitsize(type_as)) + key = ("Pytorch", self.device_type(type_as), self.bitsize(type_as)) results[key] = duration / n_runs return results @@ -1950,7 +1950,7 @@ def squeeze(self, a, axis=None): def bitsize(self, type_as): return type_as.itemsize * 8 - def prettier_device(self, type_as): + def device_type(self, type_as): return "GPU" def _bench(self, callable, *args, n_runs=1, warmup_runs=1): @@ -1967,7 +1967,7 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): callable(*inputs) end_gpu.record() end_gpu.synchronize() - key = ("Cupy", self.prettier_device(type_as), self.bitsize(type_as)) + key = ("Cupy", self.device_type(type_as), self.bitsize(type_as)) t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) / 1000. results[key] = t_gpu / n_runs return results @@ -2273,7 +2273,7 @@ def squeeze(self, a, axis=None): def bitsize(self, type_as): return type_as.dtype.size * 8 - def prettier_device(self, type_as): + def device_type(self, type_as): return "CPU" if "CPU" in type_as.device else "GPU" def _bench(self, callable, *args, n_runs=1, warmup_runs=1): @@ -2294,7 +2294,7 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): t1 = time.perf_counter() key = ( "Tensorflow", - self.prettier_device(inputs[0]), + self.device_type(inputs[0]), self.bitsize(type_as) ) results[key] = (t1 - t0) / n_runs diff --git a/test/test_backend.py b/test/test_backend.py index 15a9dea35..027c4cdeb 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -261,7 +261,7 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.bitsize(M) with pytest.raises(NotImplementedError): - nx.prettier_device(M) + nx.device_type(M) with pytest.raises(NotImplementedError): nx._bench(lambda x: x, M, n_runs=1) @@ -545,7 +545,7 @@ def test_func_backends(nx): lst_b.append(float(A)) lst_name.append("bitsize") - A = nx.prettier_device(Mb) + A = nx.device_type(Mb) assert A in ("CPU", "GPU") nx._bench(lambda x: x, M, n_runs=1) From 0593f87cff24407a9acf5eb4feb42468a9365401 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 17:21:29 +0100 Subject: [PATCH 39/46] Correct JAX mistakes preventing to see the CPU if a GPU is present --- ot/backend.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index f86c9e871..058e6cc30 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1010,9 +1010,12 @@ class JaxBackend(Backend): def __init__(self): self.rng_ = jax.random.PRNGKey(42) - for d in jax.devices(): - self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d), - jax.device_put(jnp.array(1, dtype=jnp.float64), d)] + self.__type_list__ = [] + for d in jax.devices("cpu") + jax.devices("gpu"): + self.__type_list__ += [ + jax.device_put(jnp.array(1, dtype=jnp.float32), d), + jax.device_put(jnp.array(1, dtype=jnp.float64), d) + ] def to_numpy(self, a): return np.array(a) From 321597f1e2ef38dc7d7927ba42425e7a6cd7cc77 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Tue, 7 Dec 2021 17:33:31 +0100 Subject: [PATCH 40/46] Attempt to solve JAX bug in case no GPU is found --- ot/backend.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index 058e6cc30..f63e3afda 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -50,6 +50,7 @@ import jax import jax.numpy as jnp import jax.scipy.special as jscipy + from jax.lib import xla_bridge jax_type = jax.numpy.ndarray except ImportError: jax = False @@ -1011,7 +1012,10 @@ def __init__(self): self.rng_ = jax.random.PRNGKey(42) self.__type_list__ = [] - for d in jax.devices("cpu") + jax.devices("gpu"): + available_devices = jax.devices("cpu") + if xla_bridge.get_backend().platform == "gpu": + available_devices += jax.devices("gpu") + for d in available_devices: self.__type_list__ += [ jax.device_put(jnp.array(1, dtype=jnp.float32), d), jax.device_put(jnp.array(1, dtype=jnp.float64), d) From ec72f30bd7859e4cfe8583ee7b53c6972195f12c Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 8 Dec 2021 11:29:27 +0100 Subject: [PATCH 41/46] Reworked benchmarks order and results storage & clear GPU after usage by benchmark --- benchmarks/__init__.py | 3 +- benchmarks/benchmark.py | 63 +++++++++++++++++++++--------------- benchmarks/emd.py | 2 +- benchmarks/sinkhorn_knopp.py | 2 +- ot/backend.py | 13 ++++++-- 5 files changed, 52 insertions(+), 31 deletions(-) diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 0ecbfad71..37f5e569a 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -1,4 +1,5 @@ from . import benchmark from . import sinkhorn_knopp +from . import emd -__all__= ["benchmark", "sinkhorn_knopp"] +__all__= ["benchmark", "sinkhorn_knopp", "emd"] diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 37dbf92fd..85b5562cd 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -1,8 +1,8 @@ # /usr/bin/env python3 # -*- coding: utf-8 -*- -import numpy as np from ot.backend import get_backend_list, jax, tf +import gc def setup_backends(): @@ -17,44 +17,56 @@ def setup_backends(): def exec_bench(setup, tested_function, param_list, n_runs, warmup_runs): backend_list = get_backend_list() + for i, nx in enumerate(backend_list): + if nx.__name__ == "tf" and i < len(backend_list) - 1: + # Tensorflow should be the last one to be benchmarked because + # as far as I'm aware, there is no way to force it to release + # GPU memory. Hence, if any other backend is benchmarked after + # Tensorflow and requires the usage of a GPU, it will not have the + # full memory available and you may have a GPU Out Of Memory error + # even though your GPU can technically hold your tensors in memory. + backend_list.pop(i) + backend_list.append(nx) + break + + inputs = [setup(param) for param in param_list] results = dict() - for param in param_list: - L = dict() - inputs = setup(param) - for nx in backend_list: - print(param, nx) + for nx in backend_list: + for i in range(len(param_list)): + print(nx, param_list[i]) + args = inputs[i] results_nx = nx._bench( tested_function, - *inputs, + *args, n_runs=n_runs, warmup_runs=warmup_runs ) - L.update(results_nx) - results[param] = L + gc.collect() + results_nx_with_param_in_key = dict() + for key in results_nx: + new_key = (param_list[i], *key) + results_nx_with_param_in_key[new_key] = results_nx[key] + results.update(results_nx_with_param_in_key) return results -def get_keys(d): - return sorted(list(d.keys())) - - def convert_to_html_table(results, param_name, main_title=None, comments=None): string = "\n" - keys = get_keys(results) - subkeys = get_keys(results[keys[0]]) - names, devices, bitsizes = zip(*subkeys) + keys = list(results.keys()) + params, names, devices, bitsizes = zip(*keys) devices_names = sorted(list(set(zip(devices, names)))) + params = sorted(list(set(params))) + bitsizes = sorted(list(set(bitsizes))) length = len(devices_names) + 1 - n_bitsizes = len(set(bitsizes)) - cpus_cols = list(devices).count("CPU") / n_bitsizes - gpus_cols = list(devices).count("GPU") / n_bitsizes + cpus_cols = list(devices).count("CPU") / len(bitsizes) / len(params) + gpus_cols = list(devices).count("GPU") / len(bitsizes) / len(params) assert cpus_cols + gpus_cols == len(devices_names) if main_title is not None: string += f'\n' - for i, bitsize in enumerate(sorted(list(set(bitsizes)))): + for i, bitsize in enumerate(bitsizes): if i != 0: string += f'\n' @@ -63,7 +75,7 @@ def convert_to_html_table(results, param_name, main_title=None, comments=None): text = f"{bitsize} bits" if comments is not None: text += " - " - if isinstance(comments, (tuple, list)) and len(comments) == n_bitsizes: + if isinstance(comments, (tuple, list)) and len(comments) == len(bitsizes): text += str(comments[i]) else: text += str(comments) @@ -82,12 +94,11 @@ def convert_to_html_table(results, param_name, main_title=None, comments=None): string += "\n" # make results rows - for key in keys: - subdict = results[key] - string += f'' + for param in params: + string += f'' for device, name in devices_names: - subkey = (name, device, bitsize) - string += f'' + key = (param, name, device, bitsize) + string += f'' string += "\n" string += "
{str(main_title)}
 
{key}
{param}{subdict[subkey]:.4f}{results[key]:.4f}
" diff --git a/benchmarks/emd.py b/benchmarks/emd.py index c87372113..9f6486300 100644 --- a/benchmarks/emd.py +++ b/benchmarks/emd.py @@ -23,7 +23,7 @@ def setup(n_samples): if __name__ == "__main__": n_runs = 100 warmup_runs = 10 - param_list = [50, 100, 500] # 1000, 2000, 5000, 10000] + param_list = [50, 100, 500, 1000, 2000, 5000] setup_backends() results = exec_bench( diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py index a05075e99..3a1ef3f37 100644 --- a/benchmarks/sinkhorn_knopp.py +++ b/benchmarks/sinkhorn_knopp.py @@ -25,7 +25,7 @@ def setup(n_samples): if __name__ == "__main__": n_runs = 100 warmup_runs = 10 - param_list = [50, 100, 500] # 1000, 2000, 5000, 10000] + param_list = [50, 100, 500, 1000, 2000, 5000] setup_backends() results = exec_bench( diff --git a/ot/backend.py b/ot/backend.py index f63e3afda..ef0c2d6db 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1012,7 +1012,8 @@ def __init__(self): self.rng_ = jax.random.PRNGKey(42) self.__type_list__ = [] - available_devices = jax.devices("cpu") + # available_devices = jax.devices("cpu") + available_devices = [] if xla_bridge.get_backend().platform == "gpu": available_devices += jax.devices("gpu") for d in available_devices: @@ -1665,6 +1666,8 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): duration = end - start key = ("Pytorch", self.device_type(type_as), self.bitsize(type_as)) results[key] = duration / n_runs + if torch.cuda.is_available(): + torch.cuda.empty_cache() return results @@ -1961,6 +1964,9 @@ def device_type(self, type_as): return "GPU" def _bench(self, callable, *args, n_runs=1, warmup_runs=1): + mempool = cp.get_default_memory_pool() + pinned_mempool = cp.get_default_pinned_memory_pool() + results = dict() for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] @@ -1977,6 +1983,8 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): key = ("Cupy", self.device_type(type_as), self.bitsize(type_as)) t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) / 1000. results[key] = t_gpu / n_runs + mempool.free_all_blocks() + pinned_mempool.free_all_blocks() return results @@ -2297,7 +2305,8 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): callable(*inputs) t0 = time.perf_counter() for _ in range(n_runs): - callable(*inputs) + res = callable(*inputs) + _ = res.numpy() t1 = time.perf_counter() key = ( "Tensorflow", From 9a80c7a3ac0e05550fc8a1cf225e0adc5e241d23 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 8 Dec 2021 12:39:32 +0100 Subject: [PATCH 42/46] Add bench to backend docstring --- ot/backend.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index ef0c2d6db..1269319ea 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -18,7 +18,6 @@ ... c = nx.dot(a, b) # now use the backend to do any calculation ... return c - .. warning:: Tensorflow only works with the Numpy API. To activate it, please run the following: @@ -26,6 +25,60 @@ from tensorflow.python.ops.numpy_ops import np_config np_config.enable_numpy_behavior() + +Performance +-------- + +- CPU: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz +- GPU: Tesla V100-SXM2-32GB +- Date of the benchmark: December 8th, 2021 +- Commit of benchmark: PR #316, https://github.com/PythonOT/POT/pull/316 + +.. raw:: html + + + +
+ + + + + + + + + + + + + + + + + + + + + +
Sinkhorn Knopp - Averaged on 100 runs
Bitsize32 bits
DeviceCPUGPU
Sample sizeNumpyPytorchTensorflowCupyJaxPytorchTensorflow
500.00080.00220.01510.00950.01930.00510.0293
1000.00050.00130.00970.00570.01150.00290.0173
5000.00090.00160.01100.00580.01150.00290.0166
10000.00210.00210.01450.00560.01180.00290.0168
20000.00690.00430.02780.00590.01180.00300.0165
50000.07070.03140.13950.00740.01250.00350.0198
 
Bitsize64 bits
DeviceCPUGPU
Sample sizeNumpyPytorchTensorflowCupyJaxPytorchTensorflow
500.00080.00200.01540.00930.01910.00510.0328
1000.00050.00130.00940.00560.01140.00290.0169
5000.00130.00170.01200.00590.01160.00290.0168
10000.00340.00270.01770.00580.01180.00290.0167
20000.01460.00750.04360.00590.01200.00290.0165
50000.14670.05680.24680.00770.01460.00450.0204
+
""" # Author: Remi Flamary From 21d34eaf296c39ad010891bda88cd2d93b7dabca Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 8 Dec 2021 13:26:18 +0100 Subject: [PATCH 43/46] better benchs --- ot/backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index 1269319ea..7bbb27ad5 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1338,7 +1338,8 @@ def add_one(M): # We are technically doing more calculations but adding one # is expected to be very quick and allows us to access the # block_until_ready method to measure asynchronous calculations - add_one(callable(*inputs)).block_until_ready() + a = callable(*inputs) + add_one(a).block_until_ready() t1 = time.perf_counter() key = ("Jax", self.device_type(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs From c507d3b45590d3721626fbbb15c0aab56e7c83c5 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 8 Dec 2021 13:51:34 +0100 Subject: [PATCH 44/46] remove useless stuff --- ot/backend.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 7bbb27ad5..bdea3de8b 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1325,21 +1325,15 @@ def device_type(self, type_as): def _bench(self, callable, *args, n_runs=1, warmup_runs=1): results = dict() - @jax.jit - def add_one(M): - return M + 1 - for type_as in self.__type_list__: inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] for _ in range(warmup_runs): - add_one(callable(*inputs)).block_until_ready() + a = callable(*inputs) + a.block_until_ready() t0 = time.perf_counter() for _ in range(n_runs): - # We are technically doing more calculations but adding one - # is expected to be very quick and allows us to access the - # block_until_ready method to measure asynchronous calculations a = callable(*inputs) - add_one(a).block_until_ready() + a.block_until_ready() t1 = time.perf_counter() key = ("Jax", self.device_type(type_as), self.bitsize(type_as)) results[key] = (t1 - t0) / n_runs From d02f71a4e1ed9afbf228e773fd4dd4eb57d0e549 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 8 Dec 2021 15:48:16 +0100 Subject: [PATCH 45/46] Better device_type --- ot/backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index bdea3de8b..58b652b8f 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1688,7 +1688,7 @@ def bitsize(self, type_as): return torch.finfo(type_as.dtype).bits def device_type(self, type_as): - return "CPU" if "cpu" in str(type_as.device) else "GPU" + return type_as.device.type.replace("cuda", "gpu").upper() def _bench(self, callable, *args, n_runs=1, warmup_runs=1): results = dict() @@ -2337,7 +2337,7 @@ def bitsize(self, type_as): return type_as.dtype.size * 8 def device_type(self, type_as): - return "CPU" if "CPU" in type_as.device else "GPU" + return self.dtype_device(type_as)[1].split(":")[0] def _bench(self, callable, *args, n_runs=1, warmup_runs=1): results = dict() From 86faaa42678bec21b06d2127b5ce4d8f2f361837 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Thu, 9 Dec 2021 10:29:10 +0100 Subject: [PATCH 46/46] Now using MYST_PARSER and solving links issue in the README.md / online docs --- README.md | 6 +++--- benchmarks/benchmark.py | 4 ++-- docs/requirements.txt | 3 +-- docs/requirements_rtd.txt | 3 +-- docs/source/.github/CODE_OF_CONDUCT.rst | 6 ++++++ docs/source/.github/CONTRIBUTING.rst | 6 ++++++ docs/source/code_of_conduct.rst | 1 - docs/source/conf.py | 2 +- docs/source/contributing.rst | 1 - docs/source/index.rst | 9 ++++----- 10 files changed, 24 insertions(+), 17 deletions(-) create mode 100644 docs/source/.github/CODE_OF_CONDUCT.rst create mode 100644 docs/source/.github/CONTRIBUTING.rst delete mode 100644 docs/source/code_of_conduct.rst delete mode 100644 docs/source/contributing.rst diff --git a/README.md b/README.md index 172dde928..17fbe81f7 100644 --- a/README.md +++ b/README.md @@ -202,12 +202,12 @@ This toolbox benefit a lot from open source research and we would like to thank * [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab) * [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT) -* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) ( C++ code for EMD) +* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) (C++ code for EMD) * [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda) ## Contributions and code of conduct -Every contribution is welcome and should respect the [contribution guidelines](https://pythonot.github.io/contributing.html). Each member of the project is expected to follow the [code of conduct](https://pythonot.github.io/code_of_conduct.html). +Every contribution is welcome and should respect the [contribution guidelines](.github/CONTRIBUTING.md). Each member of the project is expected to follow the [code of conduct](.github/CODE_OF_CONDUCT.md). ## Support @@ -217,7 +217,7 @@ You can ask questions and join the development discussion: * On the POT [gitter channel](https://gitter.im/PythonOT/community) * On the POT [mailing list](https://mail.python.org/mm3/mailman3/lists/pot.python.org/) -You can also post bug reports and feature requests in Github issues. Make sure to read our [guidelines](https://pythonot.github.io/contributing.html) first. +You can also post bug reports and feature requests in Github issues. Make sure to read our [guidelines](.github/CONTRIBUTING.md) first. ## References diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 85b5562cd..7973c6b91 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -84,8 +84,8 @@ def convert_to_html_table(results, param_name, main_title=None, comments=None): # make device header string += f'Device' - string += f'CPU' - string += f'GPU\n' + string += f'CPU' + string += f'GPU\n' # make param_name / backend header string += f'{param_name}' diff --git a/docs/requirements.txt b/docs/requirements.txt index 9c311053e..2e060b93c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,5 +4,4 @@ numpydoc memory_profiler pillow networkx -mistune==0.8.4 -m2r2 \ No newline at end of file +myst-parser \ No newline at end of file diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt index 7a6dbf694..11957fbe2 100644 --- a/docs/requirements_rtd.txt +++ b/docs/requirements_rtd.txt @@ -3,8 +3,7 @@ numpydoc memory_profiler pillow networkx -mistune==0.8.4 -m2r2 +myst-parser numpy scipy>=1.0 cython diff --git a/docs/source/.github/CODE_OF_CONDUCT.rst b/docs/source/.github/CODE_OF_CONDUCT.rst new file mode 100644 index 000000000..d4c5cec49 --- /dev/null +++ b/docs/source/.github/CODE_OF_CONDUCT.rst @@ -0,0 +1,6 @@ +Code of Conduct +=============== + +.. include:: ../../../.github/CODE_OF_CONDUCT.md + :parser: myst_parser.sphinx_ + :start-line: 2 diff --git a/docs/source/.github/CONTRIBUTING.rst b/docs/source/.github/CONTRIBUTING.rst new file mode 100644 index 000000000..aef24e92b --- /dev/null +++ b/docs/source/.github/CONTRIBUTING.rst @@ -0,0 +1,6 @@ +Contributing to POT +=================== + +.. include:: ../../../.github/CONTRIBUTING.md + :parser: myst_parser.sphinx_ + :start-line: 3 diff --git a/docs/source/code_of_conduct.rst b/docs/source/code_of_conduct.rst deleted file mode 100644 index b37ba7b5a..000000000 --- a/docs/source/code_of_conduct.rst +++ /dev/null @@ -1 +0,0 @@ -.. mdinclude:: ../../.github/CODE_OF_CONDUCT.md \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 1320afaac..849e97c74 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -69,7 +69,7 @@ def __getattr__(cls, name): 'sphinx.ext.viewcode', 'sphinx.ext.napoleon', 'sphinx_gallery.gen_gallery', - 'm2r2' + 'myst_parser' ] autosummary_generate = True diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst deleted file mode 100644 index dc81e757b..000000000 --- a/docs/source/contributing.rst +++ /dev/null @@ -1 +0,0 @@ -.. mdinclude:: ../../.github/CONTRIBUTING.md \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 7aaa524ef..8de31aecc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,12 +17,11 @@ Contents all auto_examples/index releases - contributing - Code of Conduct - -.. mdinclude:: ../../README.md - :start-line: 2 + .github/CONTRIBUTING + .github/CODE_OF_CONDUCT +.. include:: ../../README.md + :parser: myst_parser.sphinx_ Indices and tables