Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
035b168
added test for mps switch backend
MichaelMarien Nov 23, 2019
cd6558d
added switch backend method to MPS
MichaelMarien Nov 23, 2019
304a0e5
added test for network operations switch backend
MichaelMarien Nov 23, 2019
7a11d58
make sure switch_backend not only fixes tensor but also node property
MichaelMarien Nov 23, 2019
fcfb3ce
added switch_backend to init
MichaelMarien Nov 23, 2019
c36433c
Merge branch 'master' of https://github.com/google/TensorNetwork into…
MichaelMarien Dec 4, 2019
182ff83
Merge branch 'master' of https://github.com/google/TensorNetwork into…
MichaelMarien Jan 20, 2020
2cd10ee
Merge branch 'master' of https://github.com/google/TensorNetwork into…
MichaelMarien Jan 21, 2020
ddbb090
Merge branch 'master' of https://github.com/google/TensorNetwork into…
MichaelMarien Jan 25, 2020
a1c527b
missing test for backend contextmanager
MichaelMarien Jan 25, 2020
631477f
notimplemented tests for base backend
MichaelMarien Jan 25, 2020
a0a9423
added subtraction test notimplemented
MichaelMarien Jan 25, 2020
45381be
added jax backend index_update test
MichaelMarien Jan 25, 2020
90356c0
first missing tests for numpy
MichaelMarien Jan 25, 2020
6676520
actually catched an error in numpy_backend eigs method!
MichaelMarien Jan 25, 2020
1501202
more eigs tests
MichaelMarien Jan 25, 2020
2e8b86b
didnt catch an error, unexpected convention
MichaelMarien Jan 25, 2020
942a1d3
more tests for eigsh_lancszos
MichaelMarien Jan 25, 2020
8473aa4
added missing pytorch backend tests
MichaelMarien Jan 25, 2020
3b7b3ce
added missing tf backend tests
MichaelMarien Jan 25, 2020
f5b20b8
pytype
MichaelMarien Jan 25, 2020
2a33a4d
suppress pytype
MichaelMarien Jan 25, 2020
bba6b68
Merge branch 'master' into backend-test
Jan 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions tensornetwork/backends/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import numpy as np
from tensornetwork import connect, contract, Node
from tensornetwork.backends.base_backend import BaseBackend


def clean_tensornetwork_modules():
Expand Down Expand Up @@ -146,3 +147,200 @@ def test_basic_network_without_backends_raises_error():
Node(np.ones((2, 2)), backend="tensorflow")
with pytest.raises(ImportError):
Node(np.ones((2, 2)), backend="pytorch")
[]

def test_base_backend_name():
backend = BaseBackend()
assert backend.name == "base backend"


def test_base_backend_tensordot_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.tensordot(np.ones((2, 2)), np.ones((2, 2)), axes=[[0], [0]])


def test_base_backend_reshape_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.reshape(np.ones((2, 2)), (4, 1))


def test_base_backend_transpose_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.transpose(np.ones((2, 2)), [0, 1])


def test_base_backend_svd_decompositon_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.svd_decomposition(np.ones((2, 2)), 0)


def test_base_backend_qr_decompositon_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.qr_decomposition(np.ones((2, 2)), 0)


def test_base_backend_rq_decompositon_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.rq_decomposition(np.ones((2, 2)), 0)


def test_base_backend_shape_concat_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.shape_concat([np.ones((2, 2)), np.ones((2, 2))], 0)


def test_base_backend_shape_tensor_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.shape_tensor(np.ones((2, 2)))


def test_base_backend_shape_tuple_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.shape_tuple(np.ones((2, 2)))


def test_base_backend_shape_prod_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.shape_prod(np.ones((2, 2)))


def test_base_backend_sqrt_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.sqrt(np.ones((2, 2)))


def test_base_backend_diag_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.diag(np.ones((2, 2)))


def test_base_backend_convert_to_tensor_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.convert_to_tensor(np.ones((2, 2)))


def test_base_backend_trace_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.trace(np.ones((2, 2)))


def test_base_backend_outer_product_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.outer_product(np.ones((2, 2)), np.ones((2, 2)))


def test_base_backend_einsul_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.einsum("ii", np.ones((2, 2)))


def test_base_backend_norm_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.norm(np.ones((2, 2)))


def test_base_backend_eye_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.eye(2, dtype=np.float64)


def test_base_backend_ones_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.ones((2, 2), dtype=np.float64)


def test_base_backend_zeros_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.zeros((2, 2), dtype=np.float64)


def test_base_backend_randn_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.randn((2, 2))


def test_base_backend_random_uniforl_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.random_uniform((2, 2))


def test_base_backend_conj_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.conj(np.ones((2, 2)))


def test_base_backend_eigh_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.eigh(np.ones((2, 2)))


def test_base_backend_eigs_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.eigs(np.ones((2, 2)))


def test_base_backend_eigs_lanczos_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.eigsh_lanczos(np.ones((2, 2)))


def test_base_backend_addition_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.addition(np.ones((2, 2)), np.ones((2, 2)))


def test_base_backend_subtraction_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.subtraction(np.ones((2, 2)), np.ones((2, 2)))


def test_base_backend_multiply_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.multiply(np.ones((2, 2)), np.ones((2, 2)))


def test_base_backend_divide_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.divide(np.ones((2, 2)), np.ones((2, 2)))


def test_base_backend_index_update_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.index_update(np.ones((2, 2)), np.ones((2, 2)), np.ones((2, 2)))


def test_base_backend_inv_not_implemented():
backend = BaseBackend()
with pytest.raises(NotImplementedError):
backend.inv(np.ones((2, 2)))
24 changes: 24 additions & 0 deletions tensornetwork/backends/jax/jax_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,27 @@ def index_update(dtype):
tensor = np.array(tensor)
tensor[tensor > 0.1] = 0.0
np.testing.assert_allclose(tensor, out)


def test_base_backend_eigs_not_implemented():
backend = jax_backend.JaxBackend()
tensor = backend.randn((4, 2, 3), dtype=np.float64)
with pytest.raises(NotImplementedError):
backend.eigs(tensor)


def test_base_backend_eigsh_lanczos_not_implemented():
backend = jax_backend.JaxBackend()
tensor = backend.randn((4, 2, 3), dtype=np.float64)
with pytest.raises(NotImplementedError):
backend.eigsh_lanczos(tensor)


@pytest.mark.parametrize("dtype", np_dtypes)
def test_index_update(dtype):
backend = jax_backend.JaxBackend()
tensor = backend.randn((4, 2, 3), dtype=dtype, seed=10)
out = backend.index_update(tensor, tensor > 0.1, 0.0)
np_tensor = np.array(tensor)
np_tensor[np_tensor > 0.1] = 0.0
np.testing.assert_allclose(out, np_tensor)
Loading