Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.
87 changes: 87 additions & 0 deletions tensornetwork/tests/network_operations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tensornetwork as tn
import pytest
import numpy as np
from tensornetwork.backends.base_backend import BaseBackend


def test_split_node_full_svd_names(backend):
Expand Down Expand Up @@ -334,3 +335,89 @@ def test_switch_backend(backend):
nodes = [a, b, c]
tn.switch_backend(nodes, backend)
assert nodes[0].backend.name == backend


def test_norm_of_node_without_backend_raises_error():
node = np.random.rand(3, 3, 3)
with pytest.raises(AttributeError):
tn.norm(node)


def test_conj_of_node_without_backend_raises_error():
node = np.random.rand(3, 3, 3)
with pytest.raises(AttributeError):
tn.conj(node)


def test_transpose_of_node_without_backend_raises_error():
node = np.random.rand(3, 3, 3)
with pytest.raises(AttributeError):
tn.transpose(node, permutation=[])


def test_split_node_of_node_without_backend_raises_error():
node = np.random.rand(3, 3, 3)
with pytest.raises(AttributeError):
tn.split_node(node, left_edges=[], right_edges=[])


def test_split_node_qr_of_node_without_backend_raises_error():
node = np.random.rand(3, 3, 3)
with pytest.raises(AttributeError):
tn.split_node_qr(node, left_edges=[], right_edges=[])


def test_split_node_rq_of_node_without_backend_raises_error():
node = np.random.rand(3, 3, 3)
with pytest.raises(AttributeError):
tn.split_node_rq(node, left_edges=[], right_edges=[])


def test_split_node_full_svd_of_node_without_backend_raises_error():
node = np.random.rand(3, 3, 3)
with pytest.raises(AttributeError):
tn.split_node_full_svd(node, left_edges=[], right_edges=[])


def test_reachable_raises_value_error():
with pytest.raises(ValueError):
tn.reachable({})


def test_check_correct_raises_value_error_1(backend):
a = tn.Node(np.random.rand(3, 3, 3), backend=backend)
b = tn.Node(np.random.rand(3, 3, 3), backend=backend)
edge = a.edges[0]
edge.node1 = b
edge.node2 = b
with pytest.raises(ValueError):
tn.check_correct({a, b})


def test_check_correct_raises_value_error_2(backend):
a = tn.Node(np.random.rand(3, 3, 3), backend=backend)
b = tn.Node(np.random.rand(3, 3, 3), backend=backend)
edge = a.edges[0]
edge.axis1 = -1
with pytest.raises(ValueError):
tn.check_correct({a, b})


def test_get_all_nodes(backend):
a = tn.Node(np.random.rand(3, 3, 3), backend=backend)
b = tn.Node(np.random.rand(3, 3, 3), backend=backend)
edge = tn.connect(a[0], b[0])
assert tn.get_all_nodes({edge}) == {a, b}


def test_contract_trace_edges(backend):
a = tn.Node(np.random.rand(3, 3, 3), backend=backend)
with pytest.raises(ValueError):
tn.contract_trace_edges(a)


def test_switch_backend_raises_error(backend):
a = tn.Node(np.random.rand(3, 3, 3))
a.backend = BaseBackend()
with pytest.raises(NotImplementedError):
tn.switch_backend({a}, backend)