From 7d064fcb0b990f24c09f923ab5596a366c74a646 Mon Sep 17 00:00:00 2001 From: Kai Date: Tue, 30 Jul 2024 05:09:24 -0400 Subject: [PATCH 1/2] finished getitem for Qid --- src/cytnx_torch/bond.py | 15 +- src/cytnx_torch/internal_utils.py | 1 + src/cytnx_torch/symmetry.py | 5 + src/cytnx_torch/unitensor/block_unitensor.py | 189 +++++++++++++----- .../unitensor/regular_unitensor.py | 2 +- test/test_bd.py | 19 ++ test/test_blk_generator.py | 56 ++---- test/test_utn_blk.py | 27 ++- 8 files changed, 227 insertions(+), 87 deletions(-) create mode 100644 src/cytnx_torch/internal_utils.py diff --git a/src/cytnx_torch/bond.py b/src/cytnx_torch/bond.py index 630da5a..df219f3 100644 --- a/src/cytnx_torch/bond.py +++ b/src/cytnx_torch/bond.py @@ -1,6 +1,6 @@ import numpy as np from dataclasses import dataclass, field -from beartype.typing import List, Tuple +from beartype.typing import List, Tuple, Optional from enum import Enum from abc import abstractmethod from copy import deepcopy @@ -113,6 +113,19 @@ class SymBond(AbstractBond): ) _syms: Tuple[Symmetry] = field(default_factory=tuple) + def slice_by_qindices( + self, qindices: Optional[np.ndarray[int]] = None + ) -> "SymBond": + + if qindices is None: + return self + else: + return SymBond( + bond_type=self.bond_type, + qnums=[Qs(self._qnums[qn]) >> self._degs[qn] for qn in qindices], + syms=self._syms, + ) + def _check_meta_eq(self, other: AbstractBond) -> bool: if not isinstance(other, SymBond): return False diff --git a/src/cytnx_torch/internal_utils.py b/src/cytnx_torch/internal_utils.py new file mode 100644 index 0000000..ecf648b --- /dev/null +++ b/src/cytnx_torch/internal_utils.py @@ -0,0 +1 @@ +ALL_ELEMENTS = slice(None, None, None) diff --git a/src/cytnx_torch/symmetry.py b/src/cytnx_torch/symmetry.py index fdcac74..5d32c9e 100644 --- a/src/cytnx_torch/symmetry.py +++ b/src/cytnx_torch/symmetry.py @@ -4,6 +4,11 @@ from abc import abstractmethod +@dataclass +class Qid: + value: int + + @dataclass(frozen=True) class Symmetry: label: str = field(default="") diff --git a/src/cytnx_torch/unitensor/block_unitensor.py b/src/cytnx_torch/unitensor/block_unitensor.py index e467b4f..4fc0af0 100644 --- a/src/cytnx_torch/unitensor/block_unitensor.py +++ b/src/cytnx_torch/unitensor/block_unitensor.py @@ -1,17 +1,58 @@ -from dataclasses import dataclass, field -from beartype.typing import List, Optional, Union, Tuple +from dataclasses import dataclass +from beartype.typing import List, Optional, Union, Tuple, Sequence import numpy as np import torch -from ..symmetry import Symmetry +from ..symmetry import Symmetry, Qid from ..bond import BondType, SymBond from .base import AbstractUniTensor from functools import cached_property +from cytnx_torch.internal_utils import ALL_ELEMENTS + + +@dataclass +class BlockUniTensorMeta: + qn_indices_map: np.ndarray[int] # shape = [nblock, rank] + + def rank(self) -> int: + return self.qn_indices_map.shape[1] + + def permute(self, idx_map: Sequence[int]) -> "BlockUniTensorMeta": + return BlockUniTensorMeta(qn_indices_map=self.qn_indices_map[:, idx_map]) + + def get_block_idx(self, qn_indices: np.ndarray[int]) -> int | None: + if len(qn_indices) != self.rank(): + raise ValueError( + "qn_indices should have the same length as the rank of the tensor" + ) + + loc = np.where(np.all(self.qn_indices_map == qn_indices, axis=1)).flatten() + + if len(loc) == 0: + return None + else: + return loc[0] + + def select( + self, key: List[Union[None, List[int]]] + ) -> Tuple["BlockUniTensorMeta", List[int]]: + if len(key) != self.rank(): + raise ValueError( + "key should have the same length as the rank of the tensor" + ) + + new_map = [] + old_blk_id = [] + for blk_id, map in enumerate(self.qn_indices_map): + if all([k is None or m in k for m, k in zip(map, key)]): + new_map.append(map) + old_blk_id.append(blk_id) + return BlockUniTensorMeta(qn_indices_map=new_map), old_blk_id @dataclass class BlockGenerator: bonds: List[SymBond] - look_up: np.ndarray[int] = field(init=False) # shape = [prod(shape), rank] + meta: Optional[BlockUniTensorMeta] = None def __post_init__(self): if len(self.bonds) > 1 and not self.bonds[0].check_same_symmetry( @@ -19,7 +60,8 @@ def __post_init__(self): ): raise ValueError("bonds have different symmetry") - self.look_up = self._generate_look_up() + if self.meta is None: + self.meta = self._generate_meta() @cached_property def _get_symmetries(self) -> Tuple[Symmetry]: @@ -28,12 +70,21 @@ def _get_symmetries(self) -> Tuple[Symmetry]: else: return tuple() - def _generate_look_up(self) -> np.ndarray[int]: + def _generate_meta(self) -> np.ndarray[int]: qnindices = [np.arange(len(bd._qnums)) for bd in self.bonds] qn_indices_map = np.meshgrid(*qnindices) qn_indices_map = np.array([mp.flatten() for mp in qn_indices_map]).T - return qn_indices_map + # filter out the ones that are not allowed: + qn_indices_map = np.array( + [ + qn_indices + for qn_indices in qn_indices_map + if self._can_have_block(qn_indices) + ] + ) + + return BlockUniTensorMeta(qn_indices_map=qn_indices_map) def _can_have_block(self, qn_indices: np.ndarray[int]) -> bool: # [rank, nsym] @@ -43,75 +94,113 @@ def _can_have_block(self, qn_indices: np.ndarray[int]) -> bool: for bd, qidx in zip(self.bonds, qn_indices) ] ) - net_qns = [ sym.merge_qnums(qns[:, i]) for i, sym in enumerate(self._get_symmetries) ] - - return np.all(net_qns == 0) + return np.allclose(net_qns, 0) def __iter__(self): self.cntr = 0 return self def __next__(self): - if self.cntr < len(self.look_up): - qn_indices = self.look_up[self.cntr] + if self.cntr < len(self.meta.qn_indices_map): + qn_indices = self.meta.qn_indices_map[self.cntr] self.cntr += 1 - if self._can_have_block(qn_indices): - return qn_indices, torch.zeros( - [bd._degs[qidx] for bd, qidx in zip(self.bonds, qn_indices)] - ) - else: - return None, None + return torch.zeros( + [bd._degs[qidx] for bd, qidx in zip(self.bonds, qn_indices)] + ) + else: raise StopIteration @dataclass -class BlockUniTensorMeta: - qn_indices_map: np.ndarray[int] # shape = [nblock, rank] - - def rank(self) -> int: - return self.qn_indices_map.shape[1] - - def permute(self, idx_map: int) -> "BlockUniTensorMeta": - return BlockUniTensorMeta(qn_indices_map=self.qn_indices_map[:, idx_map]) +class BlockUniTensor(AbstractUniTensor): - def get_block_idx(self, qn_indices: np.ndarray[int]) -> int | None: - if len(qn_indices) != self.rank(): - raise ValueError( - "qn_indices should have the same length as the rank of the tensor" - ) + blocks: List[torch.Tensor] = None + meta: BlockUniTensorMeta = None - loc = np.where(np.all(self.qn_indices_map == qn_indices, axis=1)).flatten() + def __post_init__(self): - if len(loc) == 0: - return None + if self.meta is None: + bg = BlockGenerator(bonds=self.bonds) + self.meta = bg.meta else: - return loc[0] + bg = BlockGenerator(bonds=self.bonds, meta=self.meta) + # if blocks is not None, we don't generate from bg, and should be done carefully by internal! + if self.blocks is None: + self.blocks = [block for block in bg] -@dataclass -class BlockUniTensor(AbstractUniTensor): + def __getitem__( + self, key: Union[Tuple, List[Qid], int, Qid, slice] + ) -> "BlockUniTensor": + """ + if element in qid_accessor is None, then it means all elements + + """ + + def collect_qids_per_rank(item): + qidx = None + match item: + case int(): + raise NotImplementedError("int key is not supported yet") + case slice(): + if not item == ALL_ELEMENTS: + raise NotImplementedError( + "slice key currently only support all-elements, i.e. ':'" + ) + case Qid(value): + qidx = [value] + case list(): + # check instance: + if not all([isinstance(x, Qid) for x in item]): + raise ValueError("list should contain only Qid for now") + + qidx = [x.value for x in item] + + case _: + raise ValueError( + "key should be either int, slice, Qid, or list of Qid" + ) + + return qidx + + qid_accessor = [] # [naxis, list of qid] + if isinstance(key, tuple): + for item in key: + qid_accessor.append(collect_qids_per_rank(item)) + else: + qid_accessor.append(collect_qids_per_rank(key)) - blocks: List[torch.Tensor] = None - meta: Optional[BlockUniTensorMeta] = None + # pad the rest with None + qid_accessor += [None] * (self.rank - len(qid_accessor)) - def __post_init__(self): + assert ( + len(qid_accessor) == self.rank + ), "key should have the same length as the rank of the tensor" - if self.meta is None and self.blocks is None: # recalculate meta - bg = BlockGenerator(bonds=self.bonds) + # TODO create new metas: + new_labels = self.labels + new_bonds = [ + bd.slice_by_qindices(qids) for qids, bd in zip(qid_accessor, self.bonds) + ] + print(qid_accessor) + # filter out the block and qnindices: + new_meta, selected_blk_ids = self.meta.select(qid_accessor) + new_blocks = [self.blocks[blk_id] for blk_id in selected_blk_ids] - blocks = [] - qn_indices_map = [] - for qn_indices, block in bg: - if qn_indices is not None: - blocks.append(block) - qn_indices_map.append(qn_indices) - self.blocks = blocks - self.meta = BlockUniTensorMeta(qn_indices_map=np.array(qn_indices_map)) + return BlockUniTensor( + labels=new_labels, + bonds=new_bonds, + backend_args=self.backend_args, + name=self.name, + rowrank=self.rowrank, + blocks=new_blocks, + meta=new_meta, + ) def _repr_body_diagram(self) -> str: Nin = self.rowrank diff --git a/src/cytnx_torch/unitensor/regular_unitensor.py b/src/cytnx_torch/unitensor/regular_unitensor.py index 4d1a37b..824ba41 100644 --- a/src/cytnx_torch/unitensor/regular_unitensor.py +++ b/src/cytnx_torch/unitensor/regular_unitensor.py @@ -60,7 +60,7 @@ def _repr_body_diagram(self) -> str: return out def __getitem__(self, key) -> "RegularUniTensor": - + print(key) accessor = key if not isinstance(key, tuple): accessor = (key,) diff --git a/test/test_bd.py b/test/test_bd.py index bfaa5e7..7940de9 100644 --- a/test/test_bd.py +++ b/test/test_bd.py @@ -55,6 +55,25 @@ def test_eq_sym_bond(): assert b1 != b3 +def test_slice_by_qnid(): + b2 = SymBond( + bond_type=BondType.IN, + qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4, Qs([3, 0]) >> 4, Qs([4, 1]) >> 5], + syms=[U1(), Zn(n=2)], + ) + + b1 = SymBond( + bond_type=BondType.IN, + qnums=[Qs([-1, 0]) >> 3, Qs([3, 0]) >> 4], + syms=[U1(), Zn(n=2)], + ) + + b3 = b2.slice_by_qindices([0, 2]) + + assert b3 == b1 + assert b3.dim == b1.dim + + def test_contractible(): b2 = SymBond( bond_type=BondType.IN, diff --git a/test/test_blk_generator.py b/test/test_blk_generator.py index 6068842..6173bec 100644 --- a/test/test_blk_generator.py +++ b/test/test_blk_generator.py @@ -8,7 +8,7 @@ def test_generator_init(): b1 = SymBond( bond_type=BondType.IN, - qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4, Qs([1, 1]) >> 5], + qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4], syms=[U1(), Zn(n=2)], ) b2 = SymBond( @@ -18,30 +18,21 @@ def test_generator_init(): ) b3 = SymBond( bond_type=BondType.OUT, - qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4], + qnums=[Qs([-1, 0]) >> 3, Qs([0, 0]) >> 4, Qs([-1, 1]) >> 4], syms=[U1(), Zn(n=2)], ) bg = BlockGenerator(bonds=[b1, b2, b3]) expected_look_up = [ - [0, 0, 0], - [0, 0, 1], - [0, 1, 0], - [0, 1, 1], - [1, 0, 0], - [1, 0, 1], - [1, 1, 0], - [1, 1, 1], - [2, 0, 0], - [2, 0, 1], - [2, 1, 0], - [2, 1, 1], + [0, 1, 2], # o + [1, 0, 2], # o + [1, 1, 1], # o ] - assert bg.look_up.shape == (3 * 2 * 2, 3) + assert bg.meta.qn_indices_map.shape == (3, 3) - assert set(tuple(x) for x in bg.look_up.tolist()) == set( + assert set(tuple(x) for x in bg.meta.qn_indices_map.tolist()) == set( tuple(x) for x in expected_look_up ) @@ -72,7 +63,7 @@ def test_generator_interator(): b1 = SymBond( bond_type=BondType.IN, - qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4, Qs([1, 1]) >> 5], + qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4], syms=[U1(), Zn(n=2)], ) b2 = SymBond( @@ -82,28 +73,25 @@ def test_generator_interator(): ) b3 = SymBond( bond_type=BondType.OUT, - qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4], + qnums=[Qs([-1, 0]) >> 3, Qs([0, 0]) >> 4, Qs([-1, 1]) >> 4], syms=[U1(), Zn(n=2)], ) bg = BlockGenerator(bonds=[b1, b2, b3]) expected_look_up = [ - [0, 0, 0], - [0, 0, 1], - [0, 1, 0], - [0, 1, 1], - [1, 0, 0], - [1, 0, 1], - [1, 1, 0], - [1, 1, 1], - [2, 0, 0], - [2, 0, 1], - [2, 1, 0], - [2, 1, 1], + [0, 1, 2], # o + [1, 0, 2], # o + [1, 1, 1], # o ] - expected_look_up = set(tuple(x) for x in expected_look_up) - for qn_indices, blk in bg: - if qn_indices is not None: - assert tuple(qn_indices) in expected_look_up + assert bg.meta.qn_indices_map.shape == (3, 3) + + assert set(tuple(x) for x in bg.meta.qn_indices_map.tolist()) == set( + tuple(x) for x in expected_look_up + ) + + for i, blk in enumerate(bg): + assert blk.shape == tuple( + bg.bonds[j]._degs[qid] for j, qid in enumerate(bg.meta.qn_indices_map[i]) + ) diff --git a/test/test_utn_blk.py b/test/test_utn_blk.py index 31b6256..9b49b10 100644 --- a/test/test_utn_blk.py +++ b/test/test_utn_blk.py @@ -1,7 +1,7 @@ from cytnx_torch.bond import Qs, SymBond, BondType from cytnx_torch.symmetry import U1, Zn from cytnx_torch.unitensor import UniTensor -from cytnx_torch.unitensor.block_unitensor import BlockUniTensor +from cytnx_torch.unitensor.block_unitensor import BlockUniTensor, Qid import numpy as np @@ -50,3 +50,28 @@ def test_relabel_blk(): assert np.all(ut2.labels == ["a", "c", "x"]) assert isinstance(ut2, BlockUniTensor) + + +def test_getitem(): + b1 = SymBond( + bond_type=BondType.IN, + qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4], + syms=[U1(), Zn(n=2)], + ) + b2 = SymBond( + bond_type=BondType.IN, + qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4], + syms=[U1(), Zn(n=2)], + ) + b3 = SymBond( + bond_type=BondType.OUT, + qnums=[Qs([-1, 0]) >> 3, Qs([0, 0]) >> 4, Qs([-1, 1]) >> 5], + syms=[U1(), Zn(n=2)], + ) + + ut = UniTensor(labels=["a", "b", "c"], bonds=[b1, b2, b3], dtype=float) + + x = ut[Qid(0), :, [Qid(0), Qid(2)]] + + assert x.shape == (3, 7, 8) + assert x.labels == ["a", "b", "c"] From 929cae0afb1b31c3fe4b3cf49cf24e4f320f82c1 Mon Sep 17 00:00:00 2001 From: Kai Date: Tue, 30 Jul 2024 05:21:47 -0400 Subject: [PATCH 2/2] fix bug in remapping indices --- src/cytnx_torch/unitensor/block_unitensor.py | 21 +++++++++++++++----- test/test_utn_blk.py | 3 +++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/cytnx_torch/unitensor/block_unitensor.py b/src/cytnx_torch/unitensor/block_unitensor.py index 4fc0af0..68ccbca 100644 --- a/src/cytnx_torch/unitensor/block_unitensor.py +++ b/src/cytnx_torch/unitensor/block_unitensor.py @@ -40,13 +40,24 @@ def select( "key should have the same length as the rank of the tensor" ) - new_map = [] + new_maps = [] old_blk_id = [] for blk_id, map in enumerate(self.qn_indices_map): - if all([k is None or m in k for m, k in zip(map, key)]): - new_map.append(map) - old_blk_id.append(blk_id) - return BlockUniTensorMeta(qn_indices_map=new_map), old_blk_id + # TODO optimize this + new_mp = [] + for m, k in zip(map, key): + if k is None: + new_mp.append(m) + else: + if m in k: + new_mp.append(k.index(m)) + else: + break + if len(new_mp) != self.rank(): + continue + new_maps.append(new_mp) + old_blk_id.append(blk_id) + return BlockUniTensorMeta(qn_indices_map=np.array(new_maps)), old_blk_id @dataclass diff --git a/test/test_utn_blk.py b/test/test_utn_blk.py index 9b49b10..966644d 100644 --- a/test/test_utn_blk.py +++ b/test/test_utn_blk.py @@ -75,3 +75,6 @@ def test_getitem(): assert x.shape == (3, 7, 8) assert x.labels == ["a", "b", "c"] + assert len(x.blocks) == 1 + assert x.meta.qn_indices_map.shape == (1, 3) + assert np.all(x.meta.qn_indices_map[0] == [0, 1, 1])