Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix backend_args does not propogate when generate BlockUT #16

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
29 changes: 29 additions & 0 deletions src/cytnx_torch/linalg/svd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from cytnx_torch.bond import Bond, BondType # noqa F401
from cytnx_torch.unitensor.regular_unitensor import RegularUniTensor
from torch import linalg as torch_linalg # noqa F401


def _svd_regular_tn(A: RegularUniTensor):
"""
svd(A, **kwargs):
Singular Value Decomposition of a matrix A.
[U,S,V] = svd(A) returns the singular value decomposition of A such that A = U*S*V^T.
The input tensor can be complex.

Args:
A (cytnx.UniTensor): the input tensor

Returns:
U (cytnx.UniTensor): the left singular vectors
S (cytnx.UniTensor): the singular values
V (cytnx.UniTensor): the right singular vectors
"""

mat_A, cL, cR = A.as_matrix(left_bond_label="_tmp_L_", right_bond_label="_tmp_R_")

# get the data:
# u,s,v = torch_linalg.svd(mat_A)

# create new bonds:
# new_bond_dim = len(s)
# internal_bond = Bond(new_bond_dim,bond_type=BondType.OUT)
26 changes: 22 additions & 4 deletions src/cytnx_torch/unitensor/block_unitensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from beartype.typing import List, Optional, Union, Tuple, Sequence
import numpy as np
import torch
Expand Down Expand Up @@ -35,6 +35,12 @@ def get_block_idx(self, qn_indices: np.ndarray[int]) -> int | None:
def select(
self, key: List[Union[None, List[int]]]
) -> Tuple["BlockUniTensorMeta", List[int]]:
"""
select subset of the blocks based on the key.

If element is None, it is considered as all elements.

"""
if len(key) != self.rank():
raise ValueError(
"key should have the same length as the rank of the tensor"
Expand Down Expand Up @@ -62,8 +68,17 @@ def select(

@dataclass
class BlockGenerator:
"""
Generate blocks base on the input bonds

if meta is not provided, it will get generated by the generator itself.


"""

bonds: List[SymBond]
meta: Optional[BlockUniTensorMeta] = None
backend_args: dict = field(default_factory=dict)

def __post_init__(self):
if len(self.bonds) > 1 and not self.bonds[0].check_same_symmetry(
Expand Down Expand Up @@ -120,7 +135,8 @@ def __next__(self):
self.cntr += 1

return torch.zeros(
[bd._degs[qidx] for bd, qidx in zip(self.bonds, qn_indices)]
[bd._degs[qidx] for bd, qidx in zip(self.bonds, qn_indices)],
**self.backend_args,
)

else:
Expand All @@ -136,10 +152,12 @@ class BlockUniTensor(AbstractUniTensor):
def __post_init__(self):

if self.meta is None:
bg = BlockGenerator(bonds=self.bonds)
bg = BlockGenerator(bonds=self.bonds, backend_args=self.backend_args)
self.meta = bg.meta
else:
bg = BlockGenerator(bonds=self.bonds, meta=self.meta)
bg = BlockGenerator(
bonds=self.bonds, meta=self.meta, backend_args=self.backend_args
)

# if blocks is not None, we don't generate from bg, and should be done carefully by internal!
if self.blocks is None:
Expand Down
6 changes: 3 additions & 3 deletions src/cytnx_torch/unitensor/regular_unitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def permute(
)

def as_matrix(
self,
self, left_bond_label: str = "_cb_L_", right_bond_label: str = "_cb_R_"
) -> Tuple[
"RegularUniTensor", RegularUniTensorConverter, RegularUniTensorConverter
]:
Expand All @@ -169,7 +169,7 @@ def as_matrix(
dim=np.prod([b.dim for b in self.bonds[: self.rowrank]]),
bond_type=BondType.OUT if not is_directional_bonds else BondType.NONE,
)
new_label_L = "_aux_L_"
new_label_L = left_bond_label
converter_L = RegularUniTensorConverter(
output_bonds=self.bonds[: self.rowrank],
output_labels=self.labels[: self.rowrank],
Expand All @@ -181,7 +181,7 @@ def as_matrix(
dim=np.prod([b.dim for b in self.bonds[self.rowrank :]]),
bond_type=BondType.IN if not is_directional_bonds else BondType.NONE,
)
new_label_R = "_aux_R_"
new_label_R = right_bond_label
converter_R = RegularUniTensorConverter(
output_bonds=self.bonds[self.rowrank :],
output_labels=self.labels[self.rowrank :],
Expand Down
23 changes: 23 additions & 0 deletions test/test_utn_blk.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,26 @@ def test_getitem():
assert len(x.blocks) == 1
assert x.meta.qn_indices_map.shape == (1, 3)
assert np.all(x.meta.qn_indices_map[0] == [0, 1, 1])


def test_backend_arg_propogate():
b1 = SymBond(
bond_type=BondType.IN,
qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4],
syms=[U1(), Zn(n=2)],
)
b2 = SymBond(
bond_type=BondType.IN,
qnums=[Qs([-1, 0]) >> 3, Qs([0, 1]) >> 4],
syms=[U1(), Zn(n=2)],
)
b3 = SymBond(
bond_type=BondType.OUT,
qnums=[Qs([-1, 0]) >> 3, Qs([0, 0]) >> 4, Qs([-1, 1]) >> 5],
syms=[U1(), Zn(n=2)],
)

ut = UniTensor(labels=["a", "b", "c"], bonds=[b1, b2, b3], dtype=int)

for blk in ut.blocks:
blk.dtype == int
4 changes: 2 additions & 2 deletions test/test_utn_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def test_as_matrix():

ut.rowrank = 2

mat, cl, cr = ut.as_matrix()
mat, cl, cr = ut.as_matrix(left_bond_label="x", right_bond_label="y")

assert mat.shape == (6, 20)
assert mat.labels == ["_aux_L_", "_aux_R_"]
assert mat.labels == ["x", "y"]

reconstructed_ut = cl.contract(mat).contract(cr)

Expand Down