diff --git a/ipsuite/calculators/ase_geoopt.py b/ipsuite/calculators/ase_geoopt.py index 06b8469d..6b677b7a 100644 --- a/ipsuite/calculators/ase_geoopt.py +++ b/ipsuite/calculators/ase_geoopt.py @@ -8,7 +8,6 @@ import h5py import znh5md import zntrack -from ase.io.trajectory import TrajectoryWriter from ipsuite import base from ipsuite.utils.ase_sim import freeze_copy_atoms @@ -23,6 +22,8 @@ class ASEGeoOpt(base.ProcessSingleAtom): ---------- model: zntrack.Node A node that implements 'get_calculator'. + maxstep: int, optional + Maximum number of steps to perform. """ model = zntrack.deps() @@ -36,6 +37,8 @@ class ASEGeoOpt(base.ProcessSingleAtom): init_kwargs: dict = zntrack.params({}) dump_rate = zntrack.params(1000) + maxstep: int = zntrack.params(None) + traj_file: pathlib.Path = zntrack.outs_path(zntrack.nwd / "trajectory.h5") def run(self): @@ -63,7 +66,7 @@ def run(self): optimizer = getattr(ase.optimize, self.optimizer) dyn = optimizer(atoms, **self.init_kwargs) - for _ in dyn.irun(**self.run_kwargs): + for step, _ in enumerate(dyn.irun(**self.run_kwargs)): stop = [] atoms_cache.append(freeze_copy_atoms(atoms)) if len(atoms_cache) == self.dump_rate: @@ -89,6 +92,9 @@ def run(self): dyn.log() break + if self.maxstep is not None and step >= self.maxstep: + break + db.add( znh5md.io.AtomsReader( atoms_cache, @@ -114,51 +120,3 @@ def file_handle(filename): znh5md.FormatHandler, file_handle=file_handle ), ).get_atoms_list() - - -class BatchASEGeoOpt(base.ProcessAtoms): - """Class to run a geometry optimization with ASE. - - Parameters - ---------- - model: zntrack.Node - A node that implements 'get_calculator'. - """ - - model = zntrack.deps() - model_outs = zntrack.outs_path(zntrack.nwd / "model_outs") - optimizer: str = zntrack.params("FIRE") - traj: pathlib.Path = zntrack.outs_path(zntrack.nwd / "optim.traj") - optimized_structures: pathlib.Path = zntrack.outs_path(zntrack.nwd / "final.traj") - - repeat: list = zntrack.params([1, 1, 1]) - run_kwargs: dict = zntrack.params({"fmax": 0.05}) - init_kwargs: dict = zntrack.params({}) - - def run(self): - self.model_outs.mkdir(parents=True, exist_ok=True) - (self.model_outs / "outs.txt").write_text("Lorem Ipsum") - calculator = self.model.get_calculator(directory=self.model_outs) - atoms_list = self.get_data() - - opt_structures = TrajectoryWriter(self.optimized_structures, mode="a") - - for atoms in atoms_list: - atoms = atoms.repeat(self.repeat) - if self.optimizer is not None: - atoms.calc = calculator - optimizer = getattr(ase.optimize, self.optimizer) - - dyn = optimizer( - atoms, trajectory=self.traj.as_posix(), **self.init_kwargs - ) - dyn.run(**self.run_kwargs) - opt_structures.write(atoms) - - @property - def atoms(self): - return list(ase.io.iread(self.optimized_structures.as_posix())) - - @property - def trajectories(self): - return list(ase.io.iread(self.traj.as_posix())) diff --git a/tests/unit_tests/calculators/test_u_ase_geoopt.py b/tests/unit_tests/calculators/test_u_ase_geoopt.py index b4144865..d42f8412 100644 --- a/tests/unit_tests/calculators/test_u_ase_geoopt.py +++ b/tests/unit_tests/calculators/test_u_ase_geoopt.py @@ -49,9 +49,20 @@ def test_ase_geoopt(proj_path, cu_box): run_kwargs={"fmax": 0.05}, ) + opt_max_step = ips.calculators.ASEGeoOpt( + data=data.atoms, + model=model, + optimizer="FIRE", + checker_list=[check], + run_kwargs={"fmax": 0.05}, + maxstep=2, + name="opt_max_step", + ) + project.run(eager=True) assert len(opt.atoms) == n_iterations + 1 + assert len(opt_max_step.atoms) == 3 forces = np.linalg.norm(opt.atoms[0].get_forces(), 2, 1) fmax_start = np.max(forces)