diff --git a/tensornetwork/tests/network_components_free_test.py b/tensornetwork/tests/network_components_free_test.py index 47632debe..a25993c4b 100644 --- a/tensornetwork/tests/network_components_free_test.py +++ b/tensornetwork/tests/network_components_free_test.py @@ -1,12 +1,14 @@ import numpy as np import tensorflow as tf import pytest +from unittest.mock import patch from collections import namedtuple import h5py import re #pylint: disable=line-too-long -from tensornetwork.network_components import Node, CopyNode, Edge, NodeCollection +from tensornetwork.network_components import Node, CopyNode, Edge, NodeCollection, BaseNode, _remove_trace_edge, _remove_edges import tensornetwork as tn +from tensornetwork.backends.base_backend import BaseBackend string_type = h5py.special_dtype(vlen=str) @@ -15,6 +17,34 @@ 'node1 node2 edge1 edge12 tensor') +class TestNode(BaseNode): + + def get_tensor(self): #pylint: disable=useless-super-delegation + return super().get_tensor() + + def set_tensor(self, tensor): #pylint: disable=useless-super-delegation + return super().set_tensor(tensor) + + @property + def shape(self): + return super().shape + + @property + def tensor(self): + return super().tensor + + @tensor.setter + def tensor(self, tensor): + return super(TestNode, type(self)).tensor.fset(self, tensor) + + def _load_node(self, node_data):# pylint: disable=useless-super-delegation + return super()._load_node(node_data) + + def _save_node(self, node_group): #pylint: disable=useless-super-delegation + return super()._save_node(node_group) + + + @pytest.fixture(name='single_node_edge') def fixture_single_node_edge(backend): tensor = np.ones((1, 2, 2)) @@ -249,6 +279,16 @@ def test_node_reorder_edges_raise_error_trace_edge(single_node_edge): assert "Edge reordering does not support trace edges." in str(e.value) +def test_node_reorder_edges_raise_error_no_tensor(single_node_edge): + node = single_node_edge.node + e2 = tn.connect(node[1], node[2]) + e3 = node[0] + del node._tensor + with pytest.raises(AttributeError) as e: + node.reorder_edges([e2, e3]) + assert "Please provide a valid tensor for this Node." in str(e.value) + + def test_node_magic_getitem(single_node_edge): node = single_node_edge.node edge = single_node_edge.edge @@ -279,13 +319,40 @@ def test_node_magic_lt(double_node_edge): def test_node_magic_lt_raises_error_not_node(single_node_edge): node = single_node_edge.node with pytest.raises(ValueError): - assert node < 0 + node < 0 def test_node_magic_matmul_raises_error_not_node(single_node_edge): node = single_node_edge.node with pytest.raises(TypeError): - assert node @ 0 + node @ 0 + + +def test_node_magic_matmul_raises_error_no_tensor(single_node_edge): + node = single_node_edge.node + del node._tensor + with pytest.raises(AttributeError): + node @ node + + +def test_node_magic_matmul_raises_error_disabled_node(single_node_edge): + node = single_node_edge.node + node.is_disabled = True + with pytest.raises(ValueError): + node @ node + + +def test_node_edges_getter_raises_error_disabled_node(single_node_edge): + node = single_node_edge.node + node.is_disabled = True + with pytest.raises(ValueError): + node.edges + +def test_node_edges_setter_raises_error_disabled_node(single_node_edge): + node = single_node_edge.node + node.is_disabled = True + with pytest.raises(ValueError): + node.edges = [] def test_node_magic_matmul_raises_error_different_network(single_node_edge): @@ -918,4 +985,284 @@ def test_repr_for_Nodes_and_Edges(double_node_edge): assert "[[[1.,1.],[1.,1.]]]" in str(node1) and str(node2) assert "Edge(DanglingEdge)[0]" in str(node1) and str(node2) assert "Edge('test_node1'[1]->'test_node2'[1])" in str(node1) and str(node2) - assert "Edge(DanglingEdge)[2]" in str(node1) and str(node2) \ No newline at end of file + assert "Edge(DanglingEdge)[2]" in str(node1) and str(node2) + + +def test_base_node_name_list_throws_error(): + with pytest.raises(TypeError,): + TestNode(name=["A"], axis_names=['a', 'b']) # pytype: disable=wrong-arg-types + + +def test_base_node_name_int_throws_error(): + with pytest.raises(TypeError): + TestNode(name=1, axis_names=['a', 'b']) # pytype: disable=wrong-arg-types + + +def test_base_node_axis_names_int_throws_error(): + with pytest.raises(TypeError): + TestNode(axis_names=[0, 1]) # pytype: disable=wrong-arg-types + + +def test_base_node_no_axis_names_no_shapes_throws_error(): + with pytest.raises(ValueError): + TestNode(name='a') + + +def test_node_add_axis_names_int_throws_error(): + n1 = Node(np.eye(2), axis_names=['a', 'b']) + with pytest.raises(TypeError): + n1.add_axis_names([0, 1]) # pytype: disable=wrong-arg-types + + +def test_node_axis_names_setter_throws_shape_large_mismatch_error(): + n1 = Node(np.eye(2), axis_names=['a', 'b']) + with pytest.raises(ValueError): + n1.axis_names = ['a', 'b', 'c'] + + +def test_node_axis_names_setter_throws_shape_small_mismatch_error(): + n1 = Node(np.eye(2), axis_names=['a', 'b']) + with pytest.raises(ValueError): + n1.axis_names = ['a'] + + +def test_node_axis_names_setter_throws_value_error(): + n1 = Node(np.eye(2), axis_names=['a', 'b']) + with pytest.raises(TypeError): + n1.axis_names = [0, 1] + + +def test_node_dtype(backend): + n1 = Node(np.random.rand(2), backend=backend) + assert n1.dtype == n1.tensor.dtype + + +@pytest.mark.parametrize("name", [1, ['1']]) +def test_node_set_name_raises_type_error(backend, name): + n1 = Node(np.random.rand(2), backend=backend) + with pytest.raises(TypeError): + n1.set_name(name) + + +@pytest.mark.parametrize("name", [1, ['1']]) +def test_node_name_setter_raises_type_error(backend, name): + n1 = Node(np.random.rand(2), backend=backend) + with pytest.raises(TypeError): + n1.name = name + + +def test_base_node_get_tensor(): + n1 = TestNode(name="n1", axis_names=['a'], shape=(1,)) + assert n1.get_tensor() is None + + +def test_base_node_set_tensor(): + n1 = TestNode(name="n1", axis_names=['a'], shape=(1,)) + assert n1.set_tensor(np.random.rand(2)) is None + assert n1.tensor is None + + +def test_base_node_shape(): + n1 = TestNode(name="n1", axis_names=['a'], shape=(1,)) + n1._shape = None + with pytest.raises(ValueError): + n1.shape + + +def test_base_node_tensor_getter(): + n1 = TestNode(name="n1", axis_names=['a'], shape=(1,)) + assert n1.tensor is None + + +def test_base_node_tensor_setter(): + n1 = TestNode(name="n1", axis_names=['a'], shape=(1,)) + n1.tensor = np.random.rand(2) + assert n1.tensor is None + + +def test_node_has_dangling_edge_false(double_node_edge): + node1 = double_node_edge.node1 + node2 = double_node_edge.node2 + tn.connect(node1["a"], node2["a"]) + tn.connect(node1["c"], node2["c"]) + assert not node1.has_dangling_edge() + + +def test_node_has_dangling_edge_true(single_node_edge): + assert single_node_edge.node.has_dangling_edge() + + +def test_node_get_item(single_node_edge): + node = single_node_edge.node + edge = single_node_edge.edge + node.add_edge(edge, axis=0) + assert node[0] == edge + assert edge in node[0:2] + + +def test_node_signature_getter_disabled_throws_error(single_node_edge): + node = single_node_edge.node + node.is_disabled = True + with pytest.raises(ValueError): + node.signature + + +def test_node_signature_setter_disabled_throws_error(single_node_edge): + node = single_node_edge.node + node.is_disabled = True + with pytest.raises(ValueError): + node.signature = "signature" + + +def test_node_disabled_disabled_throws_error(single_node_edge): + node = single_node_edge.node + node.is_disabled = True + with pytest.raises(ValueError): + node.disable() + + +def test_node_disabled_shape_throws_error(single_node_edge): + node = single_node_edge.node + node.is_disabled = True + with pytest.raises(ValueError): + node.shape + + +def test_copy_node_get_partners_with_trace(backend): + node1 = CopyNode(4, 2, backend=backend) + node2 = Node(np.random.rand(2, 2), backend=backend, name="node2") + tn.connect(node1[0], node1[1]) + tn.connect(node1[2], node2[0]) + tn.connect(node1[3], node2[1]) + assert node1.get_partners() == {node2: {0, 1}} + + +@pytest.mark.parametrize("name", [1, ['1']]) +def test_edge_name_throws_type_error(single_node_edge, name): + with pytest.raises(TypeError): + Edge(node1=single_node_edge.node, axis1=0, name=name) + + +def test_edge_name_setter_disabled_throws_error(single_node_edge): + edge = Edge(node1=single_node_edge.node, axis1=0) + edge.is_disabled = True + with pytest.raises(ValueError): + edge.name = 'edge' + + +def test_edge_name_getter_disabled_throws_error(single_node_edge): + edge = Edge(node1=single_node_edge.node, axis1=0) + edge.is_disabled = True + with pytest.raises(ValueError): + edge.name + + +@pytest.mark.parametrize("name", [1, ['1']]) +def test_edge_name_setter_throws_type_error(single_node_edge, name): + edge = Edge(node1=single_node_edge.node, axis1=0) + with pytest.raises(TypeError): + edge.name = name + + +def test_edge_signature_getter_disabled_throws_error(single_node_edge): + edge = Edge(node1=single_node_edge.node, axis1=0) + edge.is_disabled = True + with pytest.raises(ValueError): + edge.signature + + +def test_edge_signature_setter_disabled_throws_error(single_node_edge): + edge = Edge(node1=single_node_edge.node, axis1=0) + edge.is_disabled = True + with pytest.raises(ValueError): + edge.signature = "signature" + + +def test_edge_node1_throws_value_error(single_node_edge): + edge = Edge(node1=single_node_edge.node, axis1=0, name="edge") + edge._node1 = None + err_msg = "node1 for edge 'edge' no longer exists." + with pytest.raises(ValueError, match=err_msg): + edge.node1 + + + +def test_edge_node2_throws_value_error(single_node_edge): + edge = tn.connect(single_node_edge.node[1], single_node_edge.node[2]) + edge.name = 'edge' + edge._node2 = None + err_msg = "node2 for edge 'edge' no longer exists." + with pytest.raises(ValueError, match=err_msg): + edge.node2 + + +@pytest.mark.parametrize("name", [1, ['1']]) +def test_edge_set_name_throws_type_error(single_node_edge, name): + edge = Edge(node1=single_node_edge.node, axis1=0) + with pytest.raises(TypeError): + edge.set_name(name) + + +@patch.object(Edge, "name", None) +def test_edge_str(single_node_edge): + single_node_edge.edge.name = None + assert str(single_node_edge.edge) == "__unnamed_edge__" + + +def test_get_all_dangling_single_node(single_node_edge): + node = single_node_edge.node + assert set(tn.get_all_dangling({node})) == set(node.edges) + + +def test_get_all_dangling_double_node(double_node_edge): + node1 = double_node_edge.node1 + node2 = double_node_edge.node2 + assert set(tn.get_all_dangling({node1, node2})) == {node1[0], node1[2], + node2[0], node2[2]} + + +def test_flatten_edges_different_backend_raises_value_error(single_node_edge): + node1 = single_node_edge.node + node2 = tn.Node(np.random.rand(2, 2, 2)) + node2.backend = BaseBackend() + with pytest.raises(ValueError): + tn.flatten_edges(node1.get_all_edges()+node2.get_all_edges()) + + +def test_split_edge_trivial(single_node_edge): + edge = single_node_edge.edge + assert tn.split_edge(edge, (1,)) == [edge] + + +def test_split_edge_different_backend_raises_value_error(single_node_edge): + if single_node_edge.node.backend.name == "numpy": + pytest.skip("numpy comparing to all the others") + node1 = single_node_edge.node + node2 = tn.Node(np.random.rand(2, 2, 2), backend="numpy") + edge = tn.connect(node1[1], node2[1]) + with pytest.raises(ValueError, match="Not all backends are the same."): + tn.split_edge(edge, (2, 1)) + + +def test_remove_trace_edge_dangling_edge_raises_value_error(single_node_edge): + node = single_node_edge.node + edge = node[0] + edge.name = "e" + with pytest.raises(ValueError, match="Attempted to remove dangling edge 'e"): + _remove_trace_edge(edge, node) + + +def test_remove_trace_edge_non_trace_raises_value_error(double_node_edge): + node1 = double_node_edge.node1 + node2 = double_node_edge.node2 + edge = tn.connect(node1[0], node2[0]) + edge.name = "e" + with pytest.raises(ValueError, match="Edge 'e' is not a trace edge."): + _remove_trace_edge(edge, node1) + + +def test_remove_edges_trace_raises_value_error(single_node_edge): + node = single_node_edge.node + edge = tn.connect(node[1], node[2]) + with pytest.raises(ValueError): + _remove_edges(edge, node, node, node) # pytype: disable=wrong-arg-types \ No newline at end of file diff --git a/tensornetwork/tests/tensornetwork_test.py b/tensornetwork/tests/tensornetwork_test.py index 318d2053c..259aff3ac 100644 --- a/tensornetwork/tests/tensornetwork_test.py +++ b/tensornetwork/tests/tensornetwork_test.py @@ -345,6 +345,21 @@ def test_reorder_axes(backend): assert a.shape == (4, 2, 3) +def test_reorder_axes_raises_error_no_tensor(backend): + a = tn.Node(np.zeros((2, 3, 4)), backend=backend) + del a._tensor + with pytest.raises(AttributeError) as e: + a.reorder_axes([2, 0, 1]) + assert "Please provide a valid tensor for this Node." in str(e.value) + + +def test_reorder_axes_raises_error_bad_permutation(backend): + a = tn.Node(np.zeros((2, 3, 4)), backend=backend) + with pytest.raises(ValueError) as e: + a.reorder_axes([2, 0]) + assert "A full permutation was not passed." in str(e.value) + + def test_flatten_consistent_result(backend): a_val = np.ones((3, 5, 5, 6)) b_val = np.ones((5, 6, 4, 5))