From 2065d4b118f26d7308adc1afabe3e1ca6d96e42b Mon Sep 17 00:00:00 2001 From: kaihsin Date: Mon, 15 Apr 2024 13:07:34 -0400 Subject: [PATCH] add tests for sym --- pyproject.toml | 16 +++++--- requirements.lock | 4 ++ src/cytnx_torch/symmetry.py | 77 +++++++++++++++++++++++++++++++++++++ test/test_sym.py | 28 ++++++++++++++ 4 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 src/cytnx_torch/symmetry.py create mode 100644 test/test_sym.py diff --git a/pyproject.toml b/pyproject.toml index 50e5af1..cdac6ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,12 +12,20 @@ dependencies = [ "numpy>=1.26.4", "scipy>=1.13.0", "beartype>=0.18.2", + "pytest>=8.1.1", ] readme = "README.md" requires-python = ">= 3.10" -[project.scripts] -hello = "cytnx_torch:hello" +[tool.rye.scripts] +clean = "git clean -f" +"coverage:run" = "coverage run -m pytest test" +"coverage:xml" = "coverage xml" +"coverage:html" = "coverage html" +"coverage:report" = "coverage report" +"coverage:open" = "open htmlcov/index.html" +"coverage:github" = { chain = ["coverage:run", "coverage:xml", "coverage:report"]} +"test:all" = "pytest test" [build-system] requires = ["hatchling"] @@ -43,10 +51,6 @@ dev-dependencies = [ [tool.hatch.metadata] allow-direct-references = true -[tool.hatch.build.targets.wheel] -packages = ["src/cytnx_torch"] - - [tool.black] line-length = 88 diff --git a/requirements.lock b/requirements.lock index 2912208..e8576d2 100644 --- a/requirements.lock +++ b/requirements.lock @@ -11,6 +11,7 @@ beartype==0.18.2 filelock==3.13.4 fsspec==2024.3.1 +iniconfig==2.0.0 jinja2==3.1.3 markupsafe==2.1.5 mpmath==1.3.0 @@ -28,7 +29,10 @@ nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.19.3 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 +packaging==24.0 pillow==10.3.0 +pluggy==1.4.0 +pytest==8.1.1 scipy==1.13.0 sympy==1.12 torch==2.2.2 diff --git a/src/cytnx_torch/symmetry.py b/src/cytnx_torch/symmetry.py new file mode 100644 index 0000000..24bb65e --- /dev/null +++ b/src/cytnx_torch/symmetry.py @@ -0,0 +1,77 @@ +import numpy as np +from dataclasses import dataclass, field +from beartype.typing import List +from abc import abstractmethod + + +@dataclass(frozen=True) +class Symmetry: + label: str = field(default="") + + +# trait +@dataclass(frozen=True) +class AbelianSym(Symmetry): + + @abstractmethod + def combine_rule(self, A: int, B: int) -> int: + raise NotImplementedError("not implement for abstract type trait.") + + @abstractmethod + def check_qnums(self, qnums: List[int]) -> bool: + raise NotImplementedError("not implement for abstract type trait.") + + def combine_qnums(self, qnums_a: List[int], qnums_b: List[int]) -> List[int]: + mesh_b, mesh_a = np.meshgrid(qnums_b, qnums_a) + return self.combine_rule(mesh_a.flatten(), mesh_b.flatten()) + + +@dataclass(frozen=True) +class U1(AbelianSym): + """ + Unitary Symmetry class. + The U1 symmetry can have quantum number represent as arbitrary unsigned integer. + + Fusion rule for combine two quantum number: + + q1 + q2 + + """ + + def combine_rule(self, A: int, B: int) -> int: + return A + B + + def __str__(self) -> str: + return f"U1 label={self.label}" + + def check_qnums(self, qnums: List[int]) -> bool: + return True + + +@dataclass(frozen=True) +class Zn(AbelianSym): + """ + Z(n) Symmetry class. + The Z(n) symmetry can have integer quantum number, with n > 1. + + Fusion rule for combine two quantum number: + + (q1 + q2)%n + """ + + n: int = field(default=2) + + def __post_init__(self): + if self.n < 2: + raise ValueError( + "Symmetry.Zn", "[ERROR] discrete symmetry Zn must have n >= 2." + ) + + def combine_rule(self, A: int, B: int) -> int: + return (A + B) % self.n + + def __str__(self): + return f"Z{self.n} label={self.label}" + + def check_qnums(self, qnums: List[int]) -> bool: + return np.all([(q >= 0) and (q < self.n) for q in qnums]) diff --git a/test/test_sym.py b/test/test_sym.py new file mode 100644 index 0000000..848d359 --- /dev/null +++ b/test/test_sym.py @@ -0,0 +1,28 @@ +import numpy as np +from cytnx_torch.symmetry import U1, Zn + + +def test_U1(): + + s1 = U1(label="a") + + q1 = [0, 1, 2] + q2 = [0, -1, 3] + + assert s1.check_qnums(q1) + assert s1.check_qnums(q2) + + assert np.all(s1.combine_qnums(q1, q2) == [0, -1, 3, 1, 0, 4, 2, 1, 5]) + + +def test_Zn(): + + s1 = Zn(label="x", n=3) + + q1 = [0, 1, 2] + q2 = [2, 0] + + assert s1.check_qnums(q1) + assert s1.check_qnums(q2) + + assert np.all(s1.combine_qnums(q1, q2) == [2, 0, 0, 1, 1, 2])