diff --git a/custodian/custodian.py b/custodian/custodian.py index ff7f6f96..c064484a 100644 --- a/custodian/custodian.py +++ b/custodian/custodian.py @@ -22,7 +22,7 @@ from monty.shutil import gzip_dir from monty.tempfile import ScratchDir -from .utils import get_execution_host_info +from .utils import get_execution_host_info, tracked_lru_cache __author__ = "Shyue Ping Ong, William Davidson Richards" __copyright__ = "Copyright 2012, The Materials Project" @@ -683,6 +683,8 @@ def _do_check(self, handlers, terminate_func=None): self.run_log[-1]["corrections"].extend(corrections) # We do a dump of the run log after each check. dumpfn(self.run_log, Custodian.LOG_FILE, cls=MontyEncoder, indent=4) + # Clear all the cached values to avoid reusing them in a subsequent check + tracked_lru_cache.tracked_cache_clear() return len(corrections) > 0 diff --git a/custodian/tests/test_utils.py b/custodian/tests/test_utils.py new file mode 100644 index 00000000..e7fc4de8 --- /dev/null +++ b/custodian/tests/test_utils.py @@ -0,0 +1,38 @@ +import unittest + +from custodian.utils import tracked_lru_cache + + +class TrackedLruCacheTest(unittest.TestCase): + def setUp(self): + # clear cache before and after each test to avoid + # unexpected caching from other tests + tracked_lru_cache.tracked_cache_clear() + + def test_cache_and_clear(self): + n_calls = 0 + + @tracked_lru_cache + def some_func(x): + nonlocal n_calls + n_calls += 1 + return x + + assert some_func(1) == 1 + assert n_calls == 1 + assert some_func(2) == 2 + assert n_calls == 2 + assert some_func(1) == 1 + assert n_calls == 2 + + assert len(tracked_lru_cache.cached_functions) == 1 + + tracked_lru_cache.tracked_cache_clear() + + assert len(tracked_lru_cache.cached_functions) == 0 + + assert some_func(1) == 1 + assert n_calls == 3 + + def tearDown(self): + tracked_lru_cache.tracked_cache_clear() diff --git a/custodian/utils.py b/custodian/utils.py index ecebced1..785c8bc9 100644 --- a/custodian/utils.py +++ b/custodian/utils.py @@ -1,5 +1,6 @@ """Utility function and classes.""" +import functools import logging import os import tarfile @@ -44,3 +45,47 @@ def get_execution_host_info(): except Exception: pass return host or "unknown", cluster or "unknown" + + +class tracked_lru_cache: + """ + Decorator wrapping the functools.lru_cache adding a tracking of the + functions that have been wrapped. + + Exposes a method to clear the cache of all the wrapped functions. + + Used to cache the parsed outputs in handlers/validators, to avoid + multiple parsing of the same file. + Allows Custodian to clear the cache after all the checks have been performed. + """ + + cached_functions: set = set() + + def __init__(self, func): + """ + Args: + func: function to be decorated + """ + self.func = functools.lru_cache(func) + functools.update_wrapper(self, func) + + # expose standard lru_cache functions + self.cache_info = self.func.cache_info + self.cache_clear = self.func.cache_clear + + def __call__(self, *args, **kwargs): + """ + Call the decorated function + """ + result = self.func(*args, **kwargs) + self.cached_functions.add(self.func) + return result + + @classmethod + def tracked_cache_clear(cls): + """ + Clear the cache of all the decorated functions. + """ + while cls.cached_functions: + f = cls.cached_functions.pop() + f.cache_clear() diff --git a/custodian/vasp/handlers.py b/custodian/vasp/handlers.py index 014e6dfb..33eff11c 100644 --- a/custodian/vasp/handlers.py +++ b/custodian/vasp/handlers.py @@ -22,7 +22,7 @@ from monty.serialization import loadfn from pymatgen.core.structure import Structure from pymatgen.io.vasp.inputs import Incar, Kpoints, Poscar, VaspInput -from pymatgen.io.vasp.outputs import Oszicar, Outcar, Vasprun +from pymatgen.io.vasp.outputs import Oszicar from pymatgen.io.vasp.sets import MPScanRelaxSet from pymatgen.transformations.standard_transformations import SupercellTransformation @@ -31,6 +31,7 @@ from custodian.custodian import ErrorHandler from custodian.utils import backup from custodian.vasp.interpreter import VaspModder +from custodian.vasp.io import load_outcar, load_vasprun __author__ = ( "Shyue Ping Ong, William Davidson Richards, Anubhav Jain, Wei Chen, " @@ -214,7 +215,7 @@ def correct(self): # error count to 1 to skip first fix if self.error_count["brmix"] == 0: try: - assert Outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False + assert load_outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False except Exception: self.error_count["brmix"] += 1 @@ -510,7 +511,7 @@ def correct(self): # resources, seems to be to just increase NCORE slightly. That's what I do here. nprocs = multiprocessing.cpu_count() try: - nelect = Outcar("OUTCAR").nelect + nelect = load_outcar(os.path.join(os.getcwd(), "OUTCAR")).nelect except Exception: nelect = 1 # dummy value if nelect < nprocs: @@ -706,7 +707,7 @@ def correct(self): if ( "lrf_comm" in self.errors - and Outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False + and load_outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False and not vi["INCAR"].get("LPEAD") ): actions.append({"dict": "INCAR", "action": {"_set": {"LPEAD": True}}}) @@ -897,7 +898,7 @@ def check(self): self.max_drift = incar["EDIFFG"] * -1 try: - outcar = Outcar("OUTCAR") + outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR")) except Exception: # Can't perform check if Outcar not valid return False @@ -917,7 +918,7 @@ def correct(self): vi = VaspInput.from_directory(".") incar = vi["INCAR"] - outcar = Outcar("OUTCAR") + outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR")) # Move CONTCAR to POSCAR actions.append({"file": "CONTCAR", "action": {"_file_copy": {"dest": "POSCAR"}}}) @@ -988,7 +989,7 @@ def check(self): return False try: - v = Vasprun(self.output_vasprun) + v = load_vasprun(os.path.join(os.getcwd(), self.output_vasprun)) if v.converged: return False except Exception: @@ -1031,7 +1032,7 @@ def __init__(self, output_filename: str = "vasprun.xml"): def check(self): """Check for error.""" try: - v = Vasprun(self.output_filename) + v = load_vasprun(os.path.join(os.getcwd(), self.output_filename)) if not v.converged: return True except Exception: @@ -1040,7 +1041,7 @@ def check(self): def correct(self): """Perform corrections.""" - v = Vasprun(self.output_filename) + v = load_vasprun(os.path.join(os.getcwd(), self.output_filename)) algo = v.incar.get("ALGO", "Normal").lower() actions = [] if not v.converged_electronic: @@ -1139,7 +1140,7 @@ def __init__(self, output_filename: str = "vasprun.xml"): def check(self): """Check for error.""" try: - v = Vasprun(self.output_filename) + v = load_vasprun(os.path.join(os.getcwd(), self.output_filename)) # check whether bandgap is zero, tetrahedron smearing was used # and relaxation is performed. if v.eigenvalue_band_properties[0] == 0 and v.incar.get("ISMEAR", 1) < -3 and v.incar.get("NSW", 0) > 1: @@ -1186,7 +1187,7 @@ def __init__(self, output_filename: str = "vasprun.xml"): def check(self): """Check for error.""" try: - v = Vasprun(self.output_filename) + v = load_vasprun(os.path.join(os.getcwd(), self.output_filename)) # check whether bandgap is zero and KSPACING is too large # using 0 as fallback value for KSPACING so that this handler does not trigger if KSPACING is not set if v.eigenvalue_band_properties[0] == 0 and v.incar.get("KSPACING", 0) > 0.22: @@ -1244,7 +1245,7 @@ def check(self): """Check for error.""" incar = Incar.from_file("INCAR") try: - outcar = Outcar("OUTCAR") + outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR")) except Exception: # Can't perform check if Outcar not valid return False @@ -1601,7 +1602,7 @@ def check(self): if self.wall_time: run_time = datetime.datetime.now() - self.start_time total_secs = run_time.total_seconds() - outcar = Outcar("OUTCAR") + outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR")) if not self.electronic_step_stop: # Determine max time per ionic step. outcar.read_pattern({"timings": r"LOOP\+.+real time(.+)"}, postprocess=float) diff --git a/custodian/vasp/io.py b/custodian/vasp/io.py new file mode 100644 index 00000000..9b954936 --- /dev/null +++ b/custodian/vasp/io.py @@ -0,0 +1,38 @@ +""" +Helper functions for dealing with vasp files. +""" + +from pymatgen.io.vasp.outputs import Outcar, Vasprun + +from custodian.utils import tracked_lru_cache + + +@tracked_lru_cache +def load_vasprun(filepath, **vasprun_kwargs): + """ + Load Vasprun object from file path. + Caches the output for reuse. + + Args: + filepath: path to the vasprun.xml file. + **vasprun_kwargs: kwargs arguments passed to the Vasprun init. + + Returns: + The Vasprun object + """ + return Vasprun(filepath, **vasprun_kwargs) + + +@tracked_lru_cache +def load_outcar(filepath): + """ + Load Outcar object from file path. + Caches the output for reuse. + + Args: + filepath: path to the OUTCAR file. + + Returns: + The Vasprun object + """ + return Outcar(filepath) diff --git a/custodian/vasp/tests/conftest.py b/custodian/vasp/tests/conftest.py index f8728305..bf67d0e7 100644 --- a/custodian/vasp/tests/conftest.py +++ b/custodian/vasp/tests/conftest.py @@ -12,3 +12,13 @@ def _patch_get_potential_energy(monkeypatch): Monkeypatch the multiprocessing.cpu_count() function to always return 64 """ monkeypatch.setattr(multiprocessing, "cpu_count", lambda *args, **kwargs: 64) + + +@pytest.fixture(autouse=True) +def _clear_tracked_cache(): + """ + Clear the cache of the stored functions between the tests. + """ + from custodian.utils import tracked_lru_cache + + tracked_lru_cache.tracked_cache_clear() diff --git a/custodian/vasp/tests/test_handlers.py b/custodian/vasp/tests/test_handlers.py index 6dd3a8f5..791d7172 100644 --- a/custodian/vasp/tests/test_handlers.py +++ b/custodian/vasp/tests/test_handlers.py @@ -8,6 +8,7 @@ from pymatgen.io.vasp.inputs import Incar, Kpoints, Structure, VaspInput from pymatgen.util.testing import PymatgenTest +from custodian.utils import tracked_lru_cache from custodian.vasp.handlers import ( AliasingErrorHandler, DriftErrorHandler, @@ -599,6 +600,7 @@ def test_check_correct_electronic(self): "actions": [{"action": {"_set": {"ALGO": "Normal"}}, "dict": "INCAR"}], "errors": ["Unconverged"], } + tracked_lru_cache.tracked_cache_clear() shutil.copy("vasprun.xml.electronic_veryfast", "vasprun.xml") handler = UnconvergedErrorHandler() @@ -606,6 +608,7 @@ def test_check_correct_electronic(self): dct = handler.correct() assert dct["errors"] == ["Unconverged"] assert dct == {"actions": [{"action": {"_set": {"ALGO": "Fast"}}, "dict": "INCAR"}], "errors": ["Unconverged"]} + tracked_lru_cache.tracked_cache_clear() shutil.copy("vasprun.xml.electronic_normal", "vasprun.xml") handler = UnconvergedErrorHandler() @@ -613,6 +616,7 @@ def test_check_correct_electronic(self): dct = handler.correct() assert dct["errors"] == ["Unconverged"] assert dct == {"actions": [{"action": {"_set": {"ALGO": "All"}}, "dict": "INCAR"}], "errors": ["Unconverged"]} + tracked_lru_cache.tracked_cache_clear() shutil.copy("vasprun.xml.electronic_metagga_fast", "vasprun.xml") handler = UnconvergedErrorHandler() @@ -620,6 +624,7 @@ def test_check_correct_electronic(self): dct = handler.correct() assert dct["errors"] == ["Unconverged"] assert dct == {"actions": [{"action": {"_set": {"ALGO": "All"}}, "dict": "INCAR"}], "errors": ["Unconverged"]} + tracked_lru_cache.tracked_cache_clear() shutil.copy("vasprun.xml.electronic_hybrid_fast", "vasprun.xml") handler = UnconvergedErrorHandler() @@ -627,6 +632,7 @@ def test_check_correct_electronic(self): dct = handler.correct() assert dct["errors"] == ["Unconverged"] assert dct == {"actions": [{"action": {"_set": {"ALGO": "All"}}, "dict": "INCAR"}], "errors": ["Unconverged"]} + tracked_lru_cache.tracked_cache_clear() shutil.copy("vasprun.xml.electronic_hybrid_all", "vasprun.xml") handler = UnconvergedErrorHandler() diff --git a/custodian/vasp/tests/test_io.py b/custodian/vasp/tests/test_io.py new file mode 100644 index 00000000..41098cf8 --- /dev/null +++ b/custodian/vasp/tests/test_io.py @@ -0,0 +1,27 @@ +import os +import unittest + +from custodian.utils import tracked_lru_cache +from custodian.vasp.io import load_outcar, load_vasprun + +test_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "test_files") + + +class IOTest(unittest.TestCase): + def test_load_outcar(self): + outcar = load_outcar(os.path.join(test_dir, "large_sigma", "OUTCAR")) + assert outcar is not None + outcar2 = load_outcar(os.path.join(test_dir, "large_sigma", "OUTCAR")) + + assert outcar is outcar2 + + assert len(tracked_lru_cache.cached_functions) == 1 + + def test_load_vasprun(self): + vr = load_vasprun(os.path.join(test_dir, "large_sigma", "vasprun.xml")) + assert vr is not None + vr2 = load_vasprun(os.path.join(test_dir, "large_sigma", "vasprun.xml")) + + assert vr is vr2 + + assert len(tracked_lru_cache.cached_functions) == 1 diff --git a/custodian/vasp/validators.py b/custodian/vasp/validators.py index 8bb588f4..a4b8505f 100644 --- a/custodian/vasp/validators.py +++ b/custodian/vasp/validators.py @@ -4,9 +4,10 @@ import os from collections import deque -from pymatgen.io.vasp import Chgcar, Incar, Outcar, Vasprun +from pymatgen.io.vasp import Chgcar, Incar from custodian.custodian import Validator +from custodian.vasp.io import load_outcar, load_vasprun class VasprunXMLValidator(Validator): @@ -27,7 +28,7 @@ def __init__(self, output_file="vasp.out", stderr_file="std_err.txt"): def check(self): """Check for error.""" try: - Vasprun("vasprun.xml") + load_vasprun(os.path.join(os.getcwd(), "vasprun.xml")) except Exception: exception_context = {} @@ -88,7 +89,7 @@ def check(self): if not is_npt: return False - outcar = Outcar("OUTCAR") + outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR")) patterns = {"MDALGO": r"MDALGO\s+=\s+([\d]+)"} outcar.read_pattern(patterns=patterns) if outcar.data["MDALGO"] == [["3"]]: