Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tensornetwork/backends/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def rq_decomposition(
raise NotImplementedError(
"Backend '{}' has not implemented rq_decomposition.".format(self.name))

def concat(self, values: Sequence[Tensor], axis) -> Tensor:
def shape_concat(self, values: Sequence[Tensor], axis) -> Tensor:
"""Concatenate a sequence of tensors together about the given axis."""
raise NotImplementedError("Backend '{}' has not implemented concat.".format(
self.name))

def shape(self, tensor: Tensor) -> Tensor:
def shape_tensor(self, tensor: Tensor) -> Tensor:
"""Get the shape of a tensor.

Args:
Expand All @@ -163,7 +163,7 @@ def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]:
raise NotImplementedError(
"Backend '{}' has not implemented shape_tuple.".format(self.name))

def prod(self, values: Tensor) -> Tensor:
def shape_prod(self, values: Tensor) -> Tensor:
"""Take the product of all of the elements in values"""
raise NotImplementedError("Backend '{}' has not implemented prod.".format(
self.name))
Expand Down
2 changes: 1 addition & 1 deletion tensornetwork/backends/jax/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def convert_to_tensor(self, tensor: Tensor) -> Tensor:
result = self.jax.jit(lambda x: x)(tensor)
return result

def concat(self, values: Tensor, axis: int) -> Tensor:
def shape_concat(self, values: Tensor, axis: int) -> Tensor:
return np.concatenate(values, axis)

def randn(self,
Expand Down
14 changes: 7 additions & 7 deletions tensornetwork/backends/jax/jax_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ def test_transpose():
np.testing.assert_allclose(expected, actual)


def test_concat():
def test_shape_concat():
backend = jax_backend.JaxBackend()
a = backend.convert_to_tensor(2 * np.ones((1, 3, 1)))
b = backend.convert_to_tensor(np.ones((1, 2, 1)))
expected = backend.concat((a, b), axis=1)
expected = backend.shape_concat((a, b), axis=1)
actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]])
np.testing.assert_allclose(expected, actual)


def test_shape():
def test_shape_tensor():
backend = jax_backend.JaxBackend()
a = backend.convert_to_tensor(np.ones([2, 3, 4]))
assert isinstance(backend.shape(a), tuple)
actual = backend.shape(a)
assert isinstance(backend.shape_tensor(a), tuple)
actual = backend.shape_tensor(a)
expected = np.array([2, 3, 4])
np.testing.assert_allclose(expected, actual)

Expand All @@ -60,10 +60,10 @@ def test_shape_tuple():
assert actual == (2, 3, 4)


def test_prod():
def test_shape_prod():
backend = jax_backend.JaxBackend()
a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4]))
actual = np.array(backend.prod(a))
actual = np.array(backend.shape_prod(a))
assert actual == 2**24


Expand Down
6 changes: 3 additions & 3 deletions tensornetwork/backends/numpy/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ def rq_decomposition(
) -> Tuple[Tensor, Tensor]:
return decompositions.rq_decomposition(self.np, tensor, split_axis)

def concat(self, values: Tensor, axis: int) -> Tensor:
def shape_concat(self, values: Tensor, axis: int) -> Tensor:
return self.np.concatenate(values, axis)

def shape(self, tensor: Tensor) -> Tensor:
def shape_tensor(self, tensor: Tensor) -> Tensor:
return tensor.shape

def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]:
return tensor.shape

def prod(self, values: Tensor) -> Tensor:
def shape_prod(self, values: Tensor) -> Tensor:
return self.np.prod(values)

def sqrt(self, tensor: Tensor) -> Tensor:
Expand Down
14 changes: 7 additions & 7 deletions tensornetwork/backends/numpy/numpy_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ def test_transpose():
np.testing.assert_allclose(expected, actual)


def test_concat():
def test_shape_concat():
backend = numpy_backend.NumPyBackend()
a = backend.convert_to_tensor(2 * np.ones((1, 3, 1)))
b = backend.convert_to_tensor(np.ones((1, 2, 1)))
expected = backend.concat((a, b), axis=1)
expected = backend.shape_concat((a, b), axis=1)
actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]])
np.testing.assert_allclose(expected, actual)


def test_shape():
def test_shape_tensor():
backend = numpy_backend.NumPyBackend()
a = backend.convert_to_tensor(np.ones([2, 3, 4]))
assert isinstance(backend.shape(a), tuple)
actual = backend.shape(a)
assert isinstance(backend.shape_tensor(a), tuple)
actual = backend.shape_tensor(a)
expected = np.array([2, 3, 4])
np.testing.assert_allclose(expected, actual)

Expand All @@ -58,10 +58,10 @@ def test_shape_tuple():
assert actual == (2, 3, 4)


def test_prod():
def test_shape_prod():
backend = numpy_backend.NumPyBackend()
a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4]))
actual = np.array(backend.prod(a))
actual = np.array(backend.shape_prod(a))
assert actual == 2**24


Expand Down
6 changes: 3 additions & 3 deletions tensornetwork/backends/pytorch/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ def rq_decomposition(
) -> Tuple[Tensor, Tensor]:
return decompositions.rq_decomposition(self.torch, tensor, split_axis)

def concat(self, values: Tensor, axis: int) -> Tensor:
def shape_concat(self, values: Tensor, axis: int) -> Tensor:
return np.concatenate(values, axis)

def shape(self, tensor: Tensor) -> Tensor:
def shape_tensor(self, tensor: Tensor) -> Tensor:
return self.torch.tensor(list(tensor.shape))

def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]:
return tuple(tensor.shape)

def prod(self, values: Tensor) -> int:
def shape_prod(self, values: Tensor) -> int:
return np.prod(np.array(values))

def sqrt(self, tensor: Tensor) -> Tensor:
Expand Down
14 changes: 7 additions & 7 deletions tensornetwork/backends/pytorch/pytorch_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ def test_transpose():
np.testing.assert_allclose(expected, actual)


def test_concat():
def test_shape_concat():
backend = pytorch_backend.PyTorchBackend()
a = backend.convert_to_tensor(2 * np.ones((1, 3, 1)))
b = backend.convert_to_tensor(np.ones((1, 2, 1)))
expected = backend.concat((a, b), axis=1)
expected = backend.shape_concat((a, b), axis=1)
actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]])
np.testing.assert_allclose(expected, actual)


def test_shape():
def test_shape_tensor():
backend = pytorch_backend.PyTorchBackend()
a = backend.convert_to_tensor(np.ones([2, 3, 4]))
assert isinstance(backend.shape(a), torch.Tensor)
actual = backend.shape(a)
assert isinstance(backend.shape_tensor(a), torch.Tensor)
actual = backend.shape_tensor(a)
expected = np.array([2, 3, 4])
np.testing.assert_allclose(expected, actual)

Expand All @@ -59,10 +59,10 @@ def test_shape_tuple():
assert actual == (2, 3, 4)


def test_prod():
def test_shape_prod():
backend = pytorch_backend.PyTorchBackend()
a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4]))
actual = np.array(backend.prod(a))
actual = np.array(backend.shape_prod(a))
assert actual == 2**24


Expand Down
10 changes: 5 additions & 5 deletions tensornetwork/backends/shell/shell_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def rq_decomposition(self, tensor: Tensor,
r = ShellTensor((center_dim,) + right_dims)
return q, r

def concat(self, values: Sequence[Tensor], axis: int) -> Tensor:
def shape_concat(self, values: Sequence[Tensor], axis: int) -> Tensor:
shape = values[0].shape
if axis < 0:
axis += len(shape)
Expand All @@ -119,20 +119,20 @@ def concat_shape(self, values) -> Sequence:
tuple_values = (tuple(v) for v in values)
return functools.reduce(operator.concat, tuple_values)

def shape(self, tensor: Tensor) -> Tuple:
def shape_tensor(self, tensor: Tensor) -> Tuple:
return tensor.shape

def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]:
return tensor.shape

def prod(self, values: Tensor) -> int:
def shape_prod(self, values: Tensor) -> int:
# This is different from the BaseBackend prod!
# prod calculates the product of tensor elements and cannot implemented
# for shell tensors
# This returns the product of sizes instead
return self.shape_prod(values.shape)
return self.shape_product(values.shape)

def shape_prod(self, shape: Sequence[int]) -> int:
def shape_product(self, shape: Sequence[int]) -> int:
return functools.reduce(operator.mul, shape)

def sqrt(self, tensor: Tensor) -> Tensor:
Expand Down
16 changes: 8 additions & 8 deletions tensornetwork/backends/shell/shell_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ def test_svd_decomposition_with_max_values():
assert x.shape == y.shape


def test_concat():
def test_shape_concat():
args = {
"values": [np.ones([3, 2, 5]),
np.zeros([3, 2, 5]),
np.ones([3, 3, 5])]
}
args["axis"] = 1
assertBackendsAgree("concat", args)
assertBackendsAgree("shape_concat", args)
args["axis"] = -2
assertBackendsAgree("concat", args)
assertBackendsAgree("shape_concat", args)


def test_concat_shape():
Expand All @@ -80,10 +80,10 @@ def test_concat_shape():
assert result == (5, 2, 3, 4, 6)


def test_shape():
def test_shape_tensor():
tensor = np.ones([3, 5, 2])
np_result = numpy_backend.NumPyBackend().shape(tensor)
sh_result = shell_backend.ShellBackend().shape(tensor)
np_result = numpy_backend.NumPyBackend().shape_tensor(tensor)
sh_result = shell_backend.ShellBackend().shape_tensor(tensor)
assert np_result == sh_result


Expand All @@ -94,8 +94,8 @@ def test_shape_tuple():
assert np_result == sh_result


def test_prod():
result = shell_backend.ShellBackend().prod(np.ones([3, 5, 2]))
def test_shape_prod():
result = shell_backend.ShellBackend().shape_prod(np.ones([3, 5, 2]))
assert result == 30


Expand Down
6 changes: 3 additions & 3 deletions tensornetwork/backends/tensorflow/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@ def rq_decomposition(self, tensor: Tensor,
split_axis: int) -> Tuple[Tensor, Tensor]:
return decompositions.rq_decomposition(self.tf, tensor, split_axis)

def concat(self, values: Tensor, axis: int) -> Tensor:
def shape_concat(self, values: Tensor, axis: int) -> Tensor:
return self.tf.concat(values, axis)

def shape(self, tensor: Tensor) -> Tensor:
def shape_tensor(self, tensor: Tensor) -> Tensor:
return self.tf.shape(tensor)

def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]:
return tuple(tensor.shape.as_list())

def prod(self, values: Tensor) -> Tensor:
def shape_prod(self, values: Tensor) -> Tensor:
return self.tf.reduce_prod(values)

def sqrt(self, tensor: Tensor) -> Tensor:
Expand Down
14 changes: 7 additions & 7 deletions tensornetwork/backends/tensorflow/tensorflow_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ def test_transpose():
np.testing.assert_allclose(expected, actual)


def test_concat():
def test_shape_concat():
backend = tensorflow_backend.TensorFlowBackend()
a = backend.convert_to_tensor(2 * np.ones((1, 3, 1)))
b = backend.convert_to_tensor(np.ones((1, 2, 1)))
expected = backend.concat((a, b), axis=1)
expected = backend.shape_concat((a, b), axis=1)
actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]])
np.testing.assert_allclose(expected, actual)


def test_shape():
def test_shape_tensor():
backend = tensorflow_backend.TensorFlowBackend()
a = backend.convert_to_tensor(np.ones([2, 3, 4]))
assert isinstance(backend.shape(a), type(a))
actual = backend.shape(a)
assert isinstance(backend.shape_tensor(a), type(a))
actual = backend.shape_tensor(a)
expected = np.array([2, 3, 4])
np.testing.assert_allclose(expected, actual)

Expand All @@ -59,10 +59,10 @@ def test_shape_tuple():
assert actual == (2, 3, 4)


def test_prod():
def test_shape_prod():
backend = tensorflow_backend.TensorFlowBackend()
a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4]))
actual = np.array(backend.prod(a))
actual = np.array(backend.shape_prod(a))
assert actual == 2**24


Expand Down
22 changes: 12 additions & 10 deletions tensornetwork/network_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,10 +1196,12 @@ def _flatten_trace_edges(edges: List[Edge],
perm_front = set(range(len(node.edges))) - set(perm_back)
perm_front = sorted(perm_front)
perm = perm_front + perm_back
new_dim = backend.prod([backend.shape(node.tensor)[e.axis1] for e in edges])
new_dim = backend.shape_prod(
[backend.shape_tensor(node.tensor)[e.axis1] for e in edges])
node.reorder_axes(perm)
unaffected_shape = backend.shape(node.tensor)[:len(perm_front)]
new_shape = backend.concat([unaffected_shape, [new_dim, new_dim]], axis=-1)
unaffected_shape = backend.shape_tensor(node.tensor)[:len(perm_front)]
new_shape = backend.shape_concat(
[unaffected_shape, [new_dim, new_dim]], axis=-1)
node.tensor = backend.reshape(node.tensor, new_shape)
edge1 = Edge(node1=node, axis1=len(perm_front), name="TraceFront")
edge2 = Edge(node1=node, axis1=len(perm_front) + 1, name="TraceBack")
Expand Down Expand Up @@ -1271,11 +1273,11 @@ def flatten_edges(edges: List[Edge],
perm_back.append(node.edges.index(edge))
perm_front = sorted(set(range(len(node.edges))) - set(perm_back))
node.reorder_axes(perm_front + perm_back)
old_tensor_shape = backend.shape(node.tensor)
old_tensor_shape = backend.shape_tensor(node.tensor)
# Calculate the new axis dimension as a product of the other
# axes dimensions.
flattened_axis_dim = backend.prod(old_tensor_shape[len(perm_front):])
new_tensor_shape = backend.concat(
flattened_axis_dim = backend.shape_prod(old_tensor_shape[len(perm_front):])
new_tensor_shape = backend.shape_concat(
[old_tensor_shape[:len(perm_front)], [flattened_axis_dim]], axis=-1)
new_tensor = backend.reshape(node.tensor, new_tensor_shape)
# Modify the node in place. Currently, this is they only method that
Expand Down Expand Up @@ -1363,8 +1365,8 @@ def _split_trace_edge(
perm_front = set(range(len(node.edges))) - set(perm_back)
perm_front = sorted(perm_front)
node.reorder_axes(perm_front + perm_back)
unaffected_shape = backend.shape(node.tensor)[:len(perm_front)]
new_shape = backend.concat([unaffected_shape, shape, shape], axis=-1)
unaffected_shape = backend.shape_tensor(node.tensor)[:len(perm_front)]
new_shape = backend.shape_concat([unaffected_shape, shape, shape], axis=-1)
node.tensor = backend.reshape(node.tensor, new_shape)
# Trim edges and add placeholder edges for new axes.
node.edges = node.edges[:len(perm_front)] + 2 * len(shape) * [None]
Expand Down Expand Up @@ -1438,8 +1440,8 @@ def split_edge(edge: Edge,
perm_front = set(range(len(node.edges))) - set(perm_back)
perm_front = sorted(perm_front)
node.reorder_axes(perm_front + perm_back)
unaffected_shape = backend.shape(node.tensor)[:len(perm_front)]
new_shape = backend.concat([unaffected_shape, shape], axis=-1)
unaffected_shape = backend.shape_tensor(node.tensor)[:len(perm_front)]
new_shape = backend.shape_concat([unaffected_shape, shape], axis=-1)
node.tensor = backend.reshape(node.tensor, new_shape) # in-place update
# Trim edges.
node.edges = node.edges[:len(perm_front)]
Expand Down
4 changes: 2 additions & 2 deletions tensornetwork/network_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def split_node(
# the first axis of vh. If we don't, it's possible one of the other axes of
# vh will be the same size as sqrt_s and would multiply across that axis
# instead, which is bad.
sqrt_s_broadcast_shape = backend.concat(
[backend.shape(sqrt_s), [1] * (len(vh.shape) - 1)], axis=-1)
sqrt_s_broadcast_shape = backend.shape_concat(
[backend.shape_tensor(sqrt_s), [1] * (len(vh.shape) - 1)], axis=-1)
vh_s = vh * backend.reshape(sqrt_s, sqrt_s_broadcast_shape)
left_node = Node(
u_s, name=left_name, axis_names=left_axis_names, backend=backend)
Expand Down