Skip to content

Commit

Permalink
Remove abinit __all__ module star exports (#804)
Browse files Browse the repository at this point in the history
* refactor CommonEosMaker

* ruff auto-fixes including removing __all__ module star exports
  • Loading branch information
janosh authored Apr 10, 2024
1 parent f26b43b commit f25693d
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default_language_version:
exclude: ^(.github/|tests/test_data/abinit/)
repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.3.4
rev: v0.3.5
hooks:
- id: ruff
args: [--fix]
Expand Down
7 changes: 0 additions & 7 deletions src/atomate2/abinit/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@

from atomate2.abinit.sets.base import AbinitInputGenerator

__all__ = [
"out_to_in",
"fname2ext",
"load_abinit_input",
"write_abinit_input_set",
]


logger = logging.getLogger(__name__)

Expand Down
2 changes: 0 additions & 2 deletions src/atomate2/abinit/jobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@

logger = logging.getLogger(__name__)

__all__ = ["StaticMaker", "NonSCFMaker", "RelaxMaker"]


@dataclass
class StaticMaker(BaseAbinitMaker):
Expand Down
5 changes: 2 additions & 3 deletions src/atomate2/abinit/sets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

from pymatgen.core.structure import Structure

__all__ = ["AbinitInputSet", "AbinitInputGenerator"]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,7 +98,7 @@ def write_input(
zip_inputs=zip_inputs,
)
del self.inputs["abinit_input.json"]
indir, outdir, tmpdir = self.set_workdir(workdir=directory)
indir, _outdir, _tmpdir = self.set_workdir(workdir=directory)

if self.input_files:
out_to_in(
Expand Down Expand Up @@ -718,7 +717,7 @@ def _get_kpoints(
if kconfig.get("line_density"):
# handle line density generation
kpath = HighSymmKpath(structure, **kconfig.get("kpath_kwargs", {}))
frac_k_points, k_points_labels = kpath.get_kpoints(
frac_k_points, _k_points_labels = kpath.get_kpoints(
line_density=kconfig["line_density"], coords_are_cartesian=False
)
base_kpoints = KSampling(
Expand Down
16 changes: 5 additions & 11 deletions src/atomate2/abinit/sets/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@
from pymatgen.io.abinit import PseudoTable
from pymatgen.io.abinit.abiobjects import KSampling

__all__ = [
"StaticSetGenerator",
"NonSCFSetGenerator",
"RelaxSetGenerator",
]


GS_RESTART_FROM_DEPS: tuple = (f"{SCF}|{RELAX}|{MOLECULAR_DYNAMICS}:WFK|DEN",)
GS_RESTART_FROM_DEPS = (f"{SCF}|{RELAX}|{MOLECULAR_DYNAMICS}:WFK|DEN",)


@dataclass
Expand Down Expand Up @@ -123,13 +117,13 @@ def _get_nband(self, prev_outputs: list[str] | None) -> int:
f"Should have exactly one previous output. Found {len(abinit_inputs)}"
)
previous_abinit_input = next(iter(abinit_inputs.values()))
nband = previous_abinit_input.get(
n_band = previous_abinit_input.get(
"nband",
previous_abinit_input.structure.num_valence_electrons(
previous_abinit_input.pseudos
),
)
return int(np.ceil(nband * self.nbands_factor))
return int(np.ceil(n_band * self.nbands_factor))


@dataclass
Expand All @@ -154,7 +148,7 @@ class NonScfWfqInputGenerator(AbinitInputGenerator):

calc_type: str = "nscf_wfq"

wfq_tol: dict = field(default_factory=lambda: {"tolwfr": 1.0e-18})
wfq_tol: dict = field(default_factory=lambda: {"tolwfr": 1e-18})

restart_from_deps: tuple = (f"{NSCF}:WFQ",)
prev_outputs_deps: tuple = (f"{SCF}:DEN",)
Expand Down Expand Up @@ -201,7 +195,7 @@ class RelaxSetGenerator(AbinitInputGenerator):
factory: Callable = ion_ioncell_relax_input
restart_from_deps: tuple = GS_RESTART_FROM_DEPS
relax_cell: bool = True
tolmxf: float = 5.0e-5
tolmxf: float = 5e-5

def get_abinit_input(
self,
Expand Down
46 changes: 24 additions & 22 deletions src/atomate2/common/flows/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CommonEosMaker(Maker):
initial_relax_maker : .Maker | None
Maker to relax the input structure, defaults to None (no initial relaxation).
eos_relax_maker : .Maker
Maker to relax deformationed structures for the EOS fit.
Maker to relax deformed structures for the EOS fit.
static_maker : .Maker | None
Maker to generate statics after each relaxation, defaults to None.
strain : tuple[float]
Expand Down Expand Up @@ -122,19 +122,19 @@ def make(self, structure: Structure, prev_dir: str | Path = None) -> Flow:
if self.initial_relax_maker:
# Cell without applied strain already included from relax/equilibrium steps.
# Perturb this point (or these points) if included
zero_strain_mask = np.abs(strain_l) < 1.0e-15
zero_strain_mask = np.abs(strain_l) < 1e-15
if np.any(zero_strain_mask):
nzs = len(strain_l[zero_strain_mask])
shift = strain_delta / (nzs + 1.0) * np.linspace(-1.0, 1.0, nzs)
strain_l[np.abs(strain_l) < 1.0e-15] += shift
strain_l[np.abs(strain_l) < 1e-15] += shift

deformation_l = [(np.identity(3) * (1.0 + eps)).tolist() for eps in strain_l]

# apply strain to structures, return list of transformations
transformations = apply_strain_to_structure(structure, deformation_l)
jobs["utility"] += [transformations]

for idef in range(self.number_of_frames):
for frame_idx in range(self.number_of_frames):
if self._store_transformation_information:
with contextlib.suppress(Exception):
# write details of the transformation to the
Expand All @@ -144,44 +144,46 @@ def make(self, structure: Structure, prev_dir: str | Path = None) -> Flow:
# is automatically converted to a "." in the filename.
self.eos_relax_maker.write_additional_data[
"transformations:json"
] = transformations.output[idef]
] = transformations.output[frame_idx]

relax_job = self.eos_relax_maker.make(
structure=transformations.output[idef].final_structure,
structure=transformations.output[frame_idx].final_structure,
prev_dir=prev_dir,
)
relax_job.name += f" deformation {idef}"
relax_job.name += f" deformation {frame_idx}"
jobs["relax"].append(relax_job)

if self.static_maker:
static_job = self.static_maker.make(
structure=relax_job.output.structure,
prev_dir=relax_job.output.dir_name,
)
static_job.name += f" {idef}"
static_job.name += f" {frame_idx}"
jobs["static"].append(static_job)

for key in job_types:
for i in range(len(jobs[key])):
flow_output[key]["energy"].append(jobs[key][i].output.output.energy)
flow_output[key]["volume"].append(jobs[key][i].output.structure.volume)
flow_output[key]["stress"].append(jobs[key][i].output.output.stress)
for idx in range(len(jobs[key])):
output = jobs[key][idx].output.output
flow_output[key]["energy"] += [output.energy]
flow_output[key]["volume"] += [output.structure.volume]
flow_output[key]["stress"] += [output.stress]

if self.postprocessor is not None:
if len(jobs["relax"]) < self.postprocessor.min_data_points:
min_points = self.postprocessor.min_data_points
if len(jobs["relax"]) < min_points:
raise ValueError(
"To perform least squares EOS fit with "
f"{self.postprocessor.__class__}, you must specify "
f"self.number_of_frames >= {self.postprocessor.min_data_points}."
f"{type(self.postprocessor).__name__}, you must specify "
f"self.number_of_frames >= {min_points}."
)

postprocess = self.postprocessor.make(flow_output)
postprocess.name = self.name + " postprocessing"
flow_output = postprocess.output
jobs["utility"] += [postprocess]
post_process = self.postprocessor.make(flow_output)
post_process.name = self.name + " postprocessing"
flow_output = post_process.output
jobs["utility"] += [post_process]

joblist = []
job_list = []
for key in jobs:
joblist += jobs[key]
job_list += jobs[key]

return Flow(jobs=joblist, output=flow_output, name=self.name)
return Flow(jobs=job_list, output=flow_output, name=self.name)
2 changes: 1 addition & 1 deletion src/atomate2/qchem/sets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def write_input(
for k, v in inputs.items():
if v is not None and (overwrite or not (directory / k).exists()):
with zopen(directory / k, "wt") as f:
f.write(v.__str__())
f.write(str(v))
elif not overwrite and (directory / k).exists():
raise FileExistsError(f"{directory / k} already exists.")

Expand Down
38 changes: 10 additions & 28 deletions src/atomate2/vasp/jobs/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@
if TYPE_CHECKING:
from atomate2.vasp.sets.base import VaspInputGenerator

# No prefix, base atomate 2 parameters

copy_wavecar = lambda: {"additional_vasp_files": ("WAVECAR",)} # noqa: E731


# No prefix, base atomate 2 parameters
@dataclass
class EosRelaxMaker(BaseVaspMaker):
"""
Expand Down Expand Up @@ -85,14 +87,10 @@ class EosRelaxMaker(BaseVaspMaker):

name: str = "EOS GGA relax"
input_set_generator: VaspInputGenerator = field(default_factory=EosSetGenerator)
copy_vasp_kwargs: dict = field(
default_factory=lambda: {"additional_vasp_files": ("WAVECAR",)}
)
copy_vasp_kwargs: dict = field(default_factory=copy_wavecar)


# MPLegacy prefix, legacy MP PBE-GGA


@dataclass
class MPLegacyEosRelaxMaker(BaseVaspMaker):
"""
Expand Down Expand Up @@ -126,9 +124,7 @@ class MPLegacyEosRelaxMaker(BaseVaspMaker):
input_set_generator: VaspInputGenerator = field(
default_factory=MPLegacyEosRelaxSetGenerator
)
copy_vasp_kwargs: dict = field(
default_factory=lambda: {"additional_vasp_files": ("WAVECAR",)}
)
copy_vasp_kwargs: dict = field(default_factory=copy_wavecar)


@dataclass
Expand Down Expand Up @@ -164,14 +160,10 @@ class MPLegacyEosStaticMaker(BaseVaspMaker):
input_set_generator: VaspInputGenerator = field(
default_factory=MPLegacyEosStaticSetGenerator
)
copy_vasp_kwargs: dict = field(
default_factory=lambda: {"additional_vasp_files": ("WAVECAR",)}
)
copy_vasp_kwargs: dict = field(default_factory=copy_wavecar)


# MPGGA prefix, MP PBE-GGA compatible parameters


@dataclass
class MPGGAEosRelaxMaker(BaseVaspMaker):
"""
Expand Down Expand Up @@ -205,9 +197,7 @@ class MPGGAEosRelaxMaker(BaseVaspMaker):
input_set_generator: VaspInputGenerator = field(
default_factory=MPGGAEosRelaxSetGenerator
)
copy_vasp_kwargs: dict = field(
default_factory=lambda: {"additional_vasp_files": ("WAVECAR",)}
)
copy_vasp_kwargs: dict = field(default_factory=copy_wavecar)


@dataclass
Expand Down Expand Up @@ -243,14 +233,10 @@ class MPGGAEosStaticMaker(BaseVaspMaker):
input_set_generator: VaspInputGenerator = field(
default_factory=MPGGAEosStaticSetGenerator
)
copy_vasp_kwargs: dict = field(
default_factory=lambda: {"additional_vasp_files": ("WAVECAR",)}
)
copy_vasp_kwargs: dict = field(default_factory=copy_wavecar)


# MPMetaGGA prefix, MP r2SCAN-meta-GGA compatible


@dataclass
class MPMetaGGAEosPreRelaxMaker(BaseVaspMaker):
"""
Expand Down Expand Up @@ -319,9 +305,7 @@ class MPMetaGGAEosRelaxMaker(BaseVaspMaker):
input_set_generator: VaspInputGenerator = field(
default_factory=MPMetaGGAEosRelaxSetGenerator
)
copy_vasp_kwargs: dict = field(
default_factory=lambda: {"additional_vasp_files": ("WAVECAR",)}
)
copy_vasp_kwargs: dict = field(default_factory=copy_wavecar)


@dataclass
Expand Down Expand Up @@ -357,6 +341,4 @@ class MPMetaGGAEosStaticMaker(BaseVaspMaker):
input_set_generator: VaspInputGenerator = field(
default_factory=MPMetaGGAEosStaticSetGenerator
)
copy_vasp_kwargs: dict = field(
default_factory=lambda: {"additional_vasp_files": ("WAVECAR",)}
)
copy_vasp_kwargs: dict = field(default_factory=copy_wavecar)
Loading

0 comments on commit f25693d

Please sign in to comment.