Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

finished getitem by Qid for BlockUniTensor #11

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/cytnx_torch/bond.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from dataclasses import dataclass, field
from beartype.typing import List, Tuple
from beartype.typing import List, Tuple, Optional
from enum import Enum
from abc import abstractmethod
from copy import deepcopy
Expand Down Expand Up @@ -113,6 +113,19 @@ class SymBond(AbstractBond):
)
_syms: Tuple[Symmetry] = field(default_factory=tuple)

def slice_by_qindices(
self, qindices: Optional[np.ndarray[int]] = None
) -> "SymBond":

if qindices is None:
return self
else:
return SymBond(
bond_type=self.bond_type,
qnums=[Qs(self._qnums[qn]) >> self._degs[qn] for qn in qindices],
syms=self._syms,
)

def _check_meta_eq(self, other: AbstractBond) -> bool:
if not isinstance(other, SymBond):
return False
Expand Down
1 change: 1 addition & 0 deletions src/cytnx_torch/internal_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALL_ELEMENTS = slice(None, None, None)
5 changes: 5 additions & 0 deletions src/cytnx_torch/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from abc import abstractmethod


@dataclass
class Qid:
value: int


@dataclass(frozen=True)
class Symmetry:
label: str = field(default="")
Expand Down
200 changes: 150 additions & 50 deletions src/cytnx_torch/unitensor/block_unitensor.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,78 @@
from dataclasses import dataclass, field
from beartype.typing import List, Optional, Union, Tuple
from dataclasses import dataclass
from beartype.typing import List, Optional, Union, Tuple, Sequence
import numpy as np
import torch
from ..symmetry import Symmetry
from ..symmetry import Symmetry, Qid
from ..bond import BondType, SymBond
from .base import AbstractUniTensor
from functools import cached_property
from cytnx_torch.internal_utils import ALL_ELEMENTS


@dataclass
class BlockUniTensorMeta:
qn_indices_map: np.ndarray[int] # shape = [nblock, rank]

def rank(self) -> int:
return self.qn_indices_map.shape[1]

def permute(self, idx_map: Sequence[int]) -> "BlockUniTensorMeta":
return BlockUniTensorMeta(qn_indices_map=self.qn_indices_map[:, idx_map])

def get_block_idx(self, qn_indices: np.ndarray[int]) -> int | None:
if len(qn_indices) != self.rank():
raise ValueError(
"qn_indices should have the same length as the rank of the tensor"
)

loc = np.where(np.all(self.qn_indices_map == qn_indices, axis=1)).flatten()

if len(loc) == 0:
return None
else:
return loc[0]

def select(
self, key: List[Union[None, List[int]]]
) -> Tuple["BlockUniTensorMeta", List[int]]:
if len(key) != self.rank():
raise ValueError(
"key should have the same length as the rank of the tensor"
)

new_maps = []
old_blk_id = []
for blk_id, map in enumerate(self.qn_indices_map):
# TODO optimize this
new_mp = []
for m, k in zip(map, key):
if k is None:
new_mp.append(m)
else:
if m in k:
new_mp.append(k.index(m))
else:
break
if len(new_mp) != self.rank():
continue
new_maps.append(new_mp)
old_blk_id.append(blk_id)
return BlockUniTensorMeta(qn_indices_map=np.array(new_maps)), old_blk_id


@dataclass
class BlockGenerator:
bonds: List[SymBond]
look_up: np.ndarray[int] = field(init=False) # shape = [prod(shape), rank]
meta: Optional[BlockUniTensorMeta] = None

def __post_init__(self):
if len(self.bonds) > 1 and not self.bonds[0].check_same_symmetry(
*self.bonds[1:]
):
raise ValueError("bonds have different symmetry")

self.look_up = self._generate_look_up()
if self.meta is None:
self.meta = self._generate_meta()

@cached_property
def _get_symmetries(self) -> Tuple[Symmetry]:
Expand All @@ -28,12 +81,21 @@ def _get_symmetries(self) -> Tuple[Symmetry]:
else:
return tuple()

def _generate_look_up(self) -> np.ndarray[int]:
def _generate_meta(self) -> np.ndarray[int]:
qnindices = [np.arange(len(bd._qnums)) for bd in self.bonds]
qn_indices_map = np.meshgrid(*qnindices)
qn_indices_map = np.array([mp.flatten() for mp in qn_indices_map]).T

return qn_indices_map
# filter out the ones that are not allowed:
qn_indices_map = np.array(
[
qn_indices
for qn_indices in qn_indices_map
if self._can_have_block(qn_indices)
]
)

return BlockUniTensorMeta(qn_indices_map=qn_indices_map)

def _can_have_block(self, qn_indices: np.ndarray[int]) -> bool:
# [rank, nsym]
Expand All @@ -43,75 +105,113 @@ def _can_have_block(self, qn_indices: np.ndarray[int]) -> bool:
for bd, qidx in zip(self.bonds, qn_indices)
]
)

net_qns = [
sym.merge_qnums(qns[:, i]) for i, sym in enumerate(self._get_symmetries)
]

return np.all(net_qns == 0)
return np.allclose(net_qns, 0)

def __iter__(self):
self.cntr = 0
return self

