From 326beb9598a34a3f5e28b961950f333df1aa84dc Mon Sep 17 00:00:00 2001 From: Thomas Purcell Date: Sun, 8 Sep 2024 07:23:49 -0700 Subject: [PATCH] Make AimsSpeciesFile a dataclass (#4054) * Modify for proper dataclass remove commented code and __init__ function * refactor and use Self return type for from_... methods --------- Co-authored-by: Janosh Riebesell --- src/pymatgen/io/aims/inputs.py | 76 +++++++++++----------------------- 1 file changed, 24 insertions(+), 52 deletions(-) diff --git a/src/pymatgen/io/aims/inputs.py b/src/pymatgen/io/aims/inputs.py index 8dd8d01c1d2..6a5081d7688 100644 --- a/src/pymatgen/io/aims/inputs.py +++ b/src/pymatgen/io/aims/inputs.py @@ -689,54 +689,27 @@ def from_dict(cls, dct: dict[str, Any]) -> Self: return cls(_parameters=decoded["parameters"]) +@dataclass class AimsSpeciesFile: - """An FHI-aims single species' defaults file.""" + """An FHI-aims single species' defaults file. - def __init__(self, data: str, label: str | None = None) -> None: - """ - Args: - data (str): A string of the complete species defaults file - label (str): A string representing the name of species - """ - self.data = data - self.label = label + Attributes: + data (str): A string of the complete species defaults file + label (str): A string representing the name of species + """ + + data: str = "" + label: str | None = None + + def __post_init__(self) -> None: + """Set default label""" if self.label is None: - for line in data.splitlines(): + for line in self.data.splitlines(): if "species" in line: self.label = line.split()[1] - def __eq__(self, other: object) -> bool: - """True if two species are equal.""" - if not isinstance(other, AimsSpeciesFile): - return NotImplemented - return self.data == other.data - - def __lt__(self, other: object) -> bool: - """True if self is less than other.""" - if not isinstance(other, AimsSpeciesFile): - return NotImplemented - return self.data < other.data - - def __le__(self, other: object) -> bool: - """True if self is less than or equal to other.""" - if not isinstance(other, AimsSpeciesFile): - return NotImplemented - return self.data <= other.data - - def __gt__(self, other: object) -> bool: - """True if self is greater than other.""" - if not isinstance(other, AimsSpeciesFile): - return NotImplemented - return self.data > other.data - - def __ge__(self, other: object) -> bool: - """True if self is greater than or equal to other.""" - if not isinstance(other, AimsSpeciesFile): - return NotImplemented - return self.data >= other.data - @classmethod - def from_file(cls, filename: str, label: str | None = None) -> AimsSpeciesFile: + def from_file(cls, filename: str, label: str | None = None) -> Self: """Initialize from file. Args: @@ -744,15 +717,15 @@ def from_file(cls, filename: str, label: str | None = None) -> AimsSpeciesFile: label (str): A string representing the name of species Returns: - The AimsSpeciesFile instance + AimsSpeciesFile """ with zopen(filename, mode="rt") as file: - return cls(file.read(), label) + return cls(data=file.read(), label=label) @classmethod def from_element_and_basis_name( cls, element: str, basis: str, *, species_dir: str | Path | None = None, label: str | None = None - ) -> AimsSpeciesFile: + ) -> Self: """Initialize from element and basis names. Args: @@ -763,7 +736,7 @@ def from_element_and_basis_name( then equal to element Returns: - an AimsSpeciesFile instance + AimsSpeciesFile """ # check if element is in the Periodic Table (+ Emptium) if element != "Emptium": @@ -795,25 +768,24 @@ def from_element_and_basis_name( f"Can't find the species' defaults file for {element} in {basis} basis set. Paths tried: {paths_to_try}" ) - def __str__(self): + def __str__(self) -> str: """String representation of the species' defaults file""" return re.sub(r"^ *species +\w+", f" species {self.label}", self.data, flags=re.MULTILINE) @property def element(self) -> str: - match = re.search(r"^ *species +(\w+)", self.data, flags=re.MULTILINE) - if match is None: - raise ValueError("Can't find element in species' defaults file") - return match.group(1) + if match := re.search(r"^ *species +(\w+)", self.data, flags=re.MULTILINE): + return match[1] + raise ValueError("Can't find element in species' defaults file") def as_dict(self) -> dict[str, Any]: """Dictionary representation of the species' defaults file.""" return {"label": self.label, "data": self.data, "@module": type(self).__module__, "@class": type(self).__name__} @classmethod - def from_dict(cls, dct: dict[str, Any]) -> AimsSpeciesFile: + def from_dict(cls, dct: dict[str, Any]) -> Self: """Deserialization of the AimsSpeciesFile object""" - return AimsSpeciesFile(data=dct["data"], label=dct["label"]) + return cls(**dct) class SpeciesDefaults(list, MSONable):