diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index e146b64ad0..33431fa4f0 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -72,6 +72,7 @@ def __init__(self, root_dir=None): from dvc.repo.params import Params from dvc.scm.tree import WorkingTree from dvc.utils.fs import makedirs + from dvc.stage.cache import StageCache root_dir = self.find_root(root_dir) @@ -104,6 +105,8 @@ def __init__(self, root_dir=None): self.cache = Cache(self) self.cloud = DataCloud(self) + self.stage_cache = StageCache(self.cache.local.cache_dir) + self.metrics = Metrics(self) self.params = Params(self) diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index 8bf4f84d12..7abf03d0c2 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -97,12 +97,7 @@ def reproduce( def _reproduce_stages( - G, - stages, - downstream=False, - ignore_build_cache=False, - single_item=False, - **kwargs + G, stages, downstream=False, single_item=False, **kwargs ): r"""Derive the evaluation of the given node for the given graph. @@ -172,7 +167,7 @@ def _reproduce_stages( try: ret = _reproduce_stage(stage, **kwargs) - if len(ret) != 0 and ignore_build_cache: + if len(ret) != 0 and kwargs.get("ignore_build_cache", False): # NOTE: we are walking our pipeline from the top to the # bottom. If one stage is changed, it will be reproduced, # which tells us that we should force reproducing all of diff --git a/dvc/repo/run.py b/dvc/repo/run.py index 63e76b5c0f..14042b00e4 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -68,6 +68,9 @@ def run(self, fname=None, no_exec=False, **kwargs): raise OutputDuplicationError(exc.output, set(exc.stages) - {stage}) if not no_exec: - stage.run(no_commit=kwargs.get("no_commit", False)) + stage.run( + no_commit=kwargs.get("no_commit", False), + ignore_build_cache=kwargs.get("ignore_build_cache", False), + ) dvcfile.dump(stage, update_pipeline=True) return stage diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index fcdbeed671..6c78c65de9 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -491,6 +491,8 @@ def save(self): self.md5 = self._compute_md5() + self.repo.stage_cache.save(self) + @staticmethod def _changed_entries(entries): return [ @@ -617,7 +619,9 @@ def _run(self): raise StageCmdFailedError(self) @rwlocked(read=["deps"], write=["outs"]) - def run(self, dry=False, no_commit=False, force=False): + def run( + self, dry=False, no_commit=False, force=False, ignore_build_cache=False + ): if (self.cmd or self.is_import) and not self.locked and not dry: self.remove_outs(ignore_remove=False, force=False) @@ -650,16 +654,20 @@ def run(self, dry=False, no_commit=False, force=False): self.check_missing_outputs() else: - logger.info("Running command:\n\t{}".format(self.cmd)) if not dry: + if not force and not ignore_build_cache: + self.repo.stage_cache.restore(self) + if ( not force and not self.is_callback and not self.always_changed and self._already_cached() ): + logger.info("Stage is cached, skipping.") self.checkout() else: + logger.info("Running command:\n\t{}".format(self.cmd)) self._run() if not dry: diff --git a/dvc/stage/cache.py b/dvc/stage/cache.py new file mode 100644 index 0000000000..e4ceacb2c0 --- /dev/null +++ b/dvc/stage/cache.py @@ -0,0 +1,124 @@ +import os +import yaml +import logging + +from voluptuous import Schema, Required, Invalid + +from dvc.utils.fs import makedirs +from dvc.utils import relpath, dict_sha256 + +logger = logging.getLogger(__name__) + +SCHEMA = Schema( + { + Required("cmd"): str, + Required("deps"): {str: str}, + Required("outs"): {str: str}, + } +) + + +def _get_cache_hash(cache, key=False): + return dict_sha256( + { + "cmd": cache["cmd"], + "deps": cache["deps"], + "outs": list(cache["outs"].keys()) if key else cache["outs"], + } + ) + + +def _get_stage_hash(stage): + if not stage.cmd or not stage.deps or not stage.outs: + return None + + for dep in stage.deps: + if dep.scheme != "local" or not dep.def_path or not dep.get_checksum(): + return None + + for out in stage.outs: + if out.scheme != "local" or not out.def_path or out.persist: + return None + + return _get_cache_hash(_create_cache(stage), key=True) + + +def _create_cache(stage): + return { + "cmd": stage.cmd, + "deps": {dep.def_path: dep.get_checksum() for dep in stage.deps}, + "outs": {out.def_path: out.get_checksum() for out in stage.outs}, + } + + +class StageCache: + def __init__(self, cache_dir): + self.cache_dir = os.path.join(cache_dir, "stages") + + def _get_cache_dir(self, key): + return os.path.join(self.cache_dir, key[:2], key) + + def _get_cache_path(self, key, value): + return os.path.join(self._get_cache_dir(key), value) + + def _load_cache(self, key, value): + path = self._get_cache_path(key, value) + + try: + with open(path, "r") as fobj: + return SCHEMA(yaml.safe_load(fobj)) + except FileNotFoundError: + return None + except (yaml.error.YAMLError, Invalid): + logger.warning("corrupted cache file '%s'.", relpath(path)) + os.unlink(path) + return None + + def _load(self, stage): + key = _get_stage_hash(stage) + if not key: + return None + + cache_dir = self._get_cache_dir(key) + if not os.path.exists(cache_dir): + return None + + for value in os.listdir(cache_dir): + cache = self._load_cache(key, value) + if cache: + return cache + + return None + + def save(self, stage): + cache_key = _get_stage_hash(stage) + if not cache_key: + return + + cache = _create_cache(stage) + cache_value = _get_cache_hash(cache) + + if self._load_cache(cache_key, cache_value): + return + + # sanity check + SCHEMA(cache) + + path = self._get_cache_path(cache_key, cache_value) + dpath = os.path.dirname(path) + makedirs(dpath, exist_ok=True) + with open(path, "w+") as fobj: + yaml.dump(cache, fobj) + + def restore(self, stage): + cache = self._load(stage) + if not cache: + return + + deps = {dep.def_path: dep for dep in stage.deps} + for def_path, checksum in cache["deps"].items(): + deps[def_path].checksum = checksum + + outs = {out.def_path: out for out in stage.outs} + for def_path, checksum in cache["outs"].items(): + outs[def_path].checksum = checksum diff --git a/dvc/utils/__init__.py b/dvc/utils/__init__.py index 2c03f7db25..0d2c6cc02c 100644 --- a/dvc/utils/__init__.py +++ b/dvc/utils/__init__.py @@ -76,8 +76,8 @@ def file_md5(fname): return (None, None) -def bytes_md5(byts): - hasher = hashlib.md5() +def bytes_hash(byts, typ): + hasher = getattr(hashlib, typ)() hasher.update(byts) return hasher.hexdigest() @@ -100,10 +100,18 @@ def dict_filter(d, exclude=()): return d -def dict_md5(d, exclude=()): +def dict_hash(d, typ, exclude=()): filtered = dict_filter(d, exclude) byts = json.dumps(filtered, sort_keys=True).encode("utf-8") - return bytes_md5(byts) + return bytes_hash(byts, typ) + + +def dict_md5(d, **kwargs): + return dict_hash(d, "md5", **kwargs) + + +def dict_sha256(d, **kwargs): + return dict_hash(d, "sha256", **kwargs) def _split(list_to_split, chunk_size): diff --git a/tests/func/test_gc.py b/tests/func/test_gc.py index e19d5dc41b..0698bea794 100644 --- a/tests/func/test_gc.py +++ b/tests/func/test_gc.py @@ -1,4 +1,5 @@ import logging +import shutil import os import configobj @@ -341,6 +342,7 @@ def test_gc_not_collect_pipeline_tracked_files(tmp_dir, dvc, run_copy): tmp_dir.gen("bar", "bar") run_copy("foo", "foo2", name="copy") + shutil.rmtree(dvc.stage_cache.cache_dir) assert _count_files(dvc.cache.local.cache_dir) == 1 dvc.gc(workspace=True, force=True) assert _count_files(dvc.cache.local.cache_dir) == 1 diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index 5ea5dae81c..ba757ac11e 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -1295,7 +1295,9 @@ def test(self): ["repro", self._get_stage_target(self.stage), "--no-commit"] ) self.assertEqual(ret, 0) - self.assertFalse(os.path.exists(self.dvc.cache.local.cache_dir)) + self.assertEqual( + os.listdir(self.dvc.cache.local.cache_dir), ["stages"] + ) class TestReproAlreadyCached(TestRepro): diff --git a/tests/unit/test_stage.py b/tests/unit/test_stage.py index 02924029e6..982d549a35 100644 --- a/tests/unit/test_stage.py +++ b/tests/unit/test_stage.py @@ -1,3 +1,4 @@ +import os import signal import subprocess import threading @@ -51,8 +52,6 @@ def test_meta_ignored(): class TestPathConversion(TestCase): def test(self): - import os - stage = Stage(None, "path") stage.wdir = os.path.join("..", "..") @@ -103,3 +102,39 @@ def test_always_changed(dvc): with dvc.lock: assert stage.changed() assert stage.status()["path"] == ["always changed"] + + +def test_stage_cache(tmp_dir, dvc, run_copy, mocker): + tmp_dir.gen("dep", "dep") + stage = run_copy("dep", "out") + + with dvc.lock, dvc.state: + stage.remove(remove_outs=True, force=True) + + assert not (tmp_dir / "out").exists() + assert not (tmp_dir / "out.dvc").exists() + + cache_dir = os.path.join( + dvc.stage_cache.cache_dir, + "ec", + "ec5b6d8dea9136dbb62d93a95c777f87e6c54b0a6bee839554acb99fdf23d2b1", + ) + cache_file = os.path.join( + cache_dir, + "09f9eb17fdb1ee7f8566b3c57394cee060eaf28075244bc6058612ac91fdf04a", + ) + + assert os.path.isdir(cache_dir) + assert os.listdir(cache_dir) == [os.path.basename(cache_file)] + assert os.path.isfile(cache_file) + + run_spy = mocker.spy(stage, "_run") + checkout_spy = mocker.spy(stage, "checkout") + with dvc.lock, dvc.state: + stage.run() + + assert not run_spy.called + assert checkout_spy.call_count == 1 + + assert (tmp_dir / "out").exists() + assert (tmp_dir / "out").read_text() == "dep" diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 8e18cb5422..1d181f56dd 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -6,6 +6,7 @@ from dvc.path_info import PathInfo from dvc.utils import ( file_md5, + dict_sha256, resolve_output, fix_env, relpath, @@ -155,3 +156,28 @@ def test_hint_on_lockfile(): with pytest.raises(Exception) as exc: assert parse_target("pipelines.lock:name") assert "pipelines.yaml:name" in str(exc.value) + + +@pytest.mark.parametrize( + "d,sha", + [ + ( + { + "cmd": "echo content > out", + "deps": {"dep": "2254342becceafbd04538e0a38696791"}, + "outs": {"out": "f75b8179e4bbe7e2b4a074dcef62de95"}, + }, + "f472eda60f09660a4750e8b3208cf90b3a3b24e5f42e0371d829710e9464d74a", + ), + ( + { + "cmd": "echo content > out", + "deps": {"dep": "2254342becceafbd04538e0a38696791"}, + "outs": ["out"], + }, + "a239b67073bd58affcdb81fff3305d1726c6e7f9c86f3d4fca0e92e8147dc7b0", + ), + ], +) +def test_dict_sha256(d, sha): + assert dict_sha256(d) == sha