diff --git a/tensornetwork/block_tensor/block_tensor.py b/tensornetwork/block_tensor/block_tensor.py index cb2976a00..a4388ce79 100644 --- a/tensornetwork/block_tensor/block_tensor.py +++ b/tensornetwork/block_tensor/block_tensor.py @@ -135,7 +135,7 @@ def compute_nonzero_block_shapes(charges: List[np.ndarray], return charge_shape_dict -def retrieve_non_zero_diagonal_blocks_deprecated( +def retrieve_non_zero_diagonal_blocks( data: np.ndarray, charges: List[np.ndarray], flows: List[Union[bool, int]], @@ -143,8 +143,6 @@ def retrieve_non_zero_diagonal_blocks_deprecated( """ Given the meta data and underlying data of a symmetric matrix, compute all diagonal blocks and return them in a dict. - This is a deprecated version which in general performs worse than the - current main implementation. Args: data: An np.ndarray of the data. The number of elements in `data` has to match the number of non-zero elements defined by `charges` @@ -236,7 +234,7 @@ def retrieve_non_zero_diagonal_blocks_deprecated( return blocks -def retrieve_non_zero_diagonal_blocks( +def retrieve_non_zero_diagonal_blocks_deprecated( data: np.ndarray, charges: List[np.ndarray], flows: List[Union[bool, int]], @@ -244,6 +242,9 @@ def retrieve_non_zero_diagonal_blocks( """ Given the meta data and underlying data of a symmetric matrix, compute all diagonal blocks and return them in a dict. + This is a deprecated version which in general performs worse than the + current main implementation. + Args: data: An np.ndarray of the data. The number of elements in `data` has to match the number of non-zero elements defined by `charges` @@ -287,7 +288,6 @@ def retrieve_non_zero_diagonal_blocks( row_charges, return_inverse=True, return_counts=True) unique_column_charges, column_locations, column_dims = np.unique( column_charges, return_inverse=True, return_counts=True) - #convenience container for storing the degeneracies of each #row and column charge row_degeneracies = dict(zip(unique_row_charges, row_dims)) @@ -300,12 +300,11 @@ def retrieve_non_zero_diagonal_blocks( degeneracy_vector = row_dims[column_locations] stop_positions = np.cumsum(degeneracy_vector) - blocks = {} for c in common_charges: #numpy broadcasting is substantially faster than kron! a = np.expand_dims( - stop_positions[column_locations == -c] - row_degeneracies[c], 0) + stop_positions[column_charges == -c] - row_degeneracies[c], 0) b = np.expand_dims(np.arange(row_degeneracies[c]), 1) if not return_data: blocks[c] = [a + b, (row_degeneracies[c], column_degeneracies[-c])] @@ -572,16 +571,7 @@ def raise_error(): if self.shape[n] > dense_shape[n]: raise_error() elif dense_shape[n] < self.shape[n]: - while dense_shape[n] < self.shape[n]: - #split index at n - try: - i1, i2 = split_index(self.indices.pop(n)) - except ValueError: - raise_error() - self.indices.insert(n, i1) - self.indices.insert(n + 1, i2) - if self.shape[n] < dense_shape[n]: - raise_error() + raise_error() def get_diagonal_blocks(self, return_data: Optional[bool] = True) -> Dict: """ diff --git a/tensornetwork/block_tensor/index.py b/tensornetwork/block_tensor/index.py index 1549a422e..fc6b36cd8 100644 --- a/tensornetwork/block_tensor/index.py +++ b/tensornetwork/block_tensor/index.py @@ -19,8 +19,6 @@ from tensornetwork.network_components import Node, contract, contract_between # pylint: disable=line-too-long from tensornetwork.backends import backend_factory - -import numpy as np import copy from typing import List, Union, Any, Optional, Tuple, Text diff --git a/tensornetwork/block_tensor/index_test.py b/tensornetwork/block_tensor/index_test.py new file mode 100644 index 000000000..ff331a36a --- /dev/null +++ b/tensornetwork/block_tensor/index_test.py @@ -0,0 +1,46 @@ +import numpy as np +# pylint: disable=line-too-long +from tensornetwork.block_tensor.index import Index, fuse_index_pair, split_index, fuse_charges, fuse_degeneracies + + +def test_fuse_charges(): + q1 = np.asarray([0, 1]) + q2 = np.asarray([2, 3, 4]) + fused_charges = fuse_charges(q1, 1, q2, 1) + assert np.all(fused_charges == np.asarray([2, 3, 3, 4, 4, 5])) + fused_charges = fuse_charges(q1, 1, q2, -1) + assert np.all(fused_charges == np.asarray([-2, -1, -3, -2, -4, -3])) + + +def test_index_fusion_mul(): + D = 100 + B = 4 + dtype = np.int16 + q1 = np.random.randint(-B // 2, B // 2 + 1, + D).astype(dtype) #quantum numbers on leg 1 + q2 = np.random.randint(-B // 2, B // 2 + 1, + D).astype(dtype) #quantum numbers on leg 2 + i1 = Index(charges=q1, flow=1, name='index1') #index on leg 1 + i2 = Index(charges=q2, flow=1, name='index2') #index on leg 2 + + i12 = i1 * i2 + assert i12.left_child is i1 + assert i12.right_child is i2 + assert np.all(i12.charges == fuse_charges(q1, 1, q2, 1)) + + +def test_index_fusion(): + D = 100 + B = 4 + dtype = np.int16 + q1 = np.random.randint(-B // 2, B // 2 + 1, + D).astype(dtype) #quantum numbers on leg 1 + q2 = np.random.randint(-B // 2, B // 2 + 1, + D).astype(dtype) #quantum numbers on leg 2 + i1 = Index(charges=q1, flow=1, name='index1') #index on leg 1 + i2 = Index(charges=q2, flow=1, name='index2') #index on leg 2 + + i12 = fuse_index_pair(i1, i2) + assert i12.left_child is i1 + assert i12.right_child is i2 + assert np.all(i12.charges == fuse_charges(q1, 1, q2, 1))