Skip to content

Commit

Permalink
add tests for sym
Browse files Browse the repository at this point in the history
  • Loading branch information
kaihsin committed Apr 15, 2024
1 parent 8a318fc commit 2065d4b
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 6 deletions.
16 changes: 10 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ dependencies = [
"numpy>=1.26.4",
"scipy>=1.13.0",
"beartype>=0.18.2",
"pytest>=8.1.1",
]
readme = "README.md"
requires-python = ">= 3.10"

[project.scripts]
hello = "cytnx_torch:hello"
[tool.rye.scripts]
clean = "git clean -f"
"coverage:run" = "coverage run -m pytest test"
"coverage:xml" = "coverage xml"
"coverage:html" = "coverage html"
"coverage:report" = "coverage report"
"coverage:open" = "open htmlcov/index.html"
"coverage:github" = { chain = ["coverage:run", "coverage:xml", "coverage:report"]}
"test:all" = "pytest test"

[build-system]
requires = ["hatchling"]
Expand All @@ -43,10 +51,6 @@ dev-dependencies = [
[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.wheel]
packages = ["src/cytnx_torch"]


[tool.black]
line-length = 88

Expand Down
4 changes: 4 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
beartype==0.18.2
filelock==3.13.4
fsspec==2024.3.1
iniconfig==2.0.0
jinja2==3.1.3
markupsafe==2.1.5
mpmath==1.3.0
Expand All @@ -28,7 +29,10 @@ nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
packaging==24.0
pillow==10.3.0
pluggy==1.4.0
pytest==8.1.1
scipy==1.13.0
sympy==1.12
torch==2.2.2
Expand Down
77 changes: 77 additions & 0 deletions src/cytnx_torch/symmetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
from dataclasses import dataclass, field
from beartype.typing import List
from abc import abstractmethod


@dataclass(frozen=True)
class Symmetry:
label: str = field(default="")


# trait
@dataclass(frozen=True)
class AbelianSym(Symmetry):

@abstractmethod
def combine_rule(self, A: int, B: int) -> int:
raise NotImplementedError("not implement for abstract type trait.")

@abstractmethod
def check_qnums(self, qnums: List[int]) -> bool:
raise NotImplementedError("not implement for abstract type trait.")

def combine_qnums(self, qnums_a: List[int], qnums_b: List[int]) -> List[int]:
mesh_b, mesh_a = np.meshgrid(qnums_b, qnums_a)
return self.combine_rule(mesh_a.flatten(), mesh_b.flatten())


@dataclass(frozen=True)
class U1(AbelianSym):
"""
Unitary Symmetry class.
The U1 symmetry can have quantum number represent as arbitrary unsigned integer.
Fusion rule for combine two quantum number:
q1 + q2
"""

def combine_rule(self, A: int, B: int) -> int:
return A + B

def __str__(self) -> str:
return f"U1 label={self.label}"

def check_qnums(self, qnums: List[int]) -> bool:
return True


@dataclass(frozen=True)
class Zn(AbelianSym):
"""
Z(n) Symmetry class.
The Z(n) symmetry can have integer quantum number, with n > 1.
Fusion rule for combine two quantum number:
(q1 + q2)%n
"""

n: int = field(default=2)

def __post_init__(self):
if self.n < 2:
raise ValueError(
"Symmetry.Zn", "[ERROR] discrete symmetry Zn must have n >= 2."
)

def combine_rule(self, A: int, B: int) -> int:
return (A + B) % self.n

def __str__(self):
return f"Z{self.n} label={self.label}"

def check_qnums(self, qnums: List[int]) -> bool:
return np.all([(q >= 0) and (q < self.n) for q in qnums])
28 changes: 28 additions & 0 deletions test/test_sym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy as np
from cytnx_torch.symmetry import U1, Zn


def test_U1():

s1 = U1(label="a")

q1 = [0, 1, 2]
q2 = [0, -1, 3]

assert s1.check_qnums(q1)
assert s1.check_qnums(q2)

assert np.all(s1.combine_qnums(q1, q2) == [0, -1, 3, 1, 0, 4, 2, 1, 5])


def test_Zn():

s1 = Zn(label="x", n=3)

q1 = [0, 1, 2]
q2 = [2, 0]

assert s1.check_qnums(q1)
assert s1.check_qnums(q2)

assert np.all(s1.combine_qnums(q1, q2) == [2, 0, 0, 1, 1, 2])

0 comments on commit 2065d4b

Please sign in to comment.