Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
355 changes: 351 additions & 4 deletions tensornetwork/tests/network_components_free_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np
import tensorflow as tf
import pytest
from unittest.mock import patch
from collections import namedtuple
import h5py
import re
#pylint: disable=line-too-long
from tensornetwork.network_components import Node, CopyNode, Edge, NodeCollection
from tensornetwork.network_components import Node, CopyNode, Edge, NodeCollection, BaseNode, _remove_trace_edge, _remove_edges
import tensornetwork as tn
from tensornetwork.backends.base_backend import BaseBackend

string_type = h5py.special_dtype(vlen=str)

Expand All @@ -15,6 +17,34 @@
'node1 node2 edge1 edge12 tensor')


class TestNode(BaseNode):

def get_tensor(self): #pylint: disable=useless-super-delegation
return super().get_tensor()

def set_tensor(self, tensor): #pylint: disable=useless-super-delegation
return super().set_tensor(tensor)

@property
def shape(self):
return super().shape

@property
def tensor(self):
return super().tensor

@tensor.setter
def tensor(self, tensor):
return super(TestNode, type(self)).tensor.fset(self, tensor)

def _load_node(self, node_data):# pylint: disable=useless-super-delegation
return super()._load_node(node_data)

def _save_node(self, node_group): #pylint: disable=useless-super-delegation
return super()._save_node(node_group)



@pytest.fixture(name='single_node_edge')
def fixture_single_node_edge(backend):
tensor = np.ones((1, 2, 2))
Expand Down Expand Up @@ -249,6 +279,16 @@ def test_node_reorder_edges_raise_error_trace_edge(single_node_edge):
assert "Edge reordering does not support trace edges." in str(e.value)


def test_node_reorder_edges_raise_error_no_tensor(single_node_edge):
node = single_node_edge.node
e2 = tn.connect(node[1], node[2])
e3 = node[0]
del node._tensor
with pytest.raises(AttributeError) as e:
node.reorder_edges([e2, e3])
assert "Please provide a valid tensor for this Node." in str(e.value)


def test_node_magic_getitem(single_node_edge):
node = single_node_edge.node
edge = single_node_edge.edge
Expand Down Expand Up @@ -279,13 +319,40 @@ def test_node_magic_lt(double_node_edge):
def test_node_magic_lt_raises_error_not_node(single_node_edge):
node = single_node_edge.node
with pytest.raises(ValueError):
assert node < 0
node < 0


def test_node_magic_matmul_raises_error_not_node(single_node_edge):
node = single_node_edge.node
with pytest.raises(TypeError):
assert node @ 0
node @ 0


def test_node_magic_matmul_raises_error_no_tensor(single_node_edge):
node = single_node_edge.node
del node._tensor
with pytest.raises(AttributeError):
node @ node


def test_node_magic_matmul_raises_error_disabled_node(single_node_edge):
node = single_node_edge.node
node.is_disabled = True
with pytest.raises(ValueError):
node @ node


def test_node_edges_getter_raises_error_disabled_node(single_node_edge):
node = single_node_edge.node
node.is_disabled = True
with pytest.raises(ValueError):
node.edges

def test_node_edges_setter_raises_error_disabled_node(single_node_edge):
node = single_node_edge.node
node.is_disabled = True
with pytest.raises(ValueError):
node.edges = []


def test_node_magic_matmul_raises_error_different_network(single_node_edge):
Expand Down Expand Up @@ -918,4 +985,284 @@ def test_repr_for_Nodes_and_Edges(double_node_edge):
assert "[[[1.,1.],[1.,1.]]]" in str(node1) and str(node2)
assert "Edge(DanglingEdge)[0]" in str(node1) and str(node2)
assert "Edge('test_node1'[1]->'test_node2'[1])" in str(node1) and str(node2)
assert "Edge(DanglingEdge)[2]" in str(node1) and str(node2)
assert "Edge(DanglingEdge)[2]" in str(node1) and str(node2)


def test_base_node_name_list_throws_error():
with pytest.raises(TypeError,):
TestNode(name=["A"], axis_names=['a', 'b']) # pytype: disable=wrong-arg-types


def test_base_node_name_int_throws_error():
with pytest.raises(TypeError):
TestNode(name=1, axis_names=['a', 'b']) # pytype: disable=wrong-arg-types


def test_base_node_axis_names_int_throws_error():
with pytest.raises(TypeError):
TestNode(axis_names=[0, 1]) # pytype: disable=wrong-arg-types


