Skip to content

[Feat] Add type hints #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3db8eb2
[Feat] Add type hints to pyerrors modules
fjosw Dec 25, 2024
9fe375a
[Feat] Added type hints to input modules
fjosw Dec 25, 2024
8d86295
[Feat] Fixed a few type hints manually
fjosw Dec 25, 2024
2d34b35
[Fix] Fix ruff errors and a few type annotations
fjosw Dec 25, 2024
a9e082c
[Fix] Fix type annotations for first part of obs.py
fjosw Jan 3, 2025
0198256
[Fix] Fixed most type annotations in obs.py
fjosw Jan 3, 2025
9389ad6
[CI] Add typing extension package to pytest run for pytthon 3.9
fjosw Jan 3, 2025
23d4f4c
[Fix] Fix type hints in misc.py and remove strict zips for python 3.9
fjosw Jan 3, 2025
1916de1
[Fix] Start fixing remaining type hints
fjosw Jan 3, 2025
3654635
[Fix] Removed unused imports
fjosw Jan 3, 2025
4f1606d
[CI] Add E252 to flake8 exceptions
fjosw Jan 3, 2025
d45b43e
[Fix] Fixed remaining flake8 errors
fjosw Jan 3, 2025
1c6053e
[Fix] Simplify type hints
fjosw Jan 3, 2025
b8700ef
[Fix] Fix type annotations in json.py
fjosw Jan 3, 2025
6d5a9b9
[Fix] Simplify type annotations in input modules
fjosw Jan 3, 2025
6a990c1
[Fix] Fix ruff
fjosw Jan 3, 2025
9c960ae
[Fix] Correct type hints in fits.py
fjosw Jan 5, 2025
5376a8a
[Fix] Further type fixes in fits and sfcf
fjosw Jan 5, 2025
336117c
[Fix] More type hint fixing
fjosw Jan 5, 2025
dba277f
Merge branch 'develop' into feat/typehints
jkuhl-uni Jan 6, 2025
52b91d8
add typehints for other util functions
jkuhl-uni Jan 9, 2025
cb9d942
make deriv structure like second_deriv
jkuhl-uni Jan 9, 2025
2f40ff8
add typing for read_pbp
jkuhl-uni Jan 9, 2025
bbf0b68
[Fix] Additional typehints
fjosw Jan 7, 2025
7a3a28d
add typehints for check_params
jkuhl-uni Jan 13, 2025
4814675
[Fix] more work on typehints
Feb 17, 2025
f44b19c
clean up sfcf input types
jkuhl-uni Mar 29, 2025
96cdec4
annotate read_ms5_xsf
jkuhl-uni Mar 29, 2025
4cfe863
Merge branch 'feat/typehints' of github.com:fjosw/pyerrors into feat/…
jkuhl-uni Mar 29, 2025
b3b4126
some more quick fixes
jkuhl-uni Mar 29, 2025
6d80efd
Merge branch 'develop' into feat/typehints
fjosw May 5, 2025
f594e06
[Fix] Clean up merge conflict
fjosw May 5, 2025
ab8ca3a
[Fix] Clean up whitespaces
fjosw May 5, 2025
f3e7cde
[Fix] Remove Unpack for compatibility with older python versions.
fjosw May 5, 2025
ac7e98d
simple fix for openQCD in test fails
jkuhl-uni May 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/flake8.yml
Original file line number Diff line number Diff line change
@@ -21,6 +21,6 @@ jobs:
- name: flake8 Lint
uses: py-actions/flake8@v2
with:
ignore: "E501,W503"
ignore: "E501,W503,E252"
exclude: "__init__.py, input/__init__.py"
path: "pyerrors"
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ jobs:
run: |
uv pip install wheel --system
uv pip install . --system
uv pip install pytest pytest-cov pytest-benchmark hypothesis --system
uv pip install pytest pytest-cov pytest-benchmark hypothesis typing_extensions --system
uv pip freeze --system
- name: Run tests
194 changes: 94 additions & 100 deletions pyerrors/correlators.py

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions pyerrors/covobs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations
import numpy as np
from numpy import ndarray
from typing import Optional, Union


class Covobs:

def __init__(self, mean, cov, name, pos=None, grad=None):
def __init__(self, mean: Union[float, int], cov: Union[list, ndarray], name: str, pos: Optional[int]=None, grad: Optional[Union[ndarray, list[float]]]=None):
""" Initialize Covobs object.
Parameters
@@ -39,12 +42,12 @@ def __init__(self, mean, cov, name, pos=None, grad=None):
self._set_grad(grad)
self.value = mean

def errsq(self):
def errsq(self) -> float:
""" Return the variance (= square of the error) of the Covobs
"""
return np.dot(np.transpose(self.grad), np.dot(self.cov, self.grad)).item()

def _set_cov(self, cov):
def _set_cov(self, cov: Union[list, ndarray]):
""" Set the covariance matrix of the covobs
Parameters
@@ -79,7 +82,7 @@ def _set_cov(self, cov):
if ev < 0:
raise Exception('Covariance matrix is not positive-semidefinite!')

def _set_grad(self, grad):
def _set_grad(self, grad: Union[list[float], ndarray]):
""" Set the gradient of the covobs
Parameters
@@ -96,9 +99,9 @@ def _set_grad(self, grad):
raise Exception('Invalid dimension of grad!')

@property
def cov(self):
def cov(self) -> ndarray:
return self._cov

@property
def grad(self):
def grad(self) -> ndarray:
return self._grad
8 changes: 5 additions & 3 deletions pyerrors/dirac.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations
import numpy as np
from numpy import ndarray


gammaX = np.array(
@@ -22,7 +24,7 @@
dtype=complex)


def epsilon_tensor(i, j, k):
def epsilon_tensor(i: int, j: int, k: int) -> float:
"""Rank-3 epsilon tensor
Based on https://codegolf.stackexchange.com/a/160375
@@ -39,7 +41,7 @@ def epsilon_tensor(i, j, k):
return (i - j) * (j - k) * (k - i) / 2


def epsilon_tensor_rank4(i, j, k, o):
def epsilon_tensor_rank4(i: int, j: int, k: int, o: int) -> float:
"""Rank-4 epsilon tensor
Extension of https://codegolf.stackexchange.com/a/160375
@@ -57,7 +59,7 @@ def epsilon_tensor_rank4(i, j, k, o):
return (i - j) * (j - k) * (k - i) * (i - o) * (j - o) * (o - k) / 12


def Grid_gamma(gamma_tag):
def Grid_gamma(gamma_tag: str) -> ndarray:
"""Returns gamma matrix in Grid labeling."""
if gamma_tag == 'Identity':
g = identity
109 changes: 72 additions & 37 deletions pyerrors/fits.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import gc
from collections.abc import Sequence
import warnings
@@ -15,6 +16,8 @@
from numdifftools import Jacobian as num_jacobian
from numdifftools import Hessian as num_hessian
from .obs import Obs, derived_observable, covariance, cov_Obs, invert_corr_cov_cholesky
from numpy import ndarray
from typing import Any, Callable, Optional, Union


class Fit_result(Sequence):
@@ -33,13 +36,31 @@ class Fit_result(Sequence):
Hotelling t-squared p-value for correlated fits.
"""

def __init__(self):
self.fit_parameters = None

def __getitem__(self, idx):
def __init__(self) -> None:
self.fit_parameters: Optional[list] = None
self.fit_function: Optional[Union[Callable, dict[str, Callable]]] = None
self.priors: Optional[Union[list[Obs], dict[int, Obs]]] = None
self.method: Optional[str] = None
self.iterations: Optional[int] = None
self.chisquare: Optional[float] = None
self.odr_chisquare: Optional[float] = None
self.dof: Optional[int] = None
self.p_value: Optional[float] = None
self.message: Optional[str] = None
self.t2_p_value: Optional[float] = None
self.chisquare_by_dof: Optional[float] = None
self.chisquare_by_expected_chisquare: Optional[float] = None
self.residual_variance: Optional[float] = None
self.xplus: Optional[float] = None

def __getitem__(self, idx: int) -> Obs:
if self.fit_parameters is None:
raise TypeError('No fit parameters available.')
return self.fit_parameters[idx]

def __len__(self):
def __len__(self) -> int:
if self.fit_parameters is None:
raise TypeError('No fit parameters available.')
return len(self.fit_parameters)

def gamma_method(self, **kwargs):
@@ -48,29 +69,31 @@ def gamma_method(self, **kwargs):

gm = gamma_method

def __str__(self):
def __str__(self) -> str:
my_str = 'Goodness of fit:\n'
if hasattr(self, 'chisquare_by_dof'):
if self.chisquare_by_dof is not None:
my_str += '\u03C7\u00b2/d.o.f. = ' + f'{self.chisquare_by_dof:2.6f}' + '\n'
elif hasattr(self, 'residual_variance'):
elif self.residual_variance is not None:
my_str += 'residual variance = ' + f'{self.residual_variance:2.6f}' + '\n'
if hasattr(self, 'chisquare_by_expected_chisquare'):
if self.chisquare_by_expected_chisquare is not None:
my_str += '\u03C7\u00b2/\u03C7\u00b2exp = ' + f'{self.chisquare_by_expected_chisquare:2.6f}' + '\n'
if hasattr(self, 'p_value'):
if self.p_value is not None:
my_str += 'p-value = ' + f'{self.p_value:2.4f}' + '\n'
if hasattr(self, 't2_p_value'):
if self.t2_p_value is not None:
my_str += 't\u00B2p-value = ' + f'{self.t2_p_value:2.4f}' + '\n'
my_str += 'Fit parameters:\n'
if self.fit_parameters is None:
raise TypeError('No fit parameters available.')
for i_par, par in enumerate(self.fit_parameters):
my_str += str(i_par) + '\t' + ' ' * int(par >= 0) + str(par).rjust(int(par < 0.0)) + '\n'
return my_str

def __repr__(self):
def __repr__(self) -> str:
m = max(map(len, list(self.__dict__.keys()))) + 1
return '\n'.join([key.rjust(m) + ': ' + repr(value) for key, value in sorted(self.__dict__.items())])


def least_squares(x, y, func, priors=None, silent=False, **kwargs):
def least_squares(x: Any, y: Union[dict[str, ndarray], list[Obs], ndarray, dict[str, list[Obs]]], func: Union[Callable, dict[str, Callable]], priors: Optional[Union[dict[int, str], list[str], list[Obs], dict[int, Obs]]]=None, silent: bool=False, **kwargs) -> Fit_result:
r'''Performs a non-linear fit to y = func(x).
```
@@ -335,9 +358,8 @@ def func_b(a, x):
p_f = dp_f = np.array([])
prior_mask = []
loc_priors = []

if 'initial_guess' in kwargs:
x0 = kwargs.get('initial_guess')
x0 = kwargs.get('initial_guess')
if x0 is not None:
if len(x0) != n_parms:
raise ValueError('Initial guess does not have the correct length: %d vs. %d' % (len(x0), n_parms))
else:
@@ -356,8 +378,8 @@ def chisqfunc_uncorr(p):
return anp.sum(general_chisqfunc_uncorr(p, y_f, p_f) ** 2)

if kwargs.get('correlated_fit') is True:
if 'inv_chol_cov_matrix' in kwargs:
chol_inv = kwargs.get('inv_chol_cov_matrix')
chol_inv = kwargs.get('inv_chol_cov_matrix')
if chol_inv is not None:
if (chol_inv[0].shape[0] != len(dy_f)):
raise TypeError('The number of columns of the inverse covariance matrix handed over needs to be equal to the number of y errors.')
if (chol_inv[0].shape[0] != chol_inv[0].shape[1]):
@@ -388,17 +410,17 @@ def chisqfunc(p):

if output.method != 'Levenberg-Marquardt':
if output.method == 'migrad':
tolerance = 1e-4 # default value of 1e-1 set by iminuit can be problematic
if 'tol' in kwargs:
tolerance = kwargs.get('tol')
tolerance = kwargs.get('tol')
if tolerance is None:
tolerance = 1e-4 # default value of 1e-1 set by iminuit can be problematic
fit_result = iminuit.minimize(chisqfunc_uncorr, x0, tol=tolerance) # Stopping criterion 0.002 * tol * errordef
if kwargs.get('correlated_fit') is True:
fit_result = iminuit.minimize(chisqfunc, fit_result.x, tol=tolerance)
output.iterations = fit_result.nfev
else:
tolerance = 1e-12
if 'tol' in kwargs:
tolerance = kwargs.get('tol')
tolerance = kwargs.get('tol')
if tolerance is None:
tolerance = 1e-12
fit_result = scipy.optimize.minimize(chisqfunc_uncorr, x0, method=kwargs.get('method'), tol=tolerance)
if kwargs.get('correlated_fit') is True:
fit_result = scipy.optimize.minimize(chisqfunc, fit_result.x, method=kwargs.get('method'), tol=tolerance)
@@ -428,8 +450,8 @@ def chisqfunc_residuals(p):
if not fit_result.success:
raise Exception('The minimization procedure did not converge.')

output.chisquare = chisquare
output.dof = y_all.shape[-1] - n_parms + len(loc_priors)
output.chisquare = float(chisquare)
output.dof = int(y_all.shape[-1] - n_parms + len(loc_priors))
output.p_value = 1 - scipy.stats.chi2.cdf(output.chisquare, output.dof)
if output.dof > 0:
output.chisquare_by_dof = output.chisquare / output.dof
@@ -505,7 +527,7 @@ def chisqfunc_compact(d):
return output


def total_least_squares(x, y, func, silent=False, **kwargs):
def total_least_squares(x: list[Obs], y: list[Obs], func: Callable, silent: bool=False, **kwargs) -> Fit_result:
r'''Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters.
Parameters
@@ -602,8 +624,8 @@ def func(a, x):
if np.any(np.asarray(dy_f) <= 0.0):
raise Exception('No y errors available, run the gamma method first.')

if 'initial_guess' in kwargs:
x0 = kwargs.get('initial_guess')
x0 = kwargs.get('initial_guess')
if x0 is not None:
if len(x0) != n_parms:
raise Exception('Initial guess does not have the correct length: %d vs. %d' % (len(x0), n_parms))
else:
@@ -709,7 +731,7 @@ def odr_chisquare_compact_y(d):
return output


def fit_lin(x, y, **kwargs):
def fit_lin(x: Sequence[Union[Obs, int, float]], y: Sequence[Obs], **kwargs) -> list[Obs]:
"""Performs a linear fit to y = n + m * x and returns two Obs n, m.
Parameters
@@ -740,7 +762,7 @@ def f(a, x):
raise TypeError('Unsupported types for x')


def qqplot(x, o_y, func, p, title=""):
def qqplot(x: ndarray, o_y: list[Obs], func: Callable, p: list[Obs], title: str=""):
"""Generates a quantile-quantile plot of the fit result which can be used to
check if the residuals of the fit are gaussian distributed.
@@ -770,7 +792,7 @@ def qqplot(x, o_y, func, p, title=""):
plt.draw()


def residual_plot(x, y, func, fit_res, title=""):
def residual_plot(x: ndarray, y: list[Obs], func: Callable, fit_res: list[Obs], title: str=""):
"""Generates a plot which compares the fit to the data and displays the corresponding residuals
For uncorrelated data the residuals are expected to be distributed ~N(0,1).
@@ -807,9 +829,20 @@ def residual_plot(x, y, func, fit_res, title=""):
plt.draw()


def error_band(x, func, beta):
def error_band(x: list[int], func: Callable, beta: Union[Fit_result, list[Obs]]) -> ndarray:
"""Calculate the error band for an array of sample values x, for given fit function func with optimized parameters beta.
Parameters
----------
x : list[int]
A list of sample points where the error band is evaluated.
func : Callable
The function representing the fit model.
beta : Union[Fit_result, list[Obs]]
Optimized fit parameters.
Returns
-------
err : np.array(Obs)
@@ -831,7 +864,7 @@ def error_band(x, func, beta):
return err


def ks_test(objects=None):
def ks_test(objects: Optional[list[Fit_result]]=None):
"""Performs a Kolmogorov–Smirnov test for the p-values of all fit object.
Parameters
@@ -875,7 +908,7 @@ def ks_test(objects=None):
print(scipy.stats.kstest(p_values, 'uniform'))


def _extract_val_and_dval(string):
def _extract_val_and_dval(string: str) -> tuple[float, float]:
split_string = string.split('(')
if '.' in split_string[0] and '.' not in split_string[1][:-1]:
factor = 10 ** -len(split_string[0].partition('.')[2])
@@ -884,11 +917,13 @@ def _extract_val_and_dval(string):
return float(split_string[0]), float(split_string[1][:-1]) * factor


def _construct_prior_obs(i_prior, i_n):
def _construct_prior_obs(i_prior: Union[Obs, str], i_n: int) -> Obs:
if isinstance(i_prior, Obs):
return i_prior
elif isinstance(i_prior, str):
loc_val, loc_dval = _extract_val_and_dval(i_prior)
return cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(i_n) + f"_{np.random.randint(2147483647):010d}")
prior_obs = cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(i_n) + f"_{np.random.randint(2147483647):010d}")
assert isinstance(prior_obs, Obs)
return prior_obs
else:
raise TypeError("Prior entries need to be 'Obs' or 'str'.")
82 changes: 48 additions & 34 deletions pyerrors/input/dobs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from collections import defaultdict
import gzip
import lxml.etree as et
@@ -11,12 +12,15 @@
from ..obs import _merge_idx
from ..covobs import Covobs
from .. import version as pyerrorsversion
from lxml.etree import _Element
from numpy import ndarray
from typing import Any, Optional, Union


# Based on https://stackoverflow.com/a/10076823
def _etree_to_dict(t):
def _etree_to_dict(t: _Element) -> dict:
""" Convert the content of an XML file to a python dict"""
d = {t.tag: {} if t.attrib else None}
d: dict = {t.tag: {} if t.attrib else None}
children = list(t)
if children:
dd = defaultdict(list)
@@ -38,7 +42,7 @@ def _etree_to_dict(t):
return d


def _dict_to_xmlstring(d):
def _dict_to_xmlstring(d: dict[str, Any]) -> str:
if isinstance(d, dict):
iters = ''
for k in d:
@@ -66,7 +70,7 @@ def _dict_to_xmlstring(d):
return iters


def _dict_to_xmlstring_spaces(d, space=' '):
def _dict_to_xmlstring_spaces(d: dict, space: str=' ') -> str:
s = _dict_to_xmlstring(d)
o = ''
c = 0
@@ -85,7 +89,7 @@ def _dict_to_xmlstring_spaces(d, space=' '):
return o


def create_pobs_string(obsl, name, spec='', origin='', symbol=[], enstag=None):
def create_pobs_string(obsl: list[Obs], name: str, spec: str='', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, enstag: Optional[str]=None) -> str:
"""Export a list of Obs or structures containing Obs to an xml string
according to the Zeuthen pobs format.
@@ -113,7 +117,9 @@ def create_pobs_string(obsl, name, spec='', origin='', symbol=[], enstag=None):
XML formatted string of the input data
"""

od = {}
if symbol is None:
symbol = []
od: dict[str, Any] = {}
ename = obsl[0].e_names[0]
names = list(obsl[0].deltas.keys())
nr = len(names)
@@ -176,7 +182,7 @@ def create_pobs_string(obsl, name, spec='', origin='', symbol=[], enstag=None):
return rs


def write_pobs(obsl, fname, name, spec='', origin='', symbol=[], enstag=None, gz=True):
def write_pobs(obsl: list[Obs], fname: str, name: str, spec: str='', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, enstag: Optional[str]=None, gz: bool=True):
"""Export a list of Obs or structures containing Obs to a .xml.gz file
according to the Zeuthen pobs format.
@@ -206,6 +212,8 @@ def write_pobs(obsl, fname, name, spec='', origin='', symbol=[], enstag=None, gz
-------
None
"""
if symbol is None:
symbol = []
pobsstring = create_pobs_string(obsl, name, spec, origin, symbol, enstag)

if not fname.endswith('.xml') and not fname.endswith('.gz'):
@@ -215,38 +223,39 @@ def write_pobs(obsl, fname, name, spec='', origin='', symbol=[], enstag=None, gz
if not fname.endswith('.gz'):
fname += '.gz'

fp = gzip.open(fname, 'wb')
fp.write(pobsstring.encode('utf-8'))
gp = gzip.open(fname, 'wb')
gp.write(pobsstring.encode('utf-8'))
gp.close()
else:
fp = open(fname, 'w', encoding='utf-8')
fp.write(pobsstring)
fp.close()
fp.close()


def _import_data(string):
def _import_data(string: str) -> list[Union[int, float]]:
return json.loads("[" + ",".join(string.replace(' +', ' ').split()) + "]")


def _check(condition):
def _check(condition: bool):
if not condition:
raise Exception("XML file format not supported")


class _NoTagInDataError(Exception):
"""Raised when tag is not in data"""
def __init__(self, tag):
def __init__(self, tag: str):
self.tag = tag
super().__init__('Tag %s not in data!' % (self.tag))


def _find_tag(dat, tag):
def _find_tag(dat: _Element, tag: str) -> int:
for i in range(len(dat)):
if dat[i].tag == tag:
return i
raise _NoTagInDataError(tag)


def _import_array(arr):
def _import_array(arr: _Element) -> Union[list[Union[str, list[int], list[ndarray]]], ndarray]:
name = arr[_find_tag(arr, 'id')].text.strip()
index = _find_tag(arr, 'layout')
try:
@@ -284,20 +293,20 @@ def _import_array(arr):
_check(False)


def _import_rdata(rd):
def _import_rdata(rd: _Element) -> tuple[list[ndarray], str, list[int]]:
name, idx, mask, deltas = _import_array(rd)
return deltas, name, idx


def _import_cdata(cd):
def _import_cdata(cd: _Element) -> tuple[str, ndarray, ndarray]:
_check(cd[0].tag == "id")
_check(cd[1][0].text.strip() == "cov")
cov = _import_array(cd[1])
grad = _import_array(cd[2])
return cd[0].text.strip(), cov, grad


def read_pobs(fname, full_output=False, gz=True, separator_insertion=None):
def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: None=None) -> Union[dict, list[Obs]]:
"""Import a list of Obs from an xml.gz file in the Zeuthen pobs format.
Tags are not written or recovered automatically.
@@ -309,7 +318,7 @@ def read_pobs(fname, full_output=False, gz=True, separator_insertion=None):
full_output : bool
If True, a dict containing auxiliary information and the data is returned.
If False, only the data is returned as list.
separatior_insertion: str or int
separator_insertion: str or int
str: replace all occurences of "separator_insertion" within the replica names
by "|%s" % (separator_insertion) when constructing the names of the replica.
int: Insert the separator "|" at the position given by separator_insertion.
@@ -329,8 +338,8 @@ def read_pobs(fname, full_output=False, gz=True, separator_insertion=None):
if gz:
if not fname.endswith('.gz'):
fname += '.gz'
with gzip.open(fname, 'r') as fin:
content = fin.read()
with gzip.open(fname, 'r') as gin:
content = gin.read()
else:
if fname.endswith('.gz'):
warnings.warn("Trying to read from %s without unzipping!" % fname, UserWarning)
@@ -350,7 +359,7 @@ def read_pobs(fname, full_output=False, gz=True, separator_insertion=None):

deltas = []
names = []
idl = []
idl: list[list[int]] = []
for i in range(5, len(pobs)):
delta, name, idx = _import_rdata(pobs[i])
deltas.append(delta)
@@ -397,7 +406,7 @@ def read_pobs(fname, full_output=False, gz=True, separator_insertion=None):


# this is based on Mattia Bruno's implementation at https://github.com/mbruno46/pyobs/blob/master/pyobs/IO/xml.py
def import_dobs_string(content, full_output=False, separator_insertion=True):
def import_dobs_string(content: bytes, full_output: bool=False, separator_insertion: bool=True) -> Union[dict, list[Obs]]:
"""Import a list of Obs from a string in the Zeuthen dobs format.
Tags are not written or recovered automatically.
@@ -409,7 +418,7 @@ def import_dobs_string(content, full_output=False, separator_insertion=True):
full_output : bool
If True, a dict containing auxiliary information and the data is returned.
If False, only the data is returned as list.
separatior_insertion: str, int or bool
separator_insertion: str, int or bool
str: replace all occurences of "separator_insertion" within the replica names
by "|%s" % (separator_insertion) when constructing the names of the replica.
int: Insert the separator "|" at the position given by separator_insertion.
@@ -572,7 +581,7 @@ def import_dobs_string(content, full_output=False, separator_insertion=True):
return res


def read_dobs(fname, full_output=False, gz=True, separator_insertion=True):
def read_dobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: bool=True) -> Union[dict, list[Obs]]:
"""Import a list of Obs from an xml.gz file in the Zeuthen dobs format.
Tags are not written or recovered automatically.
@@ -619,7 +628,7 @@ def read_dobs(fname, full_output=False, gz=True, separator_insertion=True):
return import_dobs_string(content, full_output, separator_insertion=separator_insertion)


def _dobsdict_to_xmlstring(d):
def _dobsdict_to_xmlstring(d: dict[str, Any]) -> str:
if isinstance(d, dict):
iters = ''
for k in d:
@@ -659,7 +668,7 @@ def _dobsdict_to_xmlstring(d):
return iters


def _dobsdict_to_xmlstring_spaces(d, space=' '):
def _dobsdict_to_xmlstring_spaces(d: dict, space: str=' ') -> str:
s = _dobsdict_to_xmlstring(d)
o = ''
c = 0
@@ -678,7 +687,7 @@ def _dobsdict_to_xmlstring_spaces(d, space=' '):
return o


def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=None, enstags=None):
def create_dobs_string(obsl: list[Obs], name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, who: Optional[str]=None, enstags: Optional[dict]=None) -> str:
"""Generate the string for the export of a list of Obs or structures containing Obs
to a .xml.gz file according to the Zeuthen dobs format.
@@ -709,9 +718,11 @@ def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=N
xml_str : str
XML string generated from the data
"""
if symbol is None:
symbol = []
if enstags is None:
enstags = {}
od = {}
od: dict[str, Any] = {}
r_names = []
for o in obsl:
r_names += [name for name in o.names if name.split('|')[0] in o.mc_names]
@@ -811,7 +822,7 @@ def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=N
ed[''].append(ad)
pd['edata'].append(ed)

allcov = {}
allcov: dict[str, ndarray] = {}
for o in obsl:
for cname in o.cov_names:
if cname in allcov:
@@ -867,7 +878,7 @@ def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=N
return rs


def write_dobs(obsl, fname, name, spec='dobs v1.0', origin='', symbol=[], who=None, enstags=None, gz=True):
def write_dobs(obsl: list[Obs], fname: str, name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, who: Optional[str]=None, enstags: Optional[dict]=None, gz: bool=True):
"""Export a list of Obs or structures containing Obs to a .xml.gz file
according to the Zeuthen dobs format.
@@ -901,6 +912,8 @@ def write_dobs(obsl, fname, name, spec='dobs v1.0', origin='', symbol=[], who=No
-------
None
"""
if symbol is None:
symbol = []
if enstags is None:
enstags = {}

@@ -913,9 +926,10 @@ def write_dobs(obsl, fname, name, spec='dobs v1.0', origin='', symbol=[], who=No
if not fname.endswith('.gz'):
fname += '.gz'

fp = gzip.open(fname, 'wb')
fp.write(dobsstring.encode('utf-8'))
gp = gzip.open(fname, 'wb')
gp.write(dobsstring.encode('utf-8'))
gp.close()
else:
fp = open(fname, 'w', encoding='utf-8')
fp.write(dobsstring)
fp.close()
fp.close()
29 changes: 16 additions & 13 deletions pyerrors/input/json.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import rapidjson as json
import gzip
import getpass
@@ -12,9 +13,10 @@
from ..correlators import Corr
from ..misc import _assert_equal_properties
from .. import version as pyerrorsversion
from typing import Any, Union


def create_json_string(ol, description='', indent=1):
def create_json_string(ol: list, description: Union[str, dict]='', indent: int=1) -> str:
"""Generate the string for the export of a list of Obs or structures containing Obs
to a .json(.gz) file
@@ -166,7 +168,7 @@ def write_Corr_to_dict(my_corr):
if not isinstance(ol, list):
ol = [ol]

d = {}
d: dict[str, Any] = {}
d['program'] = 'pyerrors %s' % (pyerrorsversion.__version__)
d['version'] = '1.1'
d['who'] = getpass.getuser()
@@ -217,7 +219,7 @@ def _jsonifier(obj):
return json.dumps(d, indent=indent, ensure_ascii=False, default=_jsonifier, write_mode=json.WM_COMPACT)


def dump_to_json(ol, fname, description='', indent=1, gz=True):
def dump_to_json(ol: list, fname: str, description: Union[str, dict]='', indent: int=1, gz: bool=True):
"""Export a list of Obs or structures containing Obs to a .json(.gz) file.
Dict keys that are not JSON-serializable such as floats are converted to strings.
@@ -251,15 +253,16 @@ def dump_to_json(ol, fname, description='', indent=1, gz=True):
if not fname.endswith('.gz'):
fname += '.gz'

fp = gzip.open(fname, 'wb')
fp.write(jsonstring.encode('utf-8'))
gp = gzip.open(fname, 'wb')
gp.write(jsonstring.encode('utf-8'))
gp.close()
else:
fp = open(fname, 'w', encoding='utf-8')
fp.write(jsonstring)
fp.close()
fp.close()


def _parse_json_dict(json_dict, verbose=True, full_output=False):
def _parse_json_dict(json_dict: dict[str, Any], verbose: bool=True, full_output: bool=False) -> Any:
"""Reconstruct a list of Obs or structures containing Obs from a dict that
was built out of a json string.
@@ -474,7 +477,7 @@ def get_Corr_from_dict(o):
return ol


def import_json_string(json_string, verbose=True, full_output=False):
def import_json_string(json_string: str, verbose: bool=True, full_output: bool=False) -> Union[Obs, list[Obs], Corr]:
"""Reconstruct a list of Obs or structures containing Obs from a json string.
The following structures are supported: Obs, list, numpy.ndarray, Corr
@@ -504,7 +507,7 @@ def import_json_string(json_string, verbose=True, full_output=False):
return _parse_json_dict(json.loads(json_string), verbose, full_output)


def load_json(fname, verbose=True, gz=True, full_output=False):
def load_json(fname: str, verbose: bool=True, gz: bool=True, full_output: bool=False) -> Any:
"""Import a list of Obs or structures containing Obs from a .json(.gz) file.
The following structures are supported: Obs, list, numpy.ndarray, Corr
@@ -549,7 +552,7 @@ def load_json(fname, verbose=True, gz=True, full_output=False):
return _parse_json_dict(d, verbose, full_output)


def _ol_from_dict(ind, reps='DICTOBS'):
def _ol_from_dict(ind: dict, reps: str='DICTOBS') -> tuple[list, dict]:
"""Convert a dictionary of Obs objects to a list and a dictionary that contains
placeholders instead of the Obs objects.
@@ -626,7 +629,7 @@ def obslist_replace_obs(li):
return ol, nd


def dump_dict_to_json(od, fname, description='', indent=1, reps='DICTOBS', gz=True):
def dump_dict_to_json(od: dict, fname: str, description: Union[str, dict]='', indent: int=1, reps: str='DICTOBS', gz: bool=True):
"""Export a dict of Obs or structures containing Obs to a .json(.gz) file
Parameters
@@ -666,7 +669,7 @@ def dump_dict_to_json(od, fname, description='', indent=1, reps='DICTOBS', gz=Tr
dump_to_json(ol, fname, description=desc_dict, indent=indent, gz=gz)


def _od_from_list_and_dict(ol, ind, reps='DICTOBS'):
def _od_from_list_and_dict(ol: list, ind: dict, reps: str='DICTOBS') -> dict[str, dict[str, Any]]:
"""Parse a list of Obs or structures containing Obs and an accompanying
dict, where the structures have been replaced by placeholders to a
dict that contains the structures.
@@ -727,7 +730,7 @@ def list_replace_string(li):
return nd


def load_json_dict(fname, verbose=True, gz=True, full_output=False, reps='DICTOBS'):
def load_json_dict(fname: str, verbose: bool=True, gz: bool=True, full_output: bool=False, reps: str='DICTOBS') -> dict:
"""Import a dict of Obs or structures containing Obs from a .json(.gz) file.
The following structures are supported: Obs, list, numpy.ndarray, Corr
10 changes: 8 additions & 2 deletions pyerrors/input/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import os
import fnmatch
import re
@@ -8,9 +9,10 @@
from matplotlib import gridspec
from ..obs import Obs
from ..fits import fit_lin
from typing import Optional


def fit_t0(t2E_dict, fit_range, plot_fit=False, observable='t0'):
def fit_t0(t2E_dict: dict[float, Obs], fit_range: int, plot_fit: Optional[bool]=False, observable: str='t0') -> Obs:
"""Compute the root of (flow-based) data based on a dictionary that contains
the necessary information in key-value pairs a la (flow time: observable at flow time).
@@ -97,11 +99,15 @@ def fit_t0(t2E_dict, fit_range, plot_fit=False, observable='t0'):
return -fit_result[0] / fit_result[1]


def read_pbp(path, prefix, **kwargs):
def read_pbp(path: str, prefix: str, **kwargs):
"""Read pbp format from given folder structure.
Parameters
----------
path : str
Directory to read pbp from
prefix : str
Prefix of the files to be read
r_start : list
list which contains the first config to be read for each replicum
r_stop : list
216 changes: 95 additions & 121 deletions pyerrors/input/openQCD.py

Large diffs are not rendered by default.

17 changes: 10 additions & 7 deletions pyerrors/input/pandas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import warnings
import gzip
import sqlite3
@@ -7,9 +8,11 @@
from ..correlators import Corr
from .json import create_json_string, import_json_string
import numpy as np
from pandas.core.frame import DataFrame
from pandas.core.series import Series


def to_sql(df, table_name, db, if_exists='fail', gz=True, **kwargs):
def to_sql(df: DataFrame, table_name: str, db: str, if_exists: str='fail', gz: bool=True, **kwargs):
"""Write DataFrame including Obs or Corr valued columns to sqlite database.
Parameters
@@ -34,7 +37,7 @@ def to_sql(df, table_name, db, if_exists='fail', gz=True, **kwargs):
se_df.to_sql(table_name, con=con, if_exists=if_exists, index=False, **kwargs)


def read_sql(sql, db, auto_gamma=False, **kwargs):
def read_sql(sql: str, db: str, auto_gamma: bool=False, **kwargs) -> DataFrame:
"""Execute SQL query on sqlite database and obtain DataFrame including Obs or Corr valued columns.
Parameters
@@ -57,7 +60,7 @@ def read_sql(sql, db, auto_gamma=False, **kwargs):
return _deserialize_df(extract_df, auto_gamma=auto_gamma)


def dump_df(df, fname, gz=True):
def dump_df(df: DataFrame, fname: str, gz: bool=True):
"""Exports a pandas DataFrame containing Obs valued columns to a (gzipped) csv file.
Before making use of pandas to_csv functionality Obs objects are serialized via the standardized
@@ -96,7 +99,7 @@ def dump_df(df, fname, gz=True):
out.to_csv(fname, index=False)


def load_df(fname, auto_gamma=False, gz=True):
def load_df(fname: str, auto_gamma: bool=False, gz: bool=True) -> DataFrame:
"""Imports a pandas DataFrame from a csv.(gz) file in which Obs objects are serialized as json strings.
Parameters
@@ -130,7 +133,7 @@ def load_df(fname, auto_gamma=False, gz=True):
return _deserialize_df(re_import, auto_gamma=auto_gamma)


def _serialize_df(df, gz=False):
def _serialize_df(df: DataFrame, gz: bool=False) -> DataFrame:
"""Serializes all Obs or Corr valued columns into json strings according to the pyerrors json specification.
Parameters
@@ -151,7 +154,7 @@ def _serialize_df(df, gz=False):
return out


def _deserialize_df(df, auto_gamma=False):
def _deserialize_df(df: DataFrame, auto_gamma: bool=False) -> DataFrame:
"""Deserializes all pyerrors json strings into Obs or Corr objects according to the pyerrors json specification.
Parameters
@@ -187,7 +190,7 @@ def _deserialize_df(df, auto_gamma=False):
return df


def _need_to_serialize(col):
def _need_to_serialize(col: Series) -> bool:
serialize = False
i = 0
while i < len(col) and col[i] is None:
127 changes: 66 additions & 61 deletions pyerrors/input/sfcf.py

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions pyerrors/input/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Utilities for the input"""

from __future__ import annotations
import re
import fnmatch
import os


def sort_names(ll):
def sort_names(ll: list[str]) -> list[str]:
"""Sorts a list of names of replika with searches for `r` and `id` in the replikum string.
If this search fails, a fallback method is used,
where the strings are simply compared and the first diffeing numeral is used for differentiation.
@@ -52,7 +53,7 @@ def sort_names(ll):
return ll


def check_idl(idl, che):
def check_idl(idl: list, che: list) -> str:
"""Checks if list of configurations is contained in an idl
Parameters
@@ -82,7 +83,7 @@ def check_idl(idl, che):
return miss_str


def check_params(path, param_hash, prefix, param_prefix="parameters_"):
def check_params(path: str, param_hash: str, prefix: str, param_prefix: str ="parameters_") -> dict[str, str]:
"""
Check if, for sfcf, the parameter hashes at the end of the parameter files are in fact the expected one.
5 changes: 4 additions & 1 deletion pyerrors/integrate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations
import numpy as np
from .obs import derived_observable, Obs
from autograd import jacobian
from scipy.integrate import quad as squad
from numpy import ndarray
from typing import Callable, Union


def quad(func, p, a, b, **kwargs):
def quad(func: Callable, p: Union[list[Union[float, Obs]], list[float], ndarray], a: Union[Obs, float, int], b: Union[Obs, float, int], **kwargs) -> Union[tuple[Obs, float], tuple[float, float], tuple[Obs, float, dict[str, Union[int, ndarray]]]]:
'''Performs a (one-dimensional) numeric integration of f(p, x) from a to b.
The integration is performed using scipy.integrate.quad().
30 changes: 17 additions & 13 deletions pyerrors/linalg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations
import numpy as np
import autograd.numpy as anp # Thinly-wrapped numpy
from .obs import derived_observable, CObs, Obs, import_jackknife
from numpy import ndarray
from typing import Callable, Union


def matmul(*operands):
def matmul(*operands) -> ndarray:
"""Matrix multiply all operands.
Parameters
@@ -45,6 +48,7 @@ def multi_dot_i(operands):
Nr = derived_observable(multi_dot_r, extended_operands, array_mode=True)
Ni = derived_observable(multi_dot_i, extended_operands, array_mode=True)

assert isinstance(Nr, ndarray) and isinstance(Ni, ndarray)
res = np.empty_like(Nr)
for (n, m), entry in np.ndenumerate(Nr):
res[n, m] = CObs(Nr[n, m], Ni[n, m])
@@ -59,7 +63,7 @@ def multi_dot(operands):
return derived_observable(multi_dot, operands, array_mode=True)


def jack_matmul(*operands):
def jack_matmul(*operands) -> ndarray:
"""Matrix multiply both operands making use of the jackknife approximation.
Parameters
@@ -120,7 +124,7 @@ def _imp_from_jack_c(matrix, name, idl):
return _imp_from_jack(r, name, idl)


def einsum(subscripts, *operands):
def einsum(subscripts: str, *operands) -> Union[CObs, Obs, ndarray]:
"""Wrapper for numpy.einsum
Parameters
@@ -194,24 +198,24 @@ def _imp_from_jack_c(matrix, name, idl):
return result


def inv(x):
def inv(x: ndarray) -> ndarray:
"""Inverse of Obs or CObs valued matrices."""
return _mat_mat_op(anp.linalg.inv, x)


def cholesky(x):
def cholesky(x: ndarray) -> ndarray:
"""Cholesky decomposition of Obs valued matrices."""
if any(isinstance(o, CObs) for o in x.ravel()):
raise Exception("Cholesky decomposition is not implemented for CObs.")
return _mat_mat_op(anp.linalg.cholesky, x)


def det(x):
def det(x: Union[ndarray, int]) -> Obs:
"""Determinant of Obs valued matrices."""
return _scalar_mat_op(anp.linalg.det, x)


def _scalar_mat_op(op, obs, **kwargs):
def _scalar_mat_op(op: Callable, obs: Union[ndarray, int], **kwargs) -> Obs:
"""Computes the matrix to scalar operation op to a given matrix of Obs."""
def _mat(x, **kwargs):
dim = int(np.sqrt(len(x)))
@@ -232,7 +236,7 @@ def _mat(x, **kwargs):
return derived_observable(_mat, raveled_obs, **kwargs)


def _mat_mat_op(op, obs, **kwargs):
def _mat_mat_op(op: Callable, obs: ndarray, **kwargs) -> ndarray:
"""Computes the matrix to matrix operation op to a given matrix of Obs."""
# Use real representation to calculate matrix operations for complex matrices
if any(isinstance(o, CObs) for o in obs.ravel()):
@@ -258,31 +262,31 @@ def _mat_mat_op(op, obs, **kwargs):
return derived_observable(lambda x, **kwargs: op(x), [obs], array_mode=True)[0]


def eigh(obs, **kwargs):
def eigh(obs: ndarray, **kwargs) -> tuple[ndarray, ndarray]:
"""Computes the eigenvalues and eigenvectors of a given hermitian matrix of Obs according to np.linalg.eigh."""
w = derived_observable(lambda x, **kwargs: anp.linalg.eigh(x)[0], obs)
v = derived_observable(lambda x, **kwargs: anp.linalg.eigh(x)[1], obs)
return w, v


def eig(obs, **kwargs):
def eig(obs: ndarray, **kwargs) -> ndarray:
"""Computes the eigenvalues of a given matrix of Obs according to np.linalg.eig."""
w = derived_observable(lambda x, **kwargs: anp.real(anp.linalg.eig(x)[0]), obs)
return w


def eigv(obs, **kwargs):
def eigv(obs: ndarray, **kwargs) -> ndarray:
"""Computes the eigenvectors of a given hermitian matrix of Obs according to np.linalg.eigh."""
v = derived_observable(lambda x, **kwargs: anp.linalg.eigh(x)[1], obs)
return v


def pinv(obs, **kwargs):
def pinv(obs: ndarray, **kwargs) -> ndarray:
"""Computes the Moore-Penrose pseudoinverse of a matrix of Obs."""
return derived_observable(lambda x, **kwargs: anp.linalg.pinv(x), obs)


def svd(obs, **kwargs):
def svd(obs: ndarray, **kwargs) -> tuple[ndarray, ndarray, ndarray]:
"""Computes the singular value decomposition of a matrix of Obs."""
u = derived_observable(lambda x, **kwargs: anp.linalg.svd(x, full_matrices=False)[0], obs)
s = derived_observable(lambda x, **kwargs: anp.linalg.svd(x, full_matrices=False)[1], obs)
31 changes: 20 additions & 11 deletions pyerrors/misc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations
import platform
import numpy as np
import scipy
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import pickle
from .obs import Obs
from .obs import Obs, CObs
from .version import __version__
from numpy import ndarray
from typing import Union, TYPE_CHECKING

if TYPE_CHECKING:
from .correlators import Corr


def print_config():
@@ -23,7 +29,7 @@ def print_config():
print(f"{key: <10}\t {value}")


def errorbar(x, y, axes=plt, **kwargs):
def errorbar(x: Union[ndarray[int, float, Obs], list[int, float, Obs]], y: Union[ndarray[int, float, Obs], list[int, float, Obs]], axes=plt, **kwargs):
"""pyerrors wrapper for the errorbars method of matplotlib
Parameters
@@ -54,7 +60,7 @@ def errorbar(x, y, axes=plt, **kwargs):
axes.errorbar(val["x"], val["y"], xerr=err["x"], yerr=err["y"], **kwargs)


def dump_object(obj, name, **kwargs):
def dump_object(obj: Corr, name: str, **kwargs):
"""Dump object into pickle file.
Parameters
@@ -70,15 +76,18 @@ def dump_object(obj, name, **kwargs):
-------
None
"""
if 'path' in kwargs:
file_name = kwargs.get('path') + '/' + name + '.p'
path = kwargs.get('path')
if path is not None:
if not isinstance(path, str):
raise Exception("Path has to be a string.")
file_name = path + '/' + name + '.p'
else:
file_name = name + '.p'
with open(file_name, 'wb') as fb:
pickle.dump(obj, fb)


def load_object(path):
def load_object(path: str) -> Union[Obs, Corr]:
"""Load object from pickle file.
Parameters
@@ -95,7 +104,7 @@ def load_object(path):
return pickle.load(file)


def pseudo_Obs(value, dvalue, name, samples=1000):
def pseudo_Obs(value: Union[float, int], dvalue: Union[float, int], name: str, samples: int=1000) -> Obs:
"""Generate an Obs object with given value, dvalue and name for test purposes
Parameters
@@ -118,11 +127,11 @@ def pseudo_Obs(value, dvalue, name, samples=1000):
return Obs([np.zeros(samples) + value], [name])
else:
for _ in range(100):
deltas = [np.random.normal(0.0, dvalue * np.sqrt(samples), samples)]
deltas = np.array([np.random.normal(0.0, dvalue * np.sqrt(samples), samples)])
deltas -= np.mean(deltas)
deltas *= dvalue / np.sqrt((np.var(deltas) / samples)) / np.sqrt(1 + 3 / samples)
deltas += value
res = Obs(deltas, [name])
res = Obs(list(deltas), [name])
res.gamma_method(S=2, tau_exp=0)
if abs(res.dvalue - dvalue) < 1e-10 * dvalue:
break
@@ -132,7 +141,7 @@ def pseudo_Obs(value, dvalue, name, samples=1000):
return res


def gen_correlated_data(means, cov, name, tau=0.5, samples=1000):
def gen_correlated_data(means: Union[ndarray, list[float]], cov: ndarray, name: str, tau: Union[float, ndarray]=0.5, samples: int=1000) -> list[Obs]:
""" Generate observables with given covariance and autocorrelation times.
Parameters
@@ -174,7 +183,7 @@ def gen_correlated_data(means, cov, name, tau=0.5, samples=1000):
return [Obs([dat], [name]) for dat in corr_data.T]


def _assert_equal_properties(ol, otype=Obs):
def _assert_equal_properties(ol: Union[list[Obs], list[CObs], ndarray]):
otype = type(ol[0])
for o in ol[1:]:
if not isinstance(o, otype):
4 changes: 3 additions & 1 deletion pyerrors/mpm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations
import numpy as np
import scipy.linalg
from .obs import Obs
from .linalg import svd, eig
from typing import Optional


def matrix_pencil_method(corrs, k=1, p=None, **kwargs):
def matrix_pencil_method(corrs: list[Obs], k: int=1, p: Optional[int]=None, **kwargs) -> list[Obs]:
"""Matrix pencil method to extract k energy levels from data
Implementation of the matrix pencil method based on
306 changes: 167 additions & 139 deletions pyerrors/obs.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyerrors/roots.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations
import numpy as np
import scipy.optimize
from autograd import jacobian
from .obs import derived_observable
from .obs import Obs, derived_observable
from typing import Callable, Union


def find_root(d, func, guess=1.0, **kwargs):
def find_root(d: Union[Obs, list[Obs]], func: Callable, guess: float=1.0, **kwargs) -> Obs:
r'''Finds the root of the function func(x, d) where d is an `Obs`.
Parameters
1 change: 1 addition & 0 deletions pyerrors/special.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import scipy
import numpy as np
from autograd.extend import primitive, defvjp
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -4,3 +4,7 @@ build-backend = "setuptools.build_meta"

[tool.ruff.lint]
ignore = ["F403"]

[tool.mypy]
warn_unused_configs = true
ignore_missing_imports = true
4 changes: 3 additions & 1 deletion tests/correlators_test.py
Original file line number Diff line number Diff line change
@@ -181,6 +181,8 @@ def f(a, x):
with pytest.raises(ValueError):
my_corr.fit(f, [0, 2, 3])

fit_res = my_corr.fit(f, fitrange=[0, 1])


def test_plateau():
my_corr = pe.correlators.Corr([pe.pseudo_Obs(1.01324, 0.05, 't'), pe.pseudo_Obs(1.042345, 0.008, 't')])
@@ -226,7 +228,7 @@ def test_utility():
corr.print()
corr.print([2, 4])
corr.show()
corr.show(comp=corr)
corr.show(comp=corr, x_range=[2, 5.], y_range=[2, 3.], hide_sigma=0.5, references=[.1, .2, .6], title='TEST')

corr.dump('test_dump', datatype="pickle", path='.')
corr.dump('test_dump', datatype="pickle")
1 change: 1 addition & 0 deletions tests/obs_test.py
Original file line number Diff line number Diff line change
@@ -410,6 +410,7 @@ def test_cobs():
obs2 = pe.pseudo_Obs(-0.2, 0.03, 't')

my_cobs = pe.CObs(obs1, obs2)
my_cobs.gm()
assert +my_cobs == my_cobs
assert -my_cobs == 0 - my_cobs
my_cobs == my_cobs