diff --git a/tensornetwork/backends/base_backend.py b/tensornetwork/backends/base_backend.py index 1d9d249b7..0368e6d6b 100644 --- a/tensornetwork/backends/base_backend.py +++ b/tensornetwork/backends/base_backend.py @@ -135,12 +135,12 @@ def rq_decomposition( raise NotImplementedError( "Backend '{}' has not implemented rq_decomposition.".format(self.name)) - def concat(self, values: Sequence[Tensor], axis) -> Tensor: + def shape_concat(self, values: Sequence[Tensor], axis) -> Tensor: """Concatenate a sequence of tensors together about the given axis.""" raise NotImplementedError("Backend '{}' has not implemented concat.".format( self.name)) - def shape(self, tensor: Tensor) -> Tensor: + def shape_tensor(self, tensor: Tensor) -> Tensor: """Get the shape of a tensor. Args: @@ -163,7 +163,7 @@ def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]: raise NotImplementedError( "Backend '{}' has not implemented shape_tuple.".format(self.name)) - def prod(self, values: Tensor) -> Tensor: + def shape_prod(self, values: Tensor) -> Tensor: """Take the product of all of the elements in values""" raise NotImplementedError("Backend '{}' has not implemented prod.".format( self.name)) diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index 9773f6026..a3064912f 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -39,7 +39,7 @@ def convert_to_tensor(self, tensor: Tensor) -> Tensor: result = self.jax.jit(lambda x: x)(tensor) return result - def concat(self, values: Tensor, axis: int) -> Tensor: + def shape_concat(self, values: Tensor, axis: int) -> Tensor: return np.concatenate(values, axis) def randn(self, diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 08b21059d..330dbe224 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -35,20 +35,20 @@ def test_transpose(): np.testing.assert_allclose(expected, actual) -def test_concat(): +def test_shape_concat(): backend = jax_backend.JaxBackend() a = backend.convert_to_tensor(2 * np.ones((1, 3, 1))) b = backend.convert_to_tensor(np.ones((1, 2, 1))) - expected = backend.concat((a, b), axis=1) + expected = backend.shape_concat((a, b), axis=1) actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]]) np.testing.assert_allclose(expected, actual) -def test_shape(): +def test_shape_tensor(): backend = jax_backend.JaxBackend() a = backend.convert_to_tensor(np.ones([2, 3, 4])) - assert isinstance(backend.shape(a), tuple) - actual = backend.shape(a) + assert isinstance(backend.shape_tensor(a), tuple) + actual = backend.shape_tensor(a) expected = np.array([2, 3, 4]) np.testing.assert_allclose(expected, actual) @@ -60,10 +60,10 @@ def test_shape_tuple(): assert actual == (2, 3, 4) -def test_prod(): +def test_shape_prod(): backend = jax_backend.JaxBackend() a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4])) - actual = np.array(backend.prod(a)) + actual = np.array(backend.shape_prod(a)) assert actual == 2**24 diff --git a/tensornetwork/backends/numpy/numpy_backend.py b/tensornetwork/backends/numpy/numpy_backend.py index e7c4e8ecd..41a0061c7 100644 --- a/tensornetwork/backends/numpy/numpy_backend.py +++ b/tensornetwork/backends/numpy/numpy_backend.py @@ -61,16 +61,16 @@ def rq_decomposition( ) -> Tuple[Tensor, Tensor]: return decompositions.rq_decomposition(self.np, tensor, split_axis) - def concat(self, values: Tensor, axis: int) -> Tensor: + def shape_concat(self, values: Tensor, axis: int) -> Tensor: return self.np.concatenate(values, axis) - def shape(self, tensor: Tensor) -> Tensor: + def shape_tensor(self, tensor: Tensor) -> Tensor: return tensor.shape def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]: return tensor.shape - def prod(self, values: Tensor) -> Tensor: + def shape_prod(self, values: Tensor) -> Tensor: return self.np.prod(values) def sqrt(self, tensor: Tensor) -> Tensor: diff --git a/tensornetwork/backends/numpy/numpy_backend_test.py b/tensornetwork/backends/numpy/numpy_backend_test.py index 49645d876..e8688ce62 100644 --- a/tensornetwork/backends/numpy/numpy_backend_test.py +++ b/tensornetwork/backends/numpy/numpy_backend_test.py @@ -33,20 +33,20 @@ def test_transpose(): np.testing.assert_allclose(expected, actual) -def test_concat(): +def test_shape_concat(): backend = numpy_backend.NumPyBackend() a = backend.convert_to_tensor(2 * np.ones((1, 3, 1))) b = backend.convert_to_tensor(np.ones((1, 2, 1))) - expected = backend.concat((a, b), axis=1) + expected = backend.shape_concat((a, b), axis=1) actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]]) np.testing.assert_allclose(expected, actual) -def test_shape(): +def test_shape_tensor(): backend = numpy_backend.NumPyBackend() a = backend.convert_to_tensor(np.ones([2, 3, 4])) - assert isinstance(backend.shape(a), tuple) - actual = backend.shape(a) + assert isinstance(backend.shape_tensor(a), tuple) + actual = backend.shape_tensor(a) expected = np.array([2, 3, 4]) np.testing.assert_allclose(expected, actual) @@ -58,10 +58,10 @@ def test_shape_tuple(): assert actual == (2, 3, 4) -def test_prod(): +def test_shape_prod(): backend = numpy_backend.NumPyBackend() a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4])) - actual = np.array(backend.prod(a)) + actual = np.array(backend.shape_prod(a)) assert actual == 2**24 diff --git a/tensornetwork/backends/pytorch/pytorch_backend.py b/tensornetwork/backends/pytorch/pytorch_backend.py index 0caba598a..b8e9a1ed7 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend.py +++ b/tensornetwork/backends/pytorch/pytorch_backend.py @@ -69,16 +69,16 @@ def rq_decomposition( ) -> Tuple[Tensor, Tensor]: return decompositions.rq_decomposition(self.torch, tensor, split_axis) - def concat(self, values: Tensor, axis: int) -> Tensor: + def shape_concat(self, values: Tensor, axis: int) -> Tensor: return np.concatenate(values, axis) - def shape(self, tensor: Tensor) -> Tensor: + def shape_tensor(self, tensor: Tensor) -> Tensor: return self.torch.tensor(list(tensor.shape)) def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]: return tuple(tensor.shape) - def prod(self, values: Tensor) -> int: + def shape_prod(self, values: Tensor) -> int: return np.prod(np.array(values)) def sqrt(self, tensor: Tensor) -> Tensor: diff --git a/tensornetwork/backends/pytorch/pytorch_backend_test.py b/tensornetwork/backends/pytorch/pytorch_backend_test.py index ca0cd92f3..e55d71b4b 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend_test.py +++ b/tensornetwork/backends/pytorch/pytorch_backend_test.py @@ -34,20 +34,20 @@ def test_transpose(): np.testing.assert_allclose(expected, actual) -def test_concat(): +def test_shape_concat(): backend = pytorch_backend.PyTorchBackend() a = backend.convert_to_tensor(2 * np.ones((1, 3, 1))) b = backend.convert_to_tensor(np.ones((1, 2, 1))) - expected = backend.concat((a, b), axis=1) + expected = backend.shape_concat((a, b), axis=1) actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]]) np.testing.assert_allclose(expected, actual) -def test_shape(): +def test_shape_tensor(): backend = pytorch_backend.PyTorchBackend() a = backend.convert_to_tensor(np.ones([2, 3, 4])) - assert isinstance(backend.shape(a), torch.Tensor) - actual = backend.shape(a) + assert isinstance(backend.shape_tensor(a), torch.Tensor) + actual = backend.shape_tensor(a) expected = np.array([2, 3, 4]) np.testing.assert_allclose(expected, actual) @@ -59,10 +59,10 @@ def test_shape_tuple(): assert actual == (2, 3, 4) -def test_prod(): +def test_shape_prod(): backend = pytorch_backend.PyTorchBackend() a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4])) - actual = np.array(backend.prod(a)) + actual = np.array(backend.shape_prod(a)) assert actual == 2**24 diff --git a/tensornetwork/backends/shell/shell_backend.py b/tensornetwork/backends/shell/shell_backend.py index 3365fae5e..33b30a99c 100644 --- a/tensornetwork/backends/shell/shell_backend.py +++ b/tensornetwork/backends/shell/shell_backend.py @@ -107,7 +107,7 @@ def rq_decomposition(self, tensor: Tensor, r = ShellTensor((center_dim,) + right_dims) return q, r - def concat(self, values: Sequence[Tensor], axis: int) -> Tensor: + def shape_concat(self, values: Sequence[Tensor], axis: int) -> Tensor: shape = values[0].shape if axis < 0: axis += len(shape) @@ -119,20 +119,20 @@ def concat_shape(self, values) -> Sequence: tuple_values = (tuple(v) for v in values) return functools.reduce(operator.concat, tuple_values) - def shape(self, tensor: Tensor) -> Tuple: + def shape_tensor(self, tensor: Tensor) -> Tuple: return tensor.shape def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]: return tensor.shape - def prod(self, values: Tensor) -> int: + def shape_prod(self, values: Tensor) -> int: # This is different from the BaseBackend prod! # prod calculates the product of tensor elements and cannot implemented # for shell tensors # This returns the product of sizes instead - return self.shape_prod(values.shape) + return self.shape_product(values.shape) - def shape_prod(self, shape: Sequence[int]) -> int: + def shape_product(self, shape: Sequence[int]) -> int: return functools.reduce(operator.mul, shape) def sqrt(self, tensor: Tensor) -> Tensor: diff --git a/tensornetwork/backends/shell/shell_backend_test.py b/tensornetwork/backends/shell/shell_backend_test.py index 3974dc1f7..af17c3354 100644 --- a/tensornetwork/backends/shell/shell_backend_test.py +++ b/tensornetwork/backends/shell/shell_backend_test.py @@ -62,16 +62,16 @@ def test_svd_decomposition_with_max_values(): assert x.shape == y.shape -def test_concat(): +def test_shape_concat(): args = { "values": [np.ones([3, 2, 5]), np.zeros([3, 2, 5]), np.ones([3, 3, 5])] } args["axis"] = 1 - assertBackendsAgree("concat", args) + assertBackendsAgree("shape_concat", args) args["axis"] = -2 - assertBackendsAgree("concat", args) + assertBackendsAgree("shape_concat", args) def test_concat_shape(): @@ -80,10 +80,10 @@ def test_concat_shape(): assert result == (5, 2, 3, 4, 6) -def test_shape(): +def test_shape_tensor(): tensor = np.ones([3, 5, 2]) - np_result = numpy_backend.NumPyBackend().shape(tensor) - sh_result = shell_backend.ShellBackend().shape(tensor) + np_result = numpy_backend.NumPyBackend().shape_tensor(tensor) + sh_result = shell_backend.ShellBackend().shape_tensor(tensor) assert np_result == sh_result @@ -94,8 +94,8 @@ def test_shape_tuple(): assert np_result == sh_result -def test_prod(): - result = shell_backend.ShellBackend().prod(np.ones([3, 5, 2])) +def test_shape_prod(): + result = shell_backend.ShellBackend().shape_prod(np.ones([3, 5, 2])) assert result == 30 diff --git a/tensornetwork/backends/tensorflow/tensorflow_backend.py b/tensornetwork/backends/tensorflow/tensorflow_backend.py index 5f7cd1201..c87d464c5 100644 --- a/tensornetwork/backends/tensorflow/tensorflow_backend.py +++ b/tensornetwork/backends/tensorflow/tensorflow_backend.py @@ -64,16 +64,16 @@ def rq_decomposition(self, tensor: Tensor, split_axis: int) -> Tuple[Tensor, Tensor]: return decompositions.rq_decomposition(self.tf, tensor, split_axis) - def concat(self, values: Tensor, axis: int) -> Tensor: + def shape_concat(self, values: Tensor, axis: int) -> Tensor: return self.tf.concat(values, axis) - def shape(self, tensor: Tensor) -> Tensor: + def shape_tensor(self, tensor: Tensor) -> Tensor: return self.tf.shape(tensor) def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]: return tuple(tensor.shape.as_list()) - def prod(self, values: Tensor) -> Tensor: + def shape_prod(self, values: Tensor) -> Tensor: return self.tf.reduce_prod(values) def sqrt(self, tensor: Tensor) -> Tensor: diff --git a/tensornetwork/backends/tensorflow/tensorflow_backend_test.py b/tensornetwork/backends/tensorflow/tensorflow_backend_test.py index 25110d66c..838771598 100644 --- a/tensornetwork/backends/tensorflow/tensorflow_backend_test.py +++ b/tensornetwork/backends/tensorflow/tensorflow_backend_test.py @@ -34,20 +34,20 @@ def test_transpose(): np.testing.assert_allclose(expected, actual) -def test_concat(): +def test_shape_concat(): backend = tensorflow_backend.TensorFlowBackend() a = backend.convert_to_tensor(2 * np.ones((1, 3, 1))) b = backend.convert_to_tensor(np.ones((1, 2, 1))) - expected = backend.concat((a, b), axis=1) + expected = backend.shape_concat((a, b), axis=1) actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]]) np.testing.assert_allclose(expected, actual) -def test_shape(): +def test_shape_tensor(): backend = tensorflow_backend.TensorFlowBackend() a = backend.convert_to_tensor(np.ones([2, 3, 4])) - assert isinstance(backend.shape(a), type(a)) - actual = backend.shape(a) + assert isinstance(backend.shape_tensor(a), type(a)) + actual = backend.shape_tensor(a) expected = np.array([2, 3, 4]) np.testing.assert_allclose(expected, actual) @@ -59,10 +59,10 @@ def test_shape_tuple(): assert actual == (2, 3, 4) -def test_prod(): +def test_shape_prod(): backend = tensorflow_backend.TensorFlowBackend() a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4])) - actual = np.array(backend.prod(a)) + actual = np.array(backend.shape_prod(a)) assert actual == 2**24 diff --git a/tensornetwork/network_components.py b/tensornetwork/network_components.py index b6569872f..1f6de8917 100644 --- a/tensornetwork/network_components.py +++ b/tensornetwork/network_components.py @@ -1196,10 +1196,12 @@ def _flatten_trace_edges(edges: List[Edge], perm_front = set(range(len(node.edges))) - set(perm_back) perm_front = sorted(perm_front) perm = perm_front + perm_back - new_dim = backend.prod([backend.shape(node.tensor)[e.axis1] for e in edges]) + new_dim = backend.shape_prod( + [backend.shape_tensor(node.tensor)[e.axis1] for e in edges]) node.reorder_axes(perm) - unaffected_shape = backend.shape(node.tensor)[:len(perm_front)] - new_shape = backend.concat([unaffected_shape, [new_dim, new_dim]], axis=-1) + unaffected_shape = backend.shape_tensor(node.tensor)[:len(perm_front)] + new_shape = backend.shape_concat( + [unaffected_shape, [new_dim, new_dim]], axis=-1) node.tensor = backend.reshape(node.tensor, new_shape) edge1 = Edge(node1=node, axis1=len(perm_front), name="TraceFront") edge2 = Edge(node1=node, axis1=len(perm_front) + 1, name="TraceBack") @@ -1271,11 +1273,11 @@ def flatten_edges(edges: List[Edge], perm_back.append(node.edges.index(edge)) perm_front = sorted(set(range(len(node.edges))) - set(perm_back)) node.reorder_axes(perm_front + perm_back) - old_tensor_shape = backend.shape(node.tensor) + old_tensor_shape = backend.shape_tensor(node.tensor) # Calculate the new axis dimension as a product of the other # axes dimensions. - flattened_axis_dim = backend.prod(old_tensor_shape[len(perm_front):]) - new_tensor_shape = backend.concat( + flattened_axis_dim = backend.shape_prod(old_tensor_shape[len(perm_front):]) + new_tensor_shape = backend.shape_concat( [old_tensor_shape[:len(perm_front)], [flattened_axis_dim]], axis=-1) new_tensor = backend.reshape(node.tensor, new_tensor_shape) # Modify the node in place. Currently, this is they only method that @@ -1363,8 +1365,8 @@ def _split_trace_edge( perm_front = set(range(len(node.edges))) - set(perm_back) perm_front = sorted(perm_front) node.reorder_axes(perm_front + perm_back) - unaffected_shape = backend.shape(node.tensor)[:len(perm_front)] - new_shape = backend.concat([unaffected_shape, shape, shape], axis=-1) + unaffected_shape = backend.shape_tensor(node.tensor)[:len(perm_front)] + new_shape = backend.shape_concat([unaffected_shape, shape, shape], axis=-1) node.tensor = backend.reshape(node.tensor, new_shape) # Trim edges and add placeholder edges for new axes. node.edges = node.edges[:len(perm_front)] + 2 * len(shape) * [None] @@ -1438,8 +1440,8 @@ def split_edge(edge: Edge, perm_front = set(range(len(node.edges))) - set(perm_back) perm_front = sorted(perm_front) node.reorder_axes(perm_front + perm_back) - unaffected_shape = backend.shape(node.tensor)[:len(perm_front)] - new_shape = backend.concat([unaffected_shape, shape], axis=-1) + unaffected_shape = backend.shape_tensor(node.tensor)[:len(perm_front)] + new_shape = backend.shape_concat([unaffected_shape, shape], axis=-1) node.tensor = backend.reshape(node.tensor, new_shape) # in-place update # Trim edges. node.edges = node.edges[:len(perm_front)] diff --git a/tensornetwork/network_operations.py b/tensornetwork/network_operations.py index fa718a430..5aea214db 100644 --- a/tensornetwork/network_operations.py +++ b/tensornetwork/network_operations.py @@ -295,8 +295,8 @@ def split_node( # the first axis of vh. If we don't, it's possible one of the other axes of # vh will be the same size as sqrt_s and would multiply across that axis # instead, which is bad. - sqrt_s_broadcast_shape = backend.concat( - [backend.shape(sqrt_s), [1] * (len(vh.shape) - 1)], axis=-1) + sqrt_s_broadcast_shape = backend.shape_concat( + [backend.shape_tensor(sqrt_s), [1] * (len(vh.shape) - 1)], axis=-1) vh_s = vh * backend.reshape(sqrt_s, sqrt_s_broadcast_shape) left_node = Node( u_s, name=left_name, axis_names=left_axis_names, backend=backend)