def __next__(self):
if self.cntr < len(self.look_up):
qn_indices = self.look_up[self.cntr]
if self.cntr < len(self.meta.qn_indices_map):
qn_indices = self.meta.qn_indices_map[self.cntr]
self.cntr += 1

if self._can_have_block(qn_indices):
return qn_indices, torch.zeros(
[bd._degs[qidx] for bd, qidx in zip(self.bonds, qn_indices)]
)
else:
return None, None
return torch.zeros(
[bd._degs[qidx] for bd, qidx in zip(self.bonds, qn_indices)]
)

else:
raise StopIteration


@dataclass
class BlockUniTensorMeta:
qn_indices_map: np.ndarray[int] # shape = [nblock, rank]

def rank(self) -> int:
return self.qn_indices_map.shape[1]

def permute(self, idx_map: int) -> "BlockUniTensorMeta":
return BlockUniTensorMeta(qn_indices_map=self.qn_indices_map[:, idx_map])
class BlockUniTensor(AbstractUniTensor):

def get_block_idx(self, qn_indices: np.ndarray[int]) -> int | None:
if len(qn_indices) != self.rank():
raise ValueError(
"qn_indices should have the same length as the rank of the tensor"
)
blocks: List[torch.Tensor] = None
meta: BlockUniTensorMeta = None

loc = np.where(np.all(self.qn_indices_map == qn_indices, axis=1)).flatten()
def __post_init__(self):

if len(loc) == 0:
return None
if self.meta is None:
bg = BlockGenerator(bonds=self.bonds)
self.meta = bg.meta
else:
return loc[0]
bg = BlockGenerator(bonds=self.bonds, meta=self.meta)

# if blocks is not None, we don't generate from bg, and should be done carefully by internal!
if self.blocks is None:
self.blocks = [block for block in bg]

@dataclass
class BlockUniTensor(AbstractUniTensor):
def __getitem__(
self, key: Union[Tuple, List[Qid], int, Qid, slice]
) -> "BlockUniTensor":
"""
if element in qid_accessor is None, then it means all elements

"""

def collect_qids_per_rank(item):
qidx = None
match item:
case int():
raise NotImplementedError("int key is not supported yet")
case slice():
if not item == ALL_ELEMENTS:
raise NotImplementedError(
"slice key currently only support all-elements, i.e. ':'"
)
case Qid(value):
qidx = [value]
case list():
# check instance:
if not all([isinstance(x, Qid) for x in item]):
raise ValueError("list should contain only Qid for now")

qidx = [x.value for x in item]

case _:
raise ValueError(
"key should be either int, slice, Qid, or list of Qid"
)

return qidx

qid_accessor = [] # [naxis, list of qid]
if isinstance(key, tuple):
for item in key:
qid_accessor.append(collect_qids_per_rank(item))
else:
qid_accessor.append(collect_qids_per_rank(key))

blocks: List[torch.Tensor] = None
meta: Optional[BlockUniTensorMeta] = None
# pad the rest with None
qid_accessor += [None] * (self.rank - len(qid_accessor))

def __post_init__(self):
assert (
len(qid_accessor) == self.rank
), "key should have the same length as the rank of the tensor"

if self.meta is None and self.blocks is None: # recalculate meta
bg = BlockGenerator(bonds=self.bonds)
# TODO create new metas:
new_labels = self.labels
new_bonds = [
bd.slice_by_qindices(qids) for qids, bd in zip(qid_accessor, self.bonds)
]
print(qid_accessor)
# filter out the block and qnindices:
new_meta, selected_blk_ids = self.meta.select(qid_accessor)
new_blocks = [self.blocks[blk_id] for blk_id in selected_blk_ids]

blocks = []
qn_indices_map = []
for qn_indices, block in bg:
if qn_indices is not None:
blocks.append(block)
qn_indices_map.append(qn_indices)
self.blocks = blocks
self.meta = BlockUniTensorMeta(qn_indices_map=np.array(qn_indices_map))
return BlockUniTensor(
labels=new_labels,
bonds=new_bonds,
backend_args=self.backend_args,
name=self.name,
rowrank=self.rowrank,
blocks=new_blocks,
meta=new_meta,
)

def _repr_body_diagram(self) -> str:
Nin = self.rowrank
Expand Down
2 changes: 1 addition & 1 deletion src/cytnx_torch/unitensor/regular_unitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _repr_body_diagram(self) -> str:
return out

def __getitem__(self, key) -> "RegularUniTensor":

print(key)
accessor = key
if not isinstance(key, tuple):
accessor = (key,)
Expand Down
19 changes: 19 additions & 0 deletions test/test_bd.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ def test_eq_sym_bond():
assert b1 != b3


def test_slice_by_qnid():
b2 = SymBond(
bond_type=BondType.IN,
qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4, Qs([3, 0]) >> 4, Qs([4, 1]) >> 5],
syms=[U1(), Zn(n=2)],
)

b1 = SymBond(
bond_type=BondType.IN,
qnums=[Qs([-1, 0]) >> 3, Qs([3, 0]) >> 4],
syms=[U1(), Zn(n=2)],
)

b3 = b2.slice_by_qindices([0, 2])

assert b3 == b1
assert b3.dim == b1.dim


def test_contractible():
b2 = SymBond(
bond_type=BondType.IN,
Expand Down
Loading