def test_base_node_no_axis_names_no_shapes_throws_error():
with pytest.raises(ValueError):
TestNode(name='a')


def test_node_add_axis_names_int_throws_error():
n1 = Node(np.eye(2), axis_names=['a', 'b'])
with pytest.raises(TypeError):
n1.add_axis_names([0, 1]) # pytype: disable=wrong-arg-types


def test_node_axis_names_setter_throws_shape_large_mismatch_error():
n1 = Node(np.eye(2), axis_names=['a', 'b'])
with pytest.raises(ValueError):
n1.axis_names = ['a', 'b', 'c']


def test_node_axis_names_setter_throws_shape_small_mismatch_error():
n1 = Node(np.eye(2), axis_names=['a', 'b'])
with pytest.raises(ValueError):
n1.axis_names = ['a']


def test_node_axis_names_setter_throws_value_error():
n1 = Node(np.eye(2), axis_names=['a', 'b'])
with pytest.raises(TypeError):
n1.axis_names = [0, 1]


def test_node_dtype(backend):
n1 = Node(np.random.rand(2), backend=backend)
assert n1.dtype == n1.tensor.dtype


@pytest.mark.parametrize("name", [1, ['1']])
def test_node_set_name_raises_type_error(backend, name):
n1 = Node(np.random.rand(2), backend=backend)
with pytest.raises(TypeError):
n1.set_name(name)


@pytest.mark.parametrize("name", [1, ['1']])
def test_node_name_setter_raises_type_error(backend, name):
n1 = Node(np.random.rand(2), backend=backend)
with pytest.raises(TypeError):
n1.name = name


def test_base_node_get_tensor():
n1 = TestNode(name="n1", axis_names=['a'], shape=(1,))
assert n1.get_tensor() is None


def test_base_node_set_tensor():
n1 = TestNode(name="n1", axis_names=['a'], shape=(1,))
assert n1.set_tensor(np.random.rand(2)) is None
assert n1.tensor is None


def test_base_node_shape():
n1 = TestNode(name="n1", axis_names=['a'], shape=(1,))
n1._shape = None
with pytest.raises(ValueError):
n1.shape


def test_base_node_tensor_getter():
n1 = TestNode(name="n1", axis_names=['a'], shape=(1,))
assert n1.tensor is None


def test_base_node_tensor_setter():
n1 = TestNode(name="n1", axis_names=['a'], shape=(1,))
n1.tensor = np.random.rand(2)
assert n1.tensor is None


def test_node_has_dangling_edge_false(double_node_edge):
node1 = double_node_edge.node1
node2 = double_node_edge.node2
tn.connect(node1["a"], node2["a"])
tn.connect(node1["c"], node2["c"])
assert not node1.has_dangling_edge()


def test_node_has_dangling_edge_true(single_node_edge):
assert single_node_edge.node.has_dangling_edge()


def test_node_get_item(single_node_edge):
node = single_node_edge.node
edge = single_node_edge.edge
node.add_edge(edge, axis=0)
assert node[0] == edge
assert edge in node[0:2]


def test_node_signature_getter_disabled_throws_error(single_node_edge):
node = single_node_edge.node
node.is_disabled = True
with pytest.raises(ValueError):
node.signature


def test_node_signature_setter_disabled_throws_error(single_node_edge):
node = single_node_edge.node
node.is_disabled = True
with pytest.raises(ValueError):
node.signature = "signature"


def test_node_disabled_disabled_throws_error(single_node_edge):
node = single_node_edge.node
node.is_disabled = True
with pytest.raises(ValueError):
node.disable()


def test_node_disabled_shape_throws_error(single_node_edge):
node = single_node_edge.node
node.is_disabled = True
with pytest.raises(ValueError):
node.shape


def test_copy_node_get_partners_with_trace(backend):
node1 = CopyNode(4, 2, backend=backend)
node2 = Node(np.random.rand(2, 2), backend=backend, name="node2")
tn.connect(node1[0], node1[1])
tn.connect(node1[2], node2[0])
tn.connect(node1[3], node2[1])
assert node1.get_partners() == {node2: {0, 1}}


@pytest.mark.parametrize("name", [1, ['1']])
def test_edge_name_throws_type_error(single_node_edge, name):
with pytest.raises(TypeError):
Edge(node1=single_node_edge.node, axis1=0, name=name)


