diff --git a/pycparser b/pycparser index f740995..28c9658 160000 --- a/pycparser +++ b/pycparser @@ -1 +1 @@ -Subproject commit f7409953060f1f4d0f8988f1e131a49f84c95eba +Subproject commit 28c96587c848378f4707af203eef6acb3866dcd9 diff --git a/pyquda/field.py b/pyquda/field.py index 82872a3..5117ae3 100644 --- a/pyquda/field.py +++ b/pyquda/field.py @@ -96,112 +96,100 @@ def cb2(data: numpy.ndarray, axes: List[int], dtype=None): return data_cb2.reshape(*shape[: axes[0]], 2, Lt, Lz, Ly, Lx // 2, *shape[axes[-1] + 1 :]) -def newLatticeFieldData(latt_info: LatticeInfo, dtype: str): +def newLatticeFieldData(latt_info: LatticeInfo, field: str): from . import getCUDABackend backend = getCUDABackend() Lx, Ly, Lz, Lt = latt_info.size if backend == "numpy": - if dtype == "Gauge": + if field == "Gauge": ret = numpy.zeros((Nd, 2, Lt, Lz, Ly, Lx // 2, Nc, Nc), " None: + def __init__(self, latt_info: LatticeInfo, L5: int) -> None: super().__init__(latt_info) - self.num_field = num_field + self.L5 = L5 class LatticeGauge(LatticeField): @@ -528,15 +516,15 @@ def __init__(self, latt_info: LatticeInfo, L5: int, value=None) -> None: @property def data_ptrs(self) -> Pointers: - return ndarrayPointer(self.data.reshape(self.num_field, -1), True) + return ndarrayPointer(self.data.reshape(self.L5, -1), True) @property def even_ptrs(self) -> Pointers: - return ndarrayPointer(self.data.reshape(self.num_field, 2, -1)[:, 0], True) + return ndarrayPointer(self.data.reshape(self.L5, 2, -1)[:, 0], True) @property def odd_ptrs(self) -> Pointers: - return ndarrayPointer(self.data.reshape(self.num_field, 2, -1)[:, 1], True) + return ndarrayPointer(self.data.reshape(self.L5, 2, -1)[:, 1], True) class LatticePropagator(LatticeField): diff --git a/pyquda/hmc_clover.py b/pyquda/hmc_clover.py index 7bcf9ab..ef5aa23 100644 --- a/pyquda/hmc_clover.py +++ b/pyquda/hmc_clover.py @@ -93,7 +93,7 @@ def updateGaugeField(self, dt: float): def computeCloverForce(self, dt, x: LatticeFermion, kappa2, ck): self.updateClover() - if self.offset_inv_square_root is None: + if self.num_flavor == 2: invertQuda(x.even_ptr, x.odd_ptr, self.invert_param) # Some conventions force the dagger to be YES here self.invert_param.dagger = QudaDagType.QUDA_DAG_YES @@ -193,7 +193,7 @@ def actionFermion(self, x: LatticeFermion) -> float: self.updateClover() self.invert_param.compute_clover_trlog = 0 self.invert_param.compute_action = 1 - if self.offset_inv_square_root is None: + if self.num_flavor == 2: invertQuda(x.even_ptr, x.odd_ptr, self.invert_param) else: num_offset = len(self.offset_inv_square_root) @@ -220,7 +220,7 @@ def updateClover(self): def initNoise(self, x: LatticeFermion, seed: int): self.updateClover() - if self.offset_fourth_root is None: + if self.num_flavor == 2: self.invert_param.dagger = QudaDagType.QUDA_DAG_YES MatQuda(x.odd_ptr, x.even_ptr, self.invert_param) self.invert_param.dagger = QudaDagType.QUDA_DAG_NO diff --git a/pyquda/version.py b/pyquda/version.py index c98d6e3..2e77e10 100644 --- a/pyquda/version.py +++ b/pyquda/version.py @@ -1 +1 @@ -__version__ = "0.6.13" +__version__ = "0.6.14" diff --git a/pyquda_pyx.py b/pyquda_pyx.py index aea8521..3f71ae8 100644 --- a/pyquda_pyx.py +++ b/pyquda_pyx.py @@ -1,4 +1,5 @@ import os +import sys from typing import Dict, List, NamedTuple, Union @@ -143,10 +144,12 @@ def build_pyquda_pyx(pyquda_root, quda_path): quda_include = os.path.join(quda_path, "include") assert os.path.exists(fake_libc_include), f"{fake_libc_include} not found" print(f"Building pyquda wrapper from {os.path.join(quda_include, 'quda.h')}") + sys.path.insert(1, os.path.join(pyquda_root, "pycparser")) try: from pycparser import parse_file, c_ast except ImportError or ModuleNotFoundError: - from pycparser.pycparser import parse_file, c_ast + from pycparser.pycparser import parse_file, c_ast # This is for the language server + sys.path.remove(os.path.join(pyquda_root, "pycparser")) def evaluate(node): if node is None: