diff --git a/tensornetwork/network_components.py b/tensornetwork/network_components.py index f8d713f58..38f51a995 100644 --- a/tensornetwork/network_components.py +++ b/tensornetwork/network_components.py @@ -1753,6 +1753,10 @@ def contract_between( ) -> BaseNode: """Contract all of the edges between the two given nodes. + If `output_edge_order` is not set, the output axes will be ordered as: + [...free axes of `node1`..., ...free axes of `node2`...]. Within the axes + of each node, the input order is preserved. + Args: node1: The first node. node2: The second node. @@ -1764,7 +1768,8 @@ def contract_between( contain all edges belonging to, but not shared by `node1` and `node2`. The axes of the new node will be permuted (if necessary) to match this ordering of Edges. - axis_names: An optional list of names for the axis of the new node + axis_names: An optional list of names for the axis of the new node in order + of the output axes. Returns: The new node created. @@ -1784,64 +1789,68 @@ def contract_between( node2.backend.name)) backend = node1.backend + shared_edges = get_shared_edges(node1, node2) # Trace edges cannot be contracted using tensordot. if node1 is node2: flat_edge = flatten_edges_between(node1, node2) if not flat_edge: raise ValueError("No trace edges found on contraction of edges between " "node '{}' and itself.".format(node1)) - return contract(flat_edge, name) - - shared_edges = get_shared_edges(node1, node2) - if not shared_edges: - if allow_outer_product: - return outer_product(node1, node2, name=name, axis_names=axis_names) - raise ValueError("No edges found between nodes '{}' and '{}' " - "and allow_outer_product=False.".format(node1, node2)) - - # Collect the axis of each node corresponding to each edge, in order. - # This specifies the contraction for tensordot. - # NOTE: The ordering of node references in each contraction edge is ignored. - axes1 = [] - axes2 = [] - for edge in shared_edges: - if edge.node1 is node1: - axes1.append(edge.axis1) - axes2.append(edge.axis2) - else: - axes1.append(edge.axis2) - axes2.append(edge.axis1) - - if output_edge_order: - # Determine heuristically if output transposition can be minimized by - # flipping the arguments to tensordot. - node1_output_axes = [] - node2_output_axes = [] - for (i, edge) in enumerate(output_edge_order): - if edge in shared_edges: - raise ValueError( - "Edge '{}' in output_edge_order is shared by the nodes to be " - "contracted: '{}' and '{}'.".format(edge, node1, node2)) - edge_nodes = set(edge.get_nodes()) - if node1 in edge_nodes: - node1_output_axes.append(i) - elif node2 in edge_nodes: - node2_output_axes.append(i) + new_node = contract(flat_edge, name) + elif not shared_edges: + if not allow_outer_product: + raise ValueError("No edges found between nodes '{}' and '{}' " + "and allow_outer_product=False.".format(node1, node2)) + new_node = outer_product(node1, node2, name=name) + else: + # Collect the axis of each node corresponding to each edge, in order. + # This specifies the contraction for tensordot. + # NOTE: The ordering of node references in each contraction edge is ignored. + axes1 = [] + axes2 = [] + for edge in shared_edges: + if edge.node1 is node1: + axes1.append(edge.axis1) + axes2.append(edge.axis2) else: - raise ValueError( - "Edge '{}' in output_edge_order is not connected to node '{}' or " - "node '{}'".format(edge, node1, node2)) - if np.mean(node1_output_axes) > np.mean(node2_output_axes): - node1, node2 = node2, node1 - axes1, axes2 = axes2, axes1 - - new_tensor = backend.tensordot(node1.tensor, node2.tensor, [axes1, axes2]) - new_node = Node( - tensor=new_tensor, name=name, axis_names=axis_names, backend=backend) - # node1 and node2 get new edges in _remove_edges - _remove_edges(shared_edges, node1, node2, new_node) + axes1.append(edge.axis2) + axes2.append(edge.axis1) + + if output_edge_order: + # Determine heuristically if output transposition can be minimized by + # flipping the arguments to tensordot. + node1_output_axes = [] + node2_output_axes = [] + for (i, edge) in enumerate(output_edge_order): + if edge in shared_edges: + raise ValueError( + "Edge '{}' in output_edge_order is shared by the nodes to be " + "contracted: '{}' and '{}'.".format(edge, node1, node2)) + edge_nodes = set(edge.get_nodes()) + if node1 in edge_nodes: + node1_output_axes.append(i) + elif node2 in edge_nodes: + node2_output_axes.append(i) + else: + raise ValueError( + "Edge '{}' in output_edge_order is not connected to node '{}' or " + "node '{}'".format(edge, node1, node2)) + if node1_output_axes and node2_output_axes and ( + np.mean(node1_output_axes) > np.mean(node2_output_axes)): + node1, node2 = node2, node1 + axes1, axes2 = axes2, axes1 + + new_tensor = backend.tensordot(node1.tensor, node2.tensor, [axes1, axes2]) + new_node = Node( + tensor=new_tensor, name=name, backend=backend) + # node1 and node2 get new edges in _remove_edges + _remove_edges(shared_edges, node1, node2, new_node) + if output_edge_order: new_node = new_node.reorder_edges(list(output_edge_order)) + if axis_names: + new_node.add_axis_names(axis_names) + return new_node diff --git a/tensornetwork/tests/network_test.py b/tensornetwork/tests/network_test.py index 8efcd9d4d..f4e83a892 100644 --- a/tensornetwork/tests/network_test.py +++ b/tensornetwork/tests/network_test.py @@ -483,26 +483,45 @@ def test_flatten_all_edges(backend): def test_contract_between(backend): - a_val = np.ones((2, 3, 4, 5)) - b_val = np.ones((3, 5, 4, 2)) + a_val = np.random.rand(2, 3, 4, 5) + b_val = np.random.rand(3, 5, 6, 2) a = tn.Node(a_val, backend=backend) b = tn.Node(b_val, backend=backend) tn.connect(a[0], b[3]) tn.connect(b[1], a[3]) tn.connect(a[1], b[0]) - edge_a = a[2] - edge_b = b[2] - c = tn.contract_between(a, b, name="New Node") - c.reorder_edges([edge_a, edge_b]) + output_axis_names = ["a2", "b2"] + c = tn.contract_between(a, b, name="New Node", axis_names=output_axis_names) tn.check_correct({c}) # Check expected values. a_flat = np.reshape(np.transpose(a_val, (2, 1, 0, 3)), (4, 30)) - b_flat = np.reshape(np.transpose(b_val, (2, 0, 3, 1)), (4, 30)) + b_flat = np.reshape(np.transpose(b_val, (2, 0, 3, 1)), (6, 30)) final_val = np.matmul(a_flat, b_flat.T) assert c.name == "New Node" + assert c.axis_names == output_axis_names np.testing.assert_allclose(c.tensor, final_val) +def test_contract_between_output_edge_order(backend): + a_val = np.random.rand(2, 3, 4, 5) + b_val = np.random.rand(3, 5, 6, 2) + a = tn.Node(a_val, backend=backend) + b = tn.Node(b_val, backend=backend) + tn.connect(a[0], b[3]) + tn.connect(b[1], a[3]) + tn.connect(a[1], b[0]) + output_axis_names = ["b2", "a2"] + c = tn.contract_between(a, b, name="New Node", axis_names=output_axis_names, + output_edge_order=[b[2], a[2]]) + # Check expected values. + a_flat = np.reshape(np.transpose(a_val, (2, 1, 0, 3)), (4, 30)) + b_flat = np.reshape(np.transpose(b_val, (2, 0, 3, 1)), (6, 30)) + final_val = np.matmul(a_flat, b_flat.T) + assert c.name == "New Node" + assert c.axis_names == output_axis_names + np.testing.assert_allclose(c.tensor, final_val.T) + + def test_contract_between_no_outer_product_value_error(backend): a_val = np.ones((2, 3, 4)) b_val = np.ones((5, 6, 7)) @@ -517,8 +536,45 @@ def test_contract_between_outer_product_no_value_error(backend): b_val = np.ones((5, 6, 7)) a = tn.Node(a_val, backend=backend) b = tn.Node(b_val, backend=backend) - c = tn.contract_between(a, b, allow_outer_product=True) + output_axis_names = ["a0", "a1", "a2", "b0", "b1", "b2"] + c = tn.contract_between(a, b, allow_outer_product=True, + axis_names=output_axis_names) assert c.shape == (2, 3, 4, 5, 6, 7) + assert c.axis_names == output_axis_names + + +def test_contract_between_outer_product_output_edge_order(backend): + a_val = np.ones((2, 3, 4)) + b_val = np.ones((5, 6, 7)) + a = tn.Node(a_val, backend=backend) + b = tn.Node(b_val, backend=backend) + output_axis_names = ["b0", "b1", "a0", "b2", "a1", "a2"] + c = tn.contract_between( + a, b, + allow_outer_product=True, + output_edge_order=[b[0], b[1], a[0], b[2], a[1], a[2]], + axis_names=output_axis_names) + assert c.shape == (5, 6, 2, 7, 3, 4) + assert c.axis_names == output_axis_names + + +def test_contract_between_trace(backend): + a_val = np.ones((2, 3, 2, 4)) + a = tn.Node(a_val, backend=backend) + tn.connect(a[0], a[2]) + c = tn.contract_between(a, a, axis_names=["1", "3"]) + assert c.shape == (3, 4) + assert c.axis_names == ["1", "3"] + + +def test_contract_between_trace_output_edge_order(backend): + a_val = np.ones((2, 3, 2, 4)) + a = tn.Node(a_val, backend=backend) + tn.connect(a[0], a[2]) + c = tn.contract_between(a, a, output_edge_order=[a[3], a[1]], + axis_names=["3", "1"]) + assert c.shape == (4, 3) + assert c.axis_names == ["3", "1"] def test_contract_parallel(backend):