def test_edge_name_setter_disabled_throws_error(single_node_edge):
edge = Edge(node1=single_node_edge.node, axis1=0)
edge.is_disabled = True
with pytest.raises(ValueError):
edge.name = 'edge'


def test_edge_name_getter_disabled_throws_error(single_node_edge):
edge = Edge(node1=single_node_edge.node, axis1=0)
edge.is_disabled = True
with pytest.raises(ValueError):
edge.name


@pytest.mark.parametrize("name", [1, ['1']])
def test_edge_name_setter_throws_type_error(single_node_edge, name):
edge = Edge(node1=single_node_edge.node, axis1=0)
with pytest.raises(TypeError):
edge.name = name


def test_edge_signature_getter_disabled_throws_error(single_node_edge):
edge = Edge(node1=single_node_edge.node, axis1=0)
edge.is_disabled = True
with pytest.raises(ValueError):
edge.signature


def test_edge_signature_setter_disabled_throws_error(single_node_edge):
edge = Edge(node1=single_node_edge.node, axis1=0)
edge.is_disabled = True
with pytest.raises(ValueError):
edge.signature = "signature"


def test_edge_node1_throws_value_error(single_node_edge):
edge = Edge(node1=single_node_edge.node, axis1=0, name="edge")
edge._node1 = None
err_msg = "node1 for edge 'edge' no longer exists."
with pytest.raises(ValueError, match=err_msg):
edge.node1



def test_edge_node2_throws_value_error(single_node_edge):
edge = tn.connect(single_node_edge.node[1], single_node_edge.node[2])
edge.name = 'edge'
edge._node2 = None
err_msg = "node2 for edge 'edge' no longer exists."
with pytest.raises(ValueError, match=err_msg):
edge.node2


@pytest.mark.parametrize("name", [1, ['1']])
def test_edge_set_name_throws_type_error(single_node_edge, name):
edge = Edge(node1=single_node_edge.node, axis1=0)
with pytest.raises(TypeError):
edge.set_name(name)


@patch.object(Edge, "name", None)
def test_edge_str(single_node_edge):
single_node_edge.edge.name = None
assert str(single_node_edge.edge) == "__unnamed_edge__"


def test_get_all_dangling_single_node(single_node_edge):
node = single_node_edge.node
assert set(tn.get_all_dangling({node})) == set(node.edges)


def test_get_all_dangling_double_node(double_node_edge):
node1 = double_node_edge.node1
node2 = double_node_edge.node2
assert set(tn.get_all_dangling({node1, node2})) == {node1[0], node1[2],
node2[0], node2[2]}


def test_flatten_edges_different_backend_raises_value_error(single_node_edge):
node1 = single_node_edge.node
node2 = tn.Node(np.random.rand(2, 2, 2))
node2.backend = BaseBackend()
with pytest.raises(ValueError):
tn.flatten_edges(node1.get_all_edges()+node2.get_all_edges())


def test_split_edge_trivial(single_node_edge):
edge = single_node_edge.edge
assert tn.split_edge(edge, (1,)) == [edge]


def test_split_edge_different_backend_raises_value_error(single_node_edge):
if single_node_edge.node.backend.name == "numpy":
pytest.skip("numpy comparing to all the others")
node1 = single_node_edge.node
node2 = tn.Node(np.random.rand(2, 2, 2), backend="numpy")
edge = tn.connect(node1[1], node2[1])
with pytest.raises(ValueError, match="Not all backends are the same."):
tn.split_edge(edge, (2, 1))


def test_remove_trace_edge_dangling_edge_raises_value_error(single_node_edge):
node = single_node_edge.node
edge = node[0]
edge.name = "e"
with pytest.raises(ValueError, match="Attempted to remove dangling edge 'e"):
_remove_trace_edge(edge, node)


def test_remove_trace_edge_non_trace_raises_value_error(double_node_edge):
node1 = double_node_edge.node1
node2 = double_node_edge.node2
edge = tn.connect(node1[0], node2[0])
edge.name = "e"
with pytest.raises(ValueError, match="Edge 'e' is not a trace edge."):
_remove_trace_edge(edge, node1)


def test_remove_edges_trace_raises_value_error(single_node_edge):
node = single_node_edge.node
edge = tn.connect(node[1], node[2])
with pytest.raises(ValueError):
_remove_edges(edge, node, node, node) # pytype: disable=wrong-arg-types
Loading