Skip to content

Commit

Permalink
Improve performance for XQCD and KYU propagator.
Browse files Browse the repository at this point in the history
  • Loading branch information
SaltyChiang committed Dec 12, 2024
1 parent 634a8af commit 79c0b74
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 22 deletions.
45 changes: 37 additions & 8 deletions pyquda_io/_field_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def gaugeEvenShiftBackward(latt_size: List[int], grid_size: List[int], gauge: nu
# DP for Dirac-Pauli, DR for DeGrand-Rossi
# \psi(DP) = _DR_TO_DP \psi(DR)
# \psi(DR) = _DP_TO_DR \psi(DP)
_DP_TO_DR = numpy.array(
_FROM_DIRAC_PAULI = numpy.array(
[
[0, 1, 0, -1],
[-1, 0, 1, 0],
[0, 1, 0, 1],
[-1, 0, -1, 0],
]
)
_DR_TO_DP = numpy.array(
_TO_DIRAC_PAULI = numpy.array(
[
[0, -1, 0, -1],
[1, 0, 1, 0],
Expand All @@ -192,15 +192,44 @@ def gaugeEvenShiftBackward(latt_size: List[int], grid_size: List[int], gauge: nu
)


def propagatorDeGrandRossiToDiracPauli(propagator: numpy.ndarray):
P = _DR_TO_DP
Pinv = _DP_TO_DR / 2
def propagatorFromDiracPauli(propagator: numpy.ndarray):
P = _FROM_DIRAC_PAULI
Pinv = _TO_DIRAC_PAULI / 2

return numpy.ascontiguousarray(numpy.einsum("ij,tzyxjkab,kl->tzyxilab", P, propagator.data, Pinv, optimize=True))


def propagatorDiracPauliToDeGrandRossi(propagator: numpy.ndarray):
P = _DP_TO_DR
Pinv = _DR_TO_DP / 2
def propagatorToDiracPauli(propagator: numpy.ndarray):
P = _TO_DIRAC_PAULI
Pinv = _FROM_DIRAC_PAULI / 2

return numpy.ascontiguousarray(numpy.einsum("ij,tzyxjkab,kl->tzyxilab", P, propagator.data, Pinv, optimize=True))


def spinMatrixFromDiracPauli(dirac_pauli: numpy.ndarray):
P = _FROM_DIRAC_PAULI
degrand_rossi = numpy.zeros_like(dirac_pauli)
for i in range(4):
for j in range(4):
for i_ in range((i + 1) % 2, 4, 2):
for j_ in range((j + 1) % 2, 4, 2):
if P[i, i_] * P[j, j_] == 1:
degrand_rossi[i, j] += dirac_pauli[i_, j_]
elif P[i, i_] * P[j, j_] == -1:
degrand_rossi[i, j] -= dirac_pauli[i_, j_]
return degrand_rossi.transpose(2, 3, 4, 5, 0, 1, 6, 7) / 2


def spinMatrixToDiracPauli(degrand_rossi: numpy.ndarray):
P = _TO_DIRAC_PAULI
degrand_rossi = degrand_rossi.transpose(4, 5, 0, 1, 2, 3, 6, 7) / 2
dirac_pauli = numpy.zeros_like(degrand_rossi)
for i in range(4):
for j in range(4):
for i_ in range((i + 1) % 2, 4, 2):
for j_ in range((j + 1) % 2, 4, 2):
if P[i, i_] * P[j, j_] == 1:
dirac_pauli[i, j] += degrand_rossi[i_, j_]
elif P[i, i_] * P[j, j_] == -1:
dirac_pauli[i, j] -= degrand_rossi[i_, j_]
return dirac_pauli
14 changes: 7 additions & 7 deletions pyquda_io/kyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy

from ._mpi_file import getSublatticeSize, readMPIFile, writeMPIFile
from ._field_utils import propagatorDiracPauliToDeGrandRossi, propagatorDeGrandRossiToDiracPauli
from ._field_utils import spinMatrixFromDiracPauli, spinMatrixToDiracPauli

Nd, Ns, Nc = 4, 4, 3

Expand Down Expand Up @@ -43,13 +43,13 @@ def readPropagator(filename: str, latt_size: List[int], grid_size: List[int]):

propagator = readMPIFile(filename, dtype, offset, (Ns, Nc, 2, Ns, Nc, Lt, Lz, Ly, Lx), (8, 7, 6, 5), grid_size)
propagator = (
propagator.transpose(5, 6, 7, 8, 3, 0, 4, 1, 2)
propagator.transpose(3, 0, 5, 6, 7, 8, 4, 1, 2)
.astype("<f8")
.copy()
.reshape(Lt, Lz, Ly, Lx, Ns, Ns, Nc, Nc * 2)
.reshape(Ns, Ns, Lt, Lz, Ly, Lx, Nc, Nc * 2)
.view("<c16")
)
propagator = propagatorDiracPauliToDeGrandRossi(propagator)
propagator = spinMatrixFromDiracPauli(propagator)
return propagator


Expand All @@ -58,12 +58,12 @@ def writePropagator(filename: str, latt_size: List[int], grid_size: List[int], p
Lx, Ly, Lz, Lt = getSublatticeSize(latt_size, grid_size)
dtype, offset = ">f8", 0

propagator = propagatorDeGrandRossiToDiracPauli(propagator)
propagator = spinMatrixToDiracPauli(propagator)
propagator = (
propagator.view("<f8")
.reshape(Lt, Lz, Ly, Lx, Ns, Ns, Nc, Nc, 2)
.reshape(Ns, Ns, Lt, Lz, Ly, Lx, Nc, Nc, 2)
.astype(dtype)
.transpose(5, 7, 8, 4, 6, 0, 1, 2, 3)
.transpose(1, 7, 8, 0, 6, 2, 3, 4, 5)
.copy()
)
writeMPIFile(filename, dtype, offset, (Ns, Nc, 2, Ns, Nc, Lt, Lz, Ly, Lx), (8, 7, 6, 5), grid_size, propagator)
10 changes: 5 additions & 5 deletions pyquda_io/xqcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy

from ._mpi_file import getSublatticeSize, readMPIFile, writeMPIFile
from ._field_utils import propagatorDiracPauliToDeGrandRossi, propagatorDeGrandRossiToDiracPauli
from ._field_utils import spinMatrixFromDiracPauli, spinMatrixToDiracPauli

Ns, Nc = 4, 3

Expand All @@ -16,8 +16,8 @@ def readPropagator(filename: str, latt_size: List[int], grid_size: List[int], st

if not staggered:
propagator = readMPIFile(filename, dtype, offset, (Ns, Nc, Lt, Lz, Ly, Lx, Ns, Nc), (5, 4, 3, 2), grid_size)
propagator = propagator.transpose(2, 3, 4, 5, 6, 0, 7, 1).astype("<c16")
propagator = propagatorDiracPauliToDeGrandRossi(propagator)
propagator = propagator.transpose(6, 0, 2, 3, 4, 5, 7, 1).astype("<c16")
propagator = spinMatrixFromDiracPauli(propagator)
else:
# QDP_ALIGN16 makes the last Nc to be aligned with 16 Bytes.
propagator_align16 = readMPIFile(filename, dtype, offset, (Nc, Lt, Lz, Ly, Lx, 4), (4, 3, 2, 1), grid_size)
Expand All @@ -34,8 +34,8 @@ def writePropagator(
dtype, offset = "<c8", 0

if not staggered:
propagator = propagatorDeGrandRossiToDiracPauli(propagator)
propagator = propagator.astype(dtype).transpose(5, 7, 0, 1, 2, 3, 4, 6).copy()
propagator = spinMatrixToDiracPauli(propagator)
propagator = propagator.astype(dtype).transpose(1, 7, 2, 3, 4, 5, 0, 6).copy()
writeMPIFile(filename, dtype, offset, (Ns, Nc, Lt, Lz, Ly, Lx, Ns, Nc), (5, 4, 3, 2), grid_size, propagator)
else:
# QDP_ALIGN16 makes the last Nc to be aligned with 16 Bytes.
Expand Down
7 changes: 5 additions & 2 deletions pyquda_utils/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def readXQCDPropagator(filename: str, latt_size: List[int], staggered: bool):
from pyquda_io.xqcd import readPropagator as read

propagator_raw = read(filename, latt_size, getGridSize(), staggered)
return LatticeStaggeredPropagator(LatticeInfo(latt_size), evenodd(propagator_raw, [0, 1, 2, 3]))
if not staggered:
return LatticePropagator(LatticeInfo(latt_size), evenodd(propagator_raw, [0, 1, 2, 3]))
else:
return LatticeStaggeredPropagator(LatticeInfo(latt_size), evenodd(propagator_raw, [0, 1, 2, 3]))


def writeXQCDPropagator(filename: str, propagator: Union[LatticePropagator, LatticeStaggeredPropagator]):
Expand All @@ -170,7 +173,7 @@ def readXQCDPropagatorFast(filename: str, latt_size: List[int]):

latt_info = LatticeInfo(latt_size)
Lx, Ly, Lz, Lt = latt_info.size
propagator_raw = read(filename, getGridSize(), latt_size)
propagator_raw = read(filename, latt_size, getGridSize())
propagator = LatticePropagator(latt_info, evenodd(propagator_raw, [2, 3, 4, 5]))
propagator.data = propagator.data.reshape(Ns, Nc, 2, Lt, Lz, Ly, Lx // 2, Ns, Nc)
propagator.toDevice()
Expand Down

0 comments on commit 79c0b74

Please sign in to comment.