diff --git a/tensornetwork/backends/base_backend.py b/tensornetwork/backends/base_backend.py index 55b9f40ea..1d9d249b7 100644 --- a/tensornetwork/backends/base_backend.py +++ b/tensornetwork/backends/base_backend.py @@ -265,6 +265,27 @@ def randn(self, raise NotImplementedError("Backend '{}' has not implemented randn.".format( self.name)) + def random_uniform(self, + shape: Tuple[int, ...], + boundaries: Optional[Tuple[float, float]] = (0.0, 1.0), + dtype: Optional[Type[np.number]] = None, + seed: Optional[int] = None) -> Tensor: + """Return a random uniform matrix of dimension `dim`. + Depending on specific backends, `dim` has to be either an int + (numpy, torch, tensorflow) or a `ShapeType` object + (for block-sparse backends). Block-sparse + behavior is currently not supported + Args: + shape (int): The dimension of the returned matrix. + boundaries (tuple): The boundaries of the uniform distribution. + dtype: The dtype of the returned matrix. + seed: The seed for the random number generator + Returns: + Tensor : random uniform initialized tensor. + """ + raise NotImplementedError(("Backend '{}' has not implemented " + "random_uniform.").format(self.name)) + def conj(self, tensor: Tensor) -> Tensor: """ Return the complex conjugate of `tensor` diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index f9235391a..9773f6026 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -72,6 +72,42 @@ def cmplx_randn(complex_dtype, real_dtype): return self.jax.random.normal(key, shape).astype(dtype) + def random_uniform(self, + shape: Tuple[int, ...], + boundaries: Optional[Tuple[float, float]] = (0.0, 1.0), + dtype: Optional[np.dtype] = None, + seed: Optional[int] = None) -> Tensor: + if not seed: + seed = np.random.randint(0, 2**63) + key = self.jax.random.PRNGKey(seed) + + dtype = dtype if dtype is not None else np.dtype(np.float64) + + def cmplx_random_uniform(complex_dtype, real_dtype): + real_dtype = np.dtype(real_dtype) + complex_dtype = np.dtype(complex_dtype) + + key_2 = self.jax.random.PRNGKey(seed + 1) + + real_part = self.jax.random.uniform(key, shape, dtype=real_dtype, + minval=boundaries[0], + maxval=boundaries[1]) + complex_part = self.jax.random.uniform(key_2, shape, dtype=real_dtype, + minval=boundaries[0], + maxval=boundaries[1]) + unit = ( + np.complex64(1j) + if complex_dtype == np.dtype(np.complex64) else np.complex128(1j)) + return real_part + unit * complex_part + + if np.dtype(dtype) is np.dtype(self.np.complex128): + return cmplx_random_uniform(dtype, self.np.float64) + if np.dtype(dtype) is np.dtype(self.np.complex64): + return cmplx_random_uniform(dtype, self.np.float32) + + return self.jax.random.uniform(key, shape, minval=boundaries[0], + maxval=boundaries[1]).astype(dtype) + def eigs(self, A: Callable, initial_state: Optional[Tensor] = None, diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index fbed4eebe..08b21059d 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -161,6 +161,13 @@ def test_randn(dtype): assert a.shape == (4, 4) +@pytest.mark.parametrize("dtype", np_randn_dtypes) +def test_random_uniform(dtype): + backend = jax_backend.JaxBackend() + a = backend.random_uniform((4, 4), dtype=dtype) + assert a.shape == (4, 4) + + @pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) def test_randn_non_zero_imag(dtype): backend = jax_backend.JaxBackend() @@ -168,6 +175,13 @@ def test_randn_non_zero_imag(dtype): assert np.linalg.norm(np.imag(a)) != 0.0 +@pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) +def test_random_uniform_non_zero_imag(dtype): + backend = jax_backend.JaxBackend() + a = backend.random_uniform((4, 4), dtype=dtype) + assert np.linalg.norm(np.imag(a)) != 0.0 + + @pytest.mark.parametrize("dtype", np_dtypes) def test_eye_dtype(dtype): backend = jax_backend.JaxBackend() @@ -196,6 +210,13 @@ def test_randn_dtype(dtype): assert a.dtype == dtype +@pytest.mark.parametrize("dtype", np_randn_dtypes) +def test_random_uniform_dtype(dtype): + backend = jax_backend.JaxBackend() + a = backend.random_uniform((4, 4), dtype=dtype) + assert a.dtype == dtype + + @pytest.mark.parametrize("dtype", np_randn_dtypes) def test_randn_seed(dtype): backend = jax_backend.JaxBackend() @@ -204,6 +225,34 @@ def test_randn_seed(dtype): np.testing.assert_allclose(a, b) +@pytest.mark.parametrize("dtype", np_randn_dtypes) +def test_random_uniform_seed(dtype): + backend = jax_backend.JaxBackend() + a = backend.random_uniform((4, 4), seed=10, dtype=dtype) + b = backend.random_uniform((4, 4), seed=10, dtype=dtype) + np.testing.assert_allclose(a, b) + + +@pytest.mark.parametrize("dtype", np_randn_dtypes) +def test_random_uniform_boundaries(dtype): + lb = 1.2 + ub = 4.8 + backend = jax_backend.JaxBackend() + a = backend.random_uniform((4, 4), seed=10, dtype=dtype) + b = backend.random_uniform((4, 4), (lb, ub), seed=10, dtype=dtype) + assert((a >= 0).all() and (a <= 1).all() and + (b >= lb).all() and (b <= ub).all()) + + +def test_random_uniform_behavior(): + seed = 10 + key = jax.random.PRNGKey(seed) + backend = jax_backend.JaxBackend() + a = backend.random_uniform((4, 4), seed=seed) + b = jax.random.uniform(key, (4, 4)) + np.testing.assert_allclose(a, b) + + def test_conj(): backend = jax_backend.JaxBackend() real = np.random.rand(2, 2, 2) diff --git a/tensornetwork/backends/numpy/numpy_backend.py b/tensornetwork/backends/numpy/numpy_backend.py index 7d0527b83..e7c4e8ecd 100644 --- a/tensornetwork/backends/numpy/numpy_backend.py +++ b/tensornetwork/backends/numpy/numpy_backend.py @@ -132,6 +132,24 @@ def randn(self, dtype) + 1j * self.np.random.randn(*shape).astype(dtype) return self.np.random.randn(*shape).astype(dtype) + def random_uniform(self, + shape: Tuple[int, ...], + boundaries: Optional[Tuple[float, float]] = (0.0, 1.0), + dtype: Optional[numpy.dtype] = None, + seed: Optional[int] = None) -> Tensor: + + if seed: + self.np.random.seed(seed) + dtype = dtype if dtype is not None else self.np.float64 + if ((self.np.dtype(dtype) is self.np.dtype(self.np.complex128)) or + (self.np.dtype(dtype) is self.np.dtype(self.np.complex64))): + return self.np.random.uniform(boundaries[0], boundaries[1], shape).astype( + dtype) + 1j * self.np.random.uniform(boundaries[0], + boundaries[1], + shape).astype(dtype) + return self.np.random.uniform(boundaries[0], + boundaries[1], shape).astype(dtype) + def conj(self, tensor: Tensor) -> Tensor: return self.np.conj(tensor) diff --git a/tensornetwork/backends/numpy/numpy_backend_test.py b/tensornetwork/backends/numpy/numpy_backend_test.py index 55fe9edb6..49645d876 100644 --- a/tensornetwork/backends/numpy/numpy_backend_test.py +++ b/tensornetwork/backends/numpy/numpy_backend_test.py @@ -159,6 +159,13 @@ def test_randn(dtype): assert a.shape == (4, 4) +@pytest.mark.parametrize("dtype", np_dtypes) +def test_random_uniform(dtype): + backend = numpy_backend.NumPyBackend() + a = backend.random_uniform((4, 4), dtype=dtype, seed=10) + assert a.shape == (4, 4) + + @pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) def test_randn_non_zero_imag(dtype): backend = numpy_backend.NumPyBackend() @@ -166,6 +173,13 @@ def test_randn_non_zero_imag(dtype): assert np.linalg.norm(np.imag(a)) != 0.0 +@pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) +def test_random_uniform_non_zero_imag(dtype): + backend = numpy_backend.NumPyBackend() + a = backend.random_uniform((4, 4), dtype=dtype, seed=10) + assert np.linalg.norm(np.imag(a)) != 0.0 + + @pytest.mark.parametrize("dtype", np_dtypes) def test_eye_dtype(dtype): backend = numpy_backend.NumPyBackend() @@ -194,6 +208,13 @@ def test_randn_dtype(dtype): assert a.dtype == dtype +@pytest.mark.parametrize("dtype", np_dtypes) +def test_random_uniform_dtype(dtype): + backend = numpy_backend.NumPyBackend() + a = backend.random_uniform((4, 4), dtype=dtype, seed=10) + assert a.dtype == dtype + + @pytest.mark.parametrize("dtype", np_randn_dtypes) def test_randn_seed(dtype): backend = numpy_backend.NumPyBackend() @@ -202,6 +223,33 @@ def test_randn_seed(dtype): np.testing.assert_allclose(a, b) +@pytest.mark.parametrize("dtype", np_dtypes) +def test_random_uniform_seed(dtype): + backend = numpy_backend.NumPyBackend() + a = backend.random_uniform((4, 4), seed=10, dtype=dtype) + b = backend.random_uniform((4, 4), seed=10, dtype=dtype) + np.testing.assert_allclose(a, b) + + +@pytest.mark.parametrize("dtype", np_randn_dtypes) +def test_random_uniform_boundaries(dtype): + lb = 1.2 + ub = 4.8 + backend = numpy_backend.NumPyBackend() + a = backend.random_uniform((4, 4), seed=10, dtype=dtype) + b = backend.random_uniform((4, 4), (lb, ub), seed=10, dtype=dtype) + assert((a >= 0).all() and (a <= 1).all() and + (b >= lb).all() and (b <= ub).all()) + + +def test_random_uniform_behavior(): + backend = numpy_backend.NumPyBackend() + a = backend.random_uniform((4, 4), seed=10) + np.random.seed(10) + b = np.random.uniform(size=(4, 4)) + np.testing.assert_allclose(a, b) + + def test_conj(): backend = numpy_backend.NumPyBackend() real = np.random.rand(2, 2, 2) diff --git a/tensornetwork/backends/pytorch/pytorch_backend.py b/tensornetwork/backends/pytorch/pytorch_backend.py index 9a979b1c8..0caba598a 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend.py +++ b/tensornetwork/backends/pytorch/pytorch_backend.py @@ -128,6 +128,16 @@ def randn(self, dtype = dtype if dtype is not None else self.torch.float64 return self.torch.randn(shape, dtype=dtype) + def random_uniform(self, + shape: Tuple[int, ...], + boundaries: Optional[Tuple[float, float]] = (0.0, 1.0), + dtype: Optional[Any] = None, + seed: Optional[int] = None) -> Tensor: + if seed: + self.torch.manual_seed(seed) + dtype = dtype if dtype is not None else self.torch.float64 + return self.torch.empty(shape, dtype=dtype).uniform_(*boundaries) + def conj(self, tensor: Tensor) -> Tensor: return tensor #pytorch does not support complex dtypes diff --git a/tensornetwork/backends/pytorch/pytorch_backend_test.py b/tensornetwork/backends/pytorch/pytorch_backend_test.py index 5e3ead3f9..ca0cd92f3 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend_test.py +++ b/tensornetwork/backends/pytorch/pytorch_backend_test.py @@ -148,6 +148,13 @@ def test_randn(dtype): assert a.shape == (4, 4) +@pytest.mark.parametrize("dtype", torch_randn_dtypes) +def test_random_uniform(dtype): + backend = pytorch_backend.PyTorchBackend() + a = backend.random_uniform((4, 4), dtype=dtype) + assert a.shape == (4, 4) + + @pytest.mark.parametrize("dtype", torch_eye_dtypes) def test_eye_dtype(dtype): backend = pytorch_backend.PyTorchBackend() @@ -176,6 +183,13 @@ def test_randn_dtype(dtype): assert a.dtype == dtype +@pytest.mark.parametrize("dtype", torch_randn_dtypes) +def test_random_uniform_dtype(dtype): + backend = pytorch_backend.PyTorchBackend() + a = backend.random_uniform((4, 4), dtype=dtype) + assert a.dtype == dtype + + @pytest.mark.parametrize("dtype", torch_randn_dtypes) def test_randn_seed(dtype): backend = pytorch_backend.PyTorchBackend() @@ -184,6 +198,33 @@ def test_randn_seed(dtype): np.testing.assert_allclose(a, b) +@pytest.mark.parametrize("dtype", torch_randn_dtypes) +def test_random_uniform_seed(dtype): + backend = pytorch_backend.PyTorchBackend() + a = backend.random_uniform((4, 4), seed=10, dtype=dtype) + b = backend.random_uniform((4, 4), seed=10, dtype=dtype) + torch.allclose(a, b) + + +@pytest.mark.parametrize("dtype", torch_randn_dtypes) +def test_random_uniform_boundaries(dtype): + lb = 1.2 + ub = 4.8 + backend = pytorch_backend.PyTorchBackend() + a = backend.random_uniform((4, 4), seed=10, dtype=dtype) + b = backend.random_uniform((4, 4), (lb, ub), seed=10, dtype=dtype) + assert(torch.ge(a, 0).byte().all() and torch.le(a, 1).byte().all() and + torch.ge(b, lb).byte().all() and torch.le(b, ub).byte().all()) + + +def test_random_uniform_behavior(): + backend = pytorch_backend.PyTorchBackend() + a = backend.random_uniform((4, 4), seed=10) + torch.manual_seed(10) + b = torch.empty((4, 4), dtype=torch.float64).uniform_() + torch.allclose(a, b) + + def test_conj(): backend = pytorch_backend.PyTorchBackend() real = np.random.rand(2, 2, 2) diff --git a/tensornetwork/backends/shell/shell_backend.py b/tensornetwork/backends/shell/shell_backend.py index 4e73f638d..3365fae5e 100644 --- a/tensornetwork/backends/shell/shell_backend.py +++ b/tensornetwork/backends/shell/shell_backend.py @@ -207,6 +207,13 @@ def randn(self, seed: Optional[int] = None) -> Tensor: return ShellTensor(shape) + def random_uniform(self, + shape: Tuple[int, ...], + boundaries: Optional[Tuple[float, float]] = (0.0, 1.0), + dtype: Optional[Type[np.number]] = None, + seed: Optional[int] = None) -> Tensor: + return ShellTensor(shape) + def conj(self, tensor: Tensor) -> Tensor: return tensor diff --git a/tensornetwork/backends/shell/shell_backend_test.py b/tensornetwork/backends/shell/shell_backend_test.py index 45a713965..3974dc1f7 100644 --- a/tensornetwork/backends/shell/shell_backend_test.py +++ b/tensornetwork/backends/shell/shell_backend_test.py @@ -157,6 +157,11 @@ def test_randn(): assertBackendsAgree("randn", args) +def test_random_uniform(): + args = {"shape": (10, 4)} + assertBackendsAgree("random_uniform", args) + + def test_eigsh_lanczos_1(): backend = shell_backend.ShellBackend() D = 16 diff --git a/tensornetwork/backends/tensorflow/tensorflow_backend.py b/tensornetwork/backends/tensorflow/tensorflow_backend.py index 2602984ed..5f7cd1201 100644 --- a/tensornetwork/backends/tensorflow/tensorflow_backend.py +++ b/tensornetwork/backends/tensorflow/tensorflow_backend.py @@ -131,6 +131,26 @@ def randn(self, self.tf.random.normal(shape=shape, dtype=dtype.real_dtype)) return self.tf.random.normal(shape=shape, dtype=dtype) + def random_uniform(self, + shape: Tuple[int, ...], + boundaries: Optional[Tuple[float, float]] = (0.0, 1.0), + dtype: Optional[Type[np.number]] = None, + seed: Optional[int] = None) -> Tensor: + if seed: + self.tf.random.set_seed(seed) + + dtype = dtype if dtype is not None else self.tf.float64 + if (dtype is self.tf.complex128) or (dtype is self.tf.complex64): + return self.tf.complex( + self.tf.random.uniform(shape=shape, minval=boundaries[0], + maxval=boundaries[1], dtype=dtype.real_dtype), + self.tf.random.uniform(shape=shape, minval=boundaries[0], + maxval=boundaries[1], dtype=dtype.real_dtype)) + self.tf.random.set_seed(10) + a = self.tf.random.uniform(shape=shape, minval=boundaries[0], + maxval=boundaries[1], dtype=dtype) + return a + def conj(self, tensor: Tensor) -> Tensor: return self.tf.math.conj(tensor) diff --git a/tensornetwork/backends/tensorflow/tensorflow_backend_test.py b/tensornetwork/backends/tensorflow/tensorflow_backend_test.py index 8df3fcce0..25110d66c 100644 --- a/tensornetwork/backends/tensorflow/tensorflow_backend_test.py +++ b/tensornetwork/backends/tensorflow/tensorflow_backend_test.py @@ -151,6 +151,13 @@ def test_randn(dtype): assert a.shape == (4, 4) +@pytest.mark.parametrize("dtype", tf_dtypes) +def test_random_uniform(dtype): + backend = tensorflow_backend.TensorFlowBackend() + a = backend.random_uniform((4, 4), dtype=dtype, seed=10) + assert a.shape == (4, 4) + + @pytest.mark.parametrize("dtype", [tf.complex64, tf.complex128]) def test_randn_non_zero_imag(dtype): backend = tensorflow_backend.TensorFlowBackend() @@ -158,6 +165,13 @@ def test_randn_non_zero_imag(dtype): assert tf.math.greater(tf.linalg.norm(tf.math.imag(a)), 0.0) +@pytest.mark.parametrize("dtype", [tf.complex64, tf.complex128]) +def test_random_uniform_non_zero_imag(dtype): + backend = tensorflow_backend.TensorFlowBackend() + a = backend.random_uniform((4, 4), dtype=dtype, seed=10) + assert tf.math.greater(tf.linalg.norm(tf.math.imag(a)), 0.0) + + @pytest.mark.parametrize("dtype", tf_dtypes) def test_eye_dtype(dtype): backend = tensorflow_backend.TensorFlowBackend() @@ -186,6 +200,13 @@ def test_randn_dtype(dtype): assert a.dtype == dtype +@pytest.mark.parametrize("dtype", tf_dtypes) +def test_random_uniform_dtype(dtype): + backend = tensorflow_backend.TensorFlowBackend() + a = backend.random_uniform((4, 4), dtype=dtype, seed=10) + assert a.dtype == dtype + + @pytest.mark.parametrize("dtype", tf_randn_dtypes) def test_randn_seed(dtype): backend = tensorflow_backend.TensorFlowBackend() @@ -194,6 +215,27 @@ def test_randn_seed(dtype): np.testing.assert_allclose(a, b) +@pytest.mark.parametrize("dtype", tf_dtypes) +def test_random_uniform_seed(dtype): + test = tf.test.TestCase() + backend = tensorflow_backend.TensorFlowBackend() + a = backend.random_uniform((4, 4), seed=10, dtype=dtype) + b = backend.random_uniform((4, 4), seed=10, dtype=dtype) + test.assertAllCloseAccordingToType(a, b) + + +@pytest.mark.parametrize("dtype", tf_randn_dtypes) +def test_random_uniform_boundaries(dtype): + test = tf.test.TestCase() + lb = 1.2 + ub = 4.8 + backend = tensorflow_backend.TensorFlowBackend() + a = backend.random_uniform((4, 4), seed=10, dtype=dtype) + b = backend.random_uniform((4, 4), (lb, ub), seed=10, dtype=dtype) + test.assertAllInRange(a, 0, 1) + test.assertAllInRange(b, lb, ub) + + def test_conj(): backend = tensorflow_backend.TensorFlowBackend() real = np.random.rand(2, 2, 2)