diff --git a/tensornetwork/backends/base_backend.py b/tensornetwork/backends/base_backend.py index 1d9d249b7..2d58e2e84 100644 --- a/tensornetwork/backends/base_backend.py +++ b/tensornetwork/backends/base_backend.py @@ -135,7 +135,13 @@ def rq_decomposition( raise NotImplementedError( "Backend '{}' has not implemented rq_decomposition.".format(self.name)) - def concat(self, values: Sequence[Tensor], axis) -> Tensor: + def shape_concat(self, values: Sequence[Tensor]) -> Tensor: + """Concatenate a sequence of tensors together about last axis, + intended only for use in shape calculations""" + raise NotImplementedError("Backend '{}' has not implemented shape_concat.".format( + self.name)) + + def concat(self, values: Sequence[Tensor], axis: int = 0) -> Tensor: """Concatenate a sequence of tensors together about the given axis.""" raise NotImplementedError("Backend '{}' has not implemented concat.".format( self.name)) diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index 9773f6026..90c750aff 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple, Callable, List, Text, Type +from typing import Any, Sequence, Optional, Tuple, Callable, List, Text, Type from tensornetwork.backends.numpy import numpy_backend import numpy as np @@ -39,8 +39,11 @@ def convert_to_tensor(self, tensor: Tensor) -> Tensor: result = self.jax.jit(lambda x: x)(tensor) return result - def concat(self, values: Tensor, axis: int) -> Tensor: - return np.concatenate(values, axis) + def shape_concat(self, values: Tensor) -> Tensor: + return np.concatenate(values, -1) + + def concat(self, values: Sequence[Tensor], axis: int = 0) -> Tensor: + return np.stack(values, axis) def randn(self, shape: Tuple[int, ...], diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 08b21059d..76bc934fa 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -35,12 +35,20 @@ def test_transpose(): np.testing.assert_allclose(expected, actual) -def test_concat(): +def test_shape_concat(): backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(2 * np.ones((1, 3, 1))) + a = backend.convert_to_tensor(2 * np.ones((1, 2, 2))) b = backend.convert_to_tensor(np.ones((1, 2, 1))) - expected = backend.concat((a, b), axis=1) - actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]]) + expected = backend.shape_concat((a, b)) + actual = np.array([[[2.0, 2.0, 1.0], [2.0, 2.0, 1.0]]]) + np.testing.assert_allclose(expected, actual) + + +def test_concat(): + backend = jax_backend.JaxBackend() + scalars = [backend.convert_to_tensor(1.0), backend.convert_to_tensor(2.0)] + actual = backend.concat(scalars, 0) + expected = np.array([1.0, 2.0]) np.testing.assert_allclose(expected, actual) diff --git a/tensornetwork/backends/numpy/numpy_backend.py b/tensornetwork/backends/numpy/numpy_backend.py index e7c4e8ecd..0f59e25ca 100644 --- a/tensornetwork/backends/numpy/numpy_backend.py +++ b/tensornetwork/backends/numpy/numpy_backend.py @@ -61,8 +61,11 @@ def rq_decomposition( ) -> Tuple[Tensor, Tensor]: return decompositions.rq_decomposition(self.np, tensor, split_axis) - def concat(self, values: Tensor, axis: int) -> Tensor: - return self.np.concatenate(values, axis) + def shape_concat(self, values: Tensor) -> Tensor: + return self.np.concatenate(values, -1) + + def concat(self, values: Tensor, axis: int = 0) -> Tensor: + return self.np.stack(values, axis) def shape(self, tensor: Tensor) -> Tensor: return tensor.shape diff --git a/tensornetwork/backends/numpy/numpy_backend_test.py b/tensornetwork/backends/numpy/numpy_backend_test.py index 49645d876..c537116f3 100644 --- a/tensornetwork/backends/numpy/numpy_backend_test.py +++ b/tensornetwork/backends/numpy/numpy_backend_test.py @@ -33,12 +33,20 @@ def test_transpose(): np.testing.assert_allclose(expected, actual) -def test_concat(): +def test_shape_concat(): backend = numpy_backend.NumPyBackend() - a = backend.convert_to_tensor(2 * np.ones((1, 3, 1))) + a = backend.convert_to_tensor(2 * np.ones((1, 2, 2))) b = backend.convert_to_tensor(np.ones((1, 2, 1))) - expected = backend.concat((a, b), axis=1) - actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]]) + expected = backend.shape_concat((a, b)) + actual = np.array([[[2.0, 2.0, 1.0], [2.0, 2.0, 1.0]]]) + np.testing.assert_allclose(expected, actual) + + +def test_concat(): + backend = numpy_backend.NumPyBackend() + scalars = [backend.convert_to_tensor(1.0), backend.convert_to_tensor(2.0)] + actual = backend.concat(scalars, 0) + expected = np.array([1.0, 2.0]) np.testing.assert_allclose(expected, actual) diff --git a/tensornetwork/backends/pytorch/pytorch_backend.py b/tensornetwork/backends/pytorch/pytorch_backend.py index 0caba598a..7bc325e37 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend.py +++ b/tensornetwork/backends/pytorch/pytorch_backend.py @@ -69,8 +69,11 @@ def rq_decomposition( ) -> Tuple[Tensor, Tensor]: return decompositions.rq_decomposition(self.torch, tensor, split_axis) - def concat(self, values: Tensor, axis: int) -> Tensor: - return np.concatenate(values, axis) + def shape_concat(self, values: Tensor) -> Tensor: + return np.concatenate(values, -1) + + def concat(self, values: Tensor, axis: int = 0) -> Tensor: + return self.torch.stack(values, axis) def shape(self, tensor: Tensor) -> Tensor: return self.torch.tensor(list(tensor.shape)) diff --git a/tensornetwork/backends/pytorch/pytorch_backend_test.py b/tensornetwork/backends/pytorch/pytorch_backend_test.py index ca0cd92f3..3c8aef2ca 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend_test.py +++ b/tensornetwork/backends/pytorch/pytorch_backend_test.py @@ -34,15 +34,23 @@ def test_transpose(): np.testing.assert_allclose(expected, actual) -def test_concat(): +def test_shape_concat(): backend = pytorch_backend.PyTorchBackend() - a = backend.convert_to_tensor(2 * np.ones((1, 3, 1))) + a = backend.convert_to_tensor(2 * np.ones((1, 2, 2))) b = backend.convert_to_tensor(np.ones((1, 2, 1))) - expected = backend.concat((a, b), axis=1) - actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]]) + expected = backend.shape_concat((a, b)) + actual = np.array([[[2.0, 2.0, 1.0], [2.0, 2.0, 1.0]]]) np.testing.assert_allclose(expected, actual) +def test_concat(): + backend = pytorch_backend.PyTorchBackend() + scalars = [backend.convert_to_tensor(1.0), backend.convert_to_tensor(2.0)] + actual = backend.concat(scalars, 0) + expected = torch.Tensor([1.0, 2.0]) + torch.equal(expected, actual) + + def test_shape(): backend = pytorch_backend.PyTorchBackend() a = backend.convert_to_tensor(np.ones([2, 3, 4])) diff --git a/tensornetwork/backends/shell/shell_backend.py b/tensornetwork/backends/shell/shell_backend.py index 3365fae5e..87de76177 100644 --- a/tensornetwork/backends/shell/shell_backend.py +++ b/tensornetwork/backends/shell/shell_backend.py @@ -107,14 +107,17 @@ def rq_decomposition(self, tensor: Tensor, r = ShellTensor((center_dim,) + right_dims) return q, r - def concat(self, values: Sequence[Tensor], axis: int) -> Tensor: + def shape_concat(self, values: Sequence[Tensor]) -> Tensor: shape = values[0].shape - if axis < 0: - axis += len(shape) + axis = len(shape) - 1 concat_size = sum(v.shape[axis] for v in values) new_shape = shape[:axis] + (concat_size,) + shape[axis + 1:] return ShellTensor(new_shape) + def concat(self, values: Sequence[Tensor], axis: int = 0) -> Tensor: + raise NotImplementedError("Backend '{}' has not implemented concat.".format( + self.name)) + def concat_shape(self, values) -> Sequence: tuple_values = (tuple(v) for v in values) return functools.reduce(operator.concat, tuple_values) diff --git a/tensornetwork/backends/shell/shell_backend_test.py b/tensornetwork/backends/shell/shell_backend_test.py index 3974dc1f7..5ffeb9c05 100644 --- a/tensornetwork/backends/shell/shell_backend_test.py +++ b/tensornetwork/backends/shell/shell_backend_test.py @@ -62,16 +62,13 @@ def test_svd_decomposition_with_max_values(): assert x.shape == y.shape -def test_concat(): +def test_shape_concat(): args = { "values": [np.ones([3, 2, 5]), - np.zeros([3, 2, 5]), - np.ones([3, 3, 5])] + np.zeros([3, 2, 6]), + np.ones([3, 2, 4])] } - args["axis"] = 1 - assertBackendsAgree("concat", args) - args["axis"] = -2 - assertBackendsAgree("concat", args) + assertBackendsAgree("shape_concat", args) def test_concat_shape(): diff --git a/tensornetwork/backends/tensorflow/tensorflow_backend.py b/tensornetwork/backends/tensorflow/tensorflow_backend.py index 5f7cd1201..827b4c16c 100644 --- a/tensornetwork/backends/tensorflow/tensorflow_backend.py +++ b/tensornetwork/backends/tensorflow/tensorflow_backend.py @@ -64,8 +64,11 @@ def rq_decomposition(self, tensor: Tensor, split_axis: int) -> Tuple[Tensor, Tensor]: return decompositions.rq_decomposition(self.tf, tensor, split_axis) - def concat(self, values: Tensor, axis: int) -> Tensor: - return self.tf.concat(values, axis) + def shape_concat(self, values: Tensor) -> Tensor: + return self.tf.concat(values, -1) + + def concat(self, values: Sequence[Tensor], axis: int = 0) -> Tensor: + return self.tf.stack(values, axis) def shape(self, tensor: Tensor) -> Tensor: return self.tf.shape(tensor) diff --git a/tensornetwork/backends/tensorflow/tensorflow_backend_test.py b/tensornetwork/backends/tensorflow/tensorflow_backend_test.py index 25110d66c..f33653085 100644 --- a/tensornetwork/backends/tensorflow/tensorflow_backend_test.py +++ b/tensornetwork/backends/tensorflow/tensorflow_backend_test.py @@ -34,15 +34,23 @@ def test_transpose(): np.testing.assert_allclose(expected, actual) -def test_concat(): +def test_shape_concat(): backend = tensorflow_backend.TensorFlowBackend() - a = backend.convert_to_tensor(2 * np.ones((1, 3, 1))) + a = backend.convert_to_tensor(2 * np.ones((1, 2, 2))) b = backend.convert_to_tensor(np.ones((1, 2, 1))) - expected = backend.concat((a, b), axis=1) - actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]]) + expected = backend.shape_concat((a, b)) + actual = np.array([[[2.0, 2.0, 1.0], [2.0, 2.0, 1.0]]]) np.testing.assert_allclose(expected, actual) +def test_concat(): + backend = tensorflow_backend.TensorFlowBackend() + scalars = [backend.convert_to_tensor(1.0), backend.convert_to_tensor(2.0)] + actual = backend.concat(scalars, 0) + expected = tf.Variable([1.0, 2.0]) + tf.math.equal(expected, actual) + + def test_shape(): backend = tensorflow_backend.TensorFlowBackend() a = backend.convert_to_tensor(np.ones([2, 3, 4])) diff --git a/tensornetwork/network_components.py b/tensornetwork/network_components.py index b6569872f..65b6acbc3 100644 --- a/tensornetwork/network_components.py +++ b/tensornetwork/network_components.py @@ -1199,7 +1199,8 @@ def _flatten_trace_edges(edges: List[Edge], new_dim = backend.prod([backend.shape(node.tensor)[e.axis1] for e in edges]) node.reorder_axes(perm) unaffected_shape = backend.shape(node.tensor)[:len(perm_front)] - new_shape = backend.concat([unaffected_shape, [new_dim, new_dim]], axis=-1) + old_shape = [unaffected_shape, [new_dim, new_dim]] + new_shape = backend.shape_concat(old_shape) node.tensor = backend.reshape(node.tensor, new_shape) edge1 = Edge(node1=node, axis1=len(perm_front), name="TraceFront") edge2 = Edge(node1=node, axis1=len(perm_front) + 1, name="TraceBack") @@ -1275,8 +1276,8 @@ def flatten_edges(edges: List[Edge], # Calculate the new axis dimension as a product of the other # axes dimensions. flattened_axis_dim = backend.prod(old_tensor_shape[len(perm_front):]) - new_tensor_shape = backend.concat( - [old_tensor_shape[:len(perm_front)], [flattened_axis_dim]], axis=-1) + new_tensor_shape = backend.shape_concat( + [old_tensor_shape[:len(perm_front)], [flattened_axis_dim]]) new_tensor = backend.reshape(node.tensor, new_tensor_shape) # Modify the node in place. Currently, this is they only method that # modifies a node's tensor. @@ -1364,7 +1365,7 @@ def _split_trace_edge( perm_front = sorted(perm_front) node.reorder_axes(perm_front + perm_back) unaffected_shape = backend.shape(node.tensor)[:len(perm_front)] - new_shape = backend.concat([unaffected_shape, shape, shape], axis=-1) + new_shape = backend.shape_concat([unaffected_shape, shape, shape]) node.tensor = backend.reshape(node.tensor, new_shape) # Trim edges and add placeholder edges for new axes. node.edges = node.edges[:len(perm_front)] + 2 * len(shape) * [None] @@ -1439,7 +1440,7 @@ def split_edge(edge: Edge, perm_front = sorted(perm_front) node.reorder_axes(perm_front + perm_back) unaffected_shape = backend.shape(node.tensor)[:len(perm_front)] - new_shape = backend.concat([unaffected_shape, shape], axis=-1) + new_shape = backend.shape_concat([unaffected_shape, shape]) node.tensor = backend.reshape(node.tensor, new_shape) # in-place update # Trim edges. node.edges = node.edges[:len(perm_front)] diff --git a/tensornetwork/network_operations.py b/tensornetwork/network_operations.py index fa718a430..5ee5bc512 100644 --- a/tensornetwork/network_operations.py +++ b/tensornetwork/network_operations.py @@ -295,8 +295,8 @@ def split_node( # the first axis of vh. If we don't, it's possible one of the other axes of # vh will be the same size as sqrt_s and would multiply across that axis # instead, which is bad. - sqrt_s_broadcast_shape = backend.concat( - [backend.shape(sqrt_s), [1] * (len(vh.shape) - 1)], axis=-1) + sqrt_s_broadcast_shape = backend.shape_concat( + [backend.shape(sqrt_s), [1] * (len(vh.shape) - 1)]) vh_s = vh * backend.reshape(sqrt_s, sqrt_s_broadcast_shape) left_node = Node( u_s, name=left_name, axis_names=left_axis_names, backend=backend)