Skip to content

Commit f1150f0

Browse files
fjoswmatteo-dc
andauthored
Matmul overloaded for correlator class. (#199)
* feat: matmul method added to correlator class. * feat: corr, corr matmul and correlator matrix trace added. * tests: tests for matmul and trace added. * tests: slightly reduced tolerance and good guess bad guess test. * feat: rmatmul added and __array_priority__ set. * tests: additional tests for rmatmul added. * tests: one more tests for rmatmul added. * docs: docstring added to Corr.trace. * tests: associative property test added for complex Corr matmul. * fix: Corr.roll method now also works for correlator matrices by explicitly specifying the axis. Co-authored-by: Matteo Di Carlo <matteo.dicarlo93@gmail.com> * feat: exception type for correlator trace of 1dim correlator changed. * tests: trace N=1 exception tested. --------- Co-authored-by: Matteo Di Carlo <matteo.dicarlo93@gmail.com>
1 parent 7d1858f commit f1150f0

File tree

3 files changed

+185
-6
lines changed

3 files changed

+185
-6
lines changed

pyerrors/correlators.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def symmetric(self):
220220
def anti_symmetric(self):
221221
"""Anti-symmetrize the correlator around x0=0."""
222222
if self.N != 1:
223-
raise Exception('anti_symmetric cannot be safely applied to multi-dimensional correlators.')
223+
raise TypeError('anti_symmetric cannot be safely applied to multi-dimensional correlators.')
224224
if self.T % 2 != 0:
225225
raise Exception("Can not symmetrize odd T")
226226

@@ -242,7 +242,7 @@ def anti_symmetric(self):
242242
def is_matrix_symmetric(self):
243243
"""Checks whether a correlator matrices is symmetric on every timeslice."""
244244
if self.N == 1:
245-
raise Exception("Only works for correlator matrices.")
245+
raise TypeError("Only works for correlator matrices.")
246246
for t in range(self.T):
247247
if self[t] is None:
248248
continue
@@ -254,6 +254,18 @@ def is_matrix_symmetric(self):
254254
return False
255255
return True
256256

257+
def trace(self):
258+
"""Calculates the per-timeslice trace of a correlator matrix."""
259+
if self.N == 1:
260+
raise ValueError("Only works for correlator matrices.")
261+
newcontent = []
262+
for t in range(self.T):
263+
if _check_for_none(self, self.content[t]):
264+
newcontent.append(None)
265+
else:
266+
newcontent.append(np.trace(self.content[t]))
267+
return Corr(newcontent)
268+
257269
def matrix_symmetric(self):
258270
"""Symmetrizes the correlator matrices on every timeslice."""
259271
if self.N == 1:
@@ -405,7 +417,7 @@ def roll(self, dt):
405417
dt : int
406418
number of timeslices
407419
"""
408-
return Corr(list(np.roll(np.array(self.content, dtype=object), dt)))
420+
return Corr(list(np.roll(np.array(self.content, dtype=object), dt, axis=0)))
409421

410422
def reverse(self):
411423
"""Reverse the time ordering of the Corr"""
@@ -1020,6 +1032,8 @@ def __str__(self):
10201032
# This is because Obs*Corr checks Obs.__mul__ first and does not catch an exception.
10211033
# One could try and tell Obs to check if the y in __mul__ is a Corr and
10221034

1035+
__array_priority__ = 10000
1036+
10231037
def __add__(self, y):
10241038
if isinstance(y, Corr):
10251039
if ((self.N != y.N) or (self.T != y.T)):
@@ -1076,6 +1090,49 @@ def __mul__(self, y):
10761090
else:
10771091
raise TypeError("Corr * wrong type")
10781092

1093+
def __matmul__(self, y):
1094+
if isinstance(y, np.ndarray):
1095+
if y.ndim != 2 or y.shape[0] != y.shape[1]:
1096+
raise ValueError("Can only multiply correlators by square matrices.")
1097+
if not self.N == y.shape[0]:
1098+
raise ValueError("matmul: mismatch of matrix dimensions")
1099+
newcontent = []
1100+
for t in range(self.T):
1101+
if _check_for_none(self, self.content[t]):
1102+
newcontent.append(None)
1103+
else:
1104+
newcontent.append(self.content[t] @ y)
1105+
return Corr(newcontent)
1106+
elif isinstance(y, Corr):
1107+
if not self.N == y.N:
1108+
raise ValueError("matmul: mismatch of matrix dimensions")
1109+
newcontent = []
1110+
for t in range(self.T):
1111+
if _check_for_none(self, self.content[t]) or _check_for_none(y, y.content[t]):
1112+
newcontent.append(None)
1113+
else:
1114+
newcontent.append(self.content[t] @ y.content[t])
1115+
return Corr(newcontent)
1116+
1117+
else:
1118+
return NotImplemented
1119+
1120+
def __rmatmul__(self, y):
1121+
if isinstance(y, np.ndarray):
1122+
if y.ndim != 2 or y.shape[0] != y.shape[1]:
1123+
raise ValueError("Can only multiply correlators by square matrices.")
1124+
if not self.N == y.shape[0]:
1125+
raise ValueError("matmul: mismatch of matrix dimensions")
1126+
newcontent = []
1127+
for t in range(self.T):
1128+
if _check_for_none(self, self.content[t]):
1129+
newcontent.append(None)
1130+
else:
1131+
newcontent.append(y @ self.content[t])
1132+
return Corr(newcontent)
1133+
else:
1134+
return NotImplemented
1135+
10791136
def __truediv__(self, y):
10801137
if isinstance(y, Corr):
10811138
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):

tests/correlators_test.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,3 +570,125 @@ def test_corr_symmetric():
570570
assert scorr[1] == scorr[3]
571571
assert scorr[2] == corr[2]
572572
assert scorr[0] == corr[0]
573+
574+
575+
def test_two_matrix_corr_inits():
576+
T = 4
577+
rn = lambda : np.random.normal(0.5, 0.1)
578+
579+
# Generate T random CObs in a list
580+
list_of_timeslices =[]
581+
for i in range(T):
582+
re = pe.pseudo_Obs(rn(), rn(), "test")
583+
im = pe.pseudo_Obs(rn(), rn(), "test")
584+
list_of_timeslices.append(pe.CObs(re, im))
585+
586+
# First option: Correlator of matrix of correlators
587+
corr = pe.Corr(list_of_timeslices)
588+
mat_corr1 = pe.Corr(np.array([[corr, corr], [corr, corr]]))
589+
590+
# Second option: Correlator of list of arrays per timeslice
591+
list_of_arrays = [np.array([[elem, elem], [elem, elem]]) for elem in list_of_timeslices]
592+
mat_corr2 = pe.Corr(list_of_arrays)
593+
594+
for el in mat_corr1 - mat_corr2:
595+
assert np.all(el == 0)
596+
597+
598+
def test_matmul_overloading():
599+
N = 4
600+
rn = lambda : np.random.normal(0.5, 0.1)
601+
602+
# Generate N^2 random CObs and assemble them in an array
603+
ll =[]
604+
for i in range(N ** 2):
605+
re = pe.pseudo_Obs(rn(), rn(), "test")
606+
im = pe.pseudo_Obs(rn(), rn(), "test")
607+
ll.append(pe.CObs(re, im))
608+
mat = np.array(ll).reshape(N, N)
609+
610+
# Multiply with gamma matrix
611+
corr = pe.Corr([mat] * 4, padding=[0, 1])
612+
613+
# __matmul__
614+
mcorr = corr @ pe.dirac.gammaX
615+
comp = mat @ pe.dirac.gammaX
616+
for i in range(4):
617+
assert np.all(mcorr[i] == comp)
618+
619+
# __rmatmul__
620+
mcorr = pe.dirac.gammaX @ corr
621+
comp = pe.dirac.gammaX @ mat
622+
for i in range(4):
623+
assert np.all(mcorr[i] == comp)
624+
625+
test_mat = pe.dirac.gamma5 + pe.dirac.gammaX
626+
icorr = corr @ test_mat @ np.linalg.inv(test_mat)
627+
tt = corr - icorr
628+
for i in range(4):
629+
assert np.all(tt[i] == 0)
630+
631+
# associative property
632+
tt = (corr.real @ pe.dirac.gammaX + corr.imag @ (pe.dirac.gammaX * 1j)) - corr @ pe.dirac.gammaX
633+
for el in tt:
634+
if el is not None:
635+
assert np.all(el == 0)
636+
637+
corr2 = corr @ corr
638+
for i in range(4):
639+
np.all(corr2[i] == corr[i] @ corr[i])
640+
641+
642+
def test_matrix_trace():
643+
N = 4
644+
rn = lambda : np.random.normal(0.5, 0.1)
645+
646+
# Generate N^2 random CObs and assemble them in an array
647+
ll =[]
648+
for i in range(N ** 2):
649+
re = pe.pseudo_Obs(rn(), rn(), "test")
650+
im = pe.pseudo_Obs(rn(), rn(), "test")
651+
ll.append(pe.CObs(re, im))
652+
mat = np.array(ll).reshape(N, N)
653+
654+
corr = pe.Corr([mat] * 4)
655+
656+
# Explicitly check trace
657+
for el in corr.trace():
658+
el == np.sum(np.diag(mat))
659+
660+
# Trace is cyclic
661+
for one, two in zip((pe.dirac.gammaX @ corr).trace(), (corr @ pe.dirac.gammaX).trace()):
662+
assert np.all(one == two)
663+
664+
# Antisymmetric matrices are traceless.
665+
mat = (mat - mat.T) / 2
666+
corr = pe.Corr([mat] * 4)
667+
for el in corr.trace():
668+
assert el == 0
669+
670+
671+
with pytest.raises(ValueError):
672+
corr.item(0, 0).trace()
673+
674+
675+
def test_corr_roll():
676+
T = 4
677+
rn = lambda : np.random.normal(0.5, 0.1)
678+
679+
ll = []
680+
for i in range(T):
681+
re = pe.pseudo_Obs(rn(), rn(), "test")
682+
im = pe.pseudo_Obs(rn(), rn(), "test")
683+
ll.append(pe.CObs(re, im))
684+
685+
# Rolling by T should produce the same correlator
686+
corr = pe.Corr(ll)
687+
tt = corr - corr.roll(T)
688+
for el in tt:
689+
assert np.all(el == 0)
690+
691+
mcorr = pe.Corr(np.array([[corr, corr + 0.1], [corr - 0.1, 2 * corr]]))
692+
tt = mcorr.roll(T) - mcorr
693+
for el in tt:
694+
assert np.all(el == 0)

tests/fits_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,12 @@ def func(a, x):
235235

236236

237237
def test_linear_fit_guesses():
238-
for err in [10, 0.1, 0.001]:
238+
for err in [1.2, 0.1, 0.001]:
239239
xvals = []
240240
yvals = []
241241
for x in range(1, 8, 2):
242242
xvals.append(x)
243-
yvals.append(pe.pseudo_Obs(x + np.random.normal(0.0, err), err, 'test1') + pe.pseudo_Obs(0, err / 100, 'test2', samples=87))
243+
yvals.append(pe.pseudo_Obs(x + np.random.normal(0.0, err), err, 'test1') + pe.pseudo_Obs(0, err / 97, 'test2', samples=87))
244244
lin_func = lambda a, x: a[0] + a[1] * x
245245
with pytest.raises(Exception):
246246
pe.least_squares(xvals, yvals, lin_func)
@@ -251,7 +251,7 @@ def test_linear_fit_guesses():
251251
bad_guess = pe.least_squares(xvals, yvals, lin_func, initial_guess=[999, 999])
252252
good_guess = pe.least_squares(xvals, yvals, lin_func, initial_guess=[0, 1])
253253
assert np.isclose(bad_guess.chisquare, good_guess.chisquare, atol=1e-8)
254-
assert np.all([(go - ba).is_zero(atol=1e-6) for (go, ba) in zip(good_guess, bad_guess)])
254+
assert np.all([(go - ba).is_zero(atol=5e-5) for (go, ba) in zip(good_guess, bad_guess)])
255255

256256

257257
def test_total_least_squares():

0 commit comments

Comments
 (0)