Skip to content

Commit

Permalink
dvc: introduce local build cache (#3603)
Browse files Browse the repository at this point in the history
This patch introduces `.dvc/cache/stages` that is used to store previous
runs and their results, which could then be reused later when we stumble
upon the same command with the same deps and outs.

Format of build cache entries is single-line json, which is readable by
humans and might also be used for lock files discussed in #1871.

Related to #1871
Local part of #1234
  • Loading branch information
efiop authored Apr 29, 2020
1 parent 8aefbac commit 18e8f07
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 17 deletions.
3 changes: 3 additions & 0 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, root_dir=None):
from dvc.repo.params import Params
from dvc.scm.tree import WorkingTree
from dvc.utils.fs import makedirs
from dvc.stage.cache import StageCache

root_dir = self.find_root(root_dir)

Expand Down Expand Up @@ -104,6 +105,8 @@ def __init__(self, root_dir=None):
self.cache = Cache(self)
self.cloud = DataCloud(self)

self.stage_cache = StageCache(self.cache.local.cache_dir)

self.metrics = Metrics(self)
self.params = Params(self)

Expand Down
9 changes: 2 additions & 7 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,7 @@ def reproduce(


def _reproduce_stages(
G,
stages,
downstream=False,
ignore_build_cache=False,
single_item=False,
**kwargs
G, stages, downstream=False, single_item=False, **kwargs
):
r"""Derive the evaluation of the given node for the given graph.
Expand Down Expand Up @@ -172,7 +167,7 @@ def _reproduce_stages(
try:
ret = _reproduce_stage(stage, **kwargs)

if len(ret) != 0 and ignore_build_cache:
if len(ret) != 0 and kwargs.get("ignore_build_cache", False):
# NOTE: we are walking our pipeline from the top to the
# bottom. If one stage is changed, it will be reproduced,
# which tells us that we should force reproducing all of
Expand Down
5 changes: 4 additions & 1 deletion dvc/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def run(self, fname=None, no_exec=False, **kwargs):
raise OutputDuplicationError(exc.output, set(exc.stages) - {stage})

if not no_exec:
stage.run(no_commit=kwargs.get("no_commit", False))
stage.run(
no_commit=kwargs.get("no_commit", False),
ignore_build_cache=kwargs.get("ignore_build_cache", False),
)
dvcfile.dump(stage, update_pipeline=True)
return stage
12 changes: 10 additions & 2 deletions dvc/stage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ def save(self):

self.md5 = self._compute_md5()

self.repo.stage_cache.save(self)

@staticmethod
def _changed_entries(entries):
return [
Expand Down Expand Up @@ -617,7 +619,9 @@ def _run(self):
raise StageCmdFailedError(self)

@rwlocked(read=["deps"], write=["outs"])
def run(self, dry=False, no_commit=False, force=False):
def run(
self, dry=False, no_commit=False, force=False, ignore_build_cache=False
):
if (self.cmd or self.is_import) and not self.locked and not dry:
self.remove_outs(ignore_remove=False, force=False)

Expand Down Expand Up @@ -650,16 +654,20 @@ def run(self, dry=False, no_commit=False, force=False):
self.check_missing_outputs()

else:
logger.info("Running command:\n\t{}".format(self.cmd))
if not dry:
if not force and not ignore_build_cache:
self.repo.stage_cache.restore(self)

if (
not force
and not self.is_callback
and not self.always_changed
and self._already_cached()
):
logger.info("Stage is cached, skipping.")
self.checkout()
else:
logger.info("Running command:\n\t{}".format(self.cmd))
self._run()

if not dry:
Expand Down
124 changes: 124 additions & 0 deletions dvc/stage/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import yaml
import logging

from voluptuous import Schema, Required, Invalid

from dvc.utils.fs import makedirs
from dvc.utils import relpath, dict_sha256

logger = logging.getLogger(__name__)

SCHEMA = Schema(
{
Required("cmd"): str,
Required("deps"): {str: str},
Required("outs"): {str: str},
}
)


def _get_cache_hash(cache, key=False):
return dict_sha256(
{
"cmd": cache["cmd"],
"deps": cache["deps"],
"outs": list(cache["outs"].keys()) if key else cache["outs"],
}
)


def _get_stage_hash(stage):
if not stage.cmd or not stage.deps or not stage.outs:
return None

for dep in stage.deps:
if dep.scheme != "local" or not dep.def_path or not dep.get_checksum():
return None

for out in stage.outs:
if out.scheme != "local" or not out.def_path or out.persist:
return None

return _get_cache_hash(_create_cache(stage), key=True)


def _create_cache(stage):
return {
"cmd": stage.cmd,
"deps": {dep.def_path: dep.get_checksum() for dep in stage.deps},
"outs": {out.def_path: out.get_checksum() for out in stage.outs},
}


class StageCache:
def __init__(self, cache_dir):
self.cache_dir = os.path.join(cache_dir, "stages")

def _get_cache_dir(self, key):
return os.path.join(self.cache_dir, key[:2], key)

def _get_cache_path(self, key, value):
return os.path.join(self._get_cache_dir(key), value)

def _load_cache(self, key, value):
path = self._get_cache_path(key, value)

try:
with open(path, "r") as fobj:
return SCHEMA(yaml.safe_load(fobj))
except FileNotFoundError:
return None
except (yaml.error.YAMLError, Invalid):
logger.warning("corrupted cache file '%s'.", relpath(path))
os.unlink(path)
return None

def _load(self, stage):
key = _get_stage_hash(stage)
if not key:
return None

cache_dir = self._get_cache_dir(key)
if not os.path.exists(cache_dir):
return None

for value in os.listdir(cache_dir):
cache = self._load_cache(key, value)
if cache:
return cache

return None

def save(self, stage):
cache_key = _get_stage_hash(stage)
if not cache_key:
return

cache = _create_cache(stage)
cache_value = _get_cache_hash(cache)

if self._load_cache(cache_key, cache_value):
return

# sanity check
SCHEMA(cache)

path = self._get_cache_path(cache_key, cache_value)
dpath = os.path.dirname(path)
makedirs(dpath, exist_ok=True)
with open(path, "w+") as fobj:
yaml.dump(cache, fobj)

def restore(self, stage):
cache = self._load(stage)
if not cache:
return

deps = {dep.def_path: dep for dep in stage.deps}
for def_path, checksum in cache["deps"].items():
deps[def_path].checksum = checksum

outs = {out.def_path: out for out in stage.outs}
for def_path, checksum in cache["outs"].items():
outs[def_path].checksum = checksum
16 changes: 12 additions & 4 deletions dvc/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def file_md5(fname):
return (None, None)


def bytes_md5(byts):
hasher = hashlib.md5()
def bytes_hash(byts, typ):
hasher = getattr(hashlib, typ)()
hasher.update(byts)
return hasher.hexdigest()

Expand All @@ -100,10 +100,18 @@ def dict_filter(d, exclude=()):
return d


def dict_md5(d, exclude=()):
def dict_hash(d, typ, exclude=()):
filtered = dict_filter(d, exclude)
byts = json.dumps(filtered, sort_keys=True).encode("utf-8")
return bytes_md5(byts)
return bytes_hash(byts, typ)


def dict_md5(d, **kwargs):
return dict_hash(d, "md5", **kwargs)


def dict_sha256(d, **kwargs):
return dict_hash(d, "sha256", **kwargs)


def _split(list_to_split, chunk_size):
Expand Down
2 changes: 2 additions & 0 deletions tests/func/test_gc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import shutil
import os

import configobj
Expand Down Expand Up @@ -341,6 +342,7 @@ def test_gc_not_collect_pipeline_tracked_files(tmp_dir, dvc, run_copy):
tmp_dir.gen("bar", "bar")

run_copy("foo", "foo2", name="copy")
shutil.rmtree(dvc.stage_cache.cache_dir)
assert _count_files(dvc.cache.local.cache_dir) == 1
dvc.gc(workspace=True, force=True)
assert _count_files(dvc.cache.local.cache_dir) == 1
Expand Down
4 changes: 3 additions & 1 deletion tests/func/test_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,9 @@ def test(self):
["repro", self._get_stage_target(self.stage), "--no-commit"]
)
self.assertEqual(ret, 0)
self.assertFalse(os.path.exists(self.dvc.cache.local.cache_dir))
self.assertEqual(
os.listdir(self.dvc.cache.local.cache_dir), ["stages"]
)


class TestReproAlreadyCached(TestRepro):
Expand Down
39 changes: 37 additions & 2 deletions tests/unit/test_stage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import signal
import subprocess
import threading
Expand Down Expand Up @@ -51,8 +52,6 @@ def test_meta_ignored():

class TestPathConversion(TestCase):
def test(self):
import os

stage = Stage(None, "path")

stage.wdir = os.path.join("..", "..")
Expand Down Expand Up @@ -103,3 +102,39 @@ def test_always_changed(dvc):
with dvc.lock:
assert stage.changed()
assert stage.status()["path"] == ["always changed"]


def test_stage_cache(tmp_dir, dvc, run_copy, mocker):
tmp_dir.gen("dep", "dep")
stage = run_copy("dep", "out")

with dvc.lock, dvc.state:
stage.remove(remove_outs=True, force=True)

assert not (tmp_dir / "out").exists()
assert not (tmp_dir / "out.dvc").exists()

cache_dir = os.path.join(
dvc.stage_cache.cache_dir,
"ec",
"ec5b6d8dea9136dbb62d93a95c777f87e6c54b0a6bee839554acb99fdf23d2b1",
)
cache_file = os.path.join(
cache_dir,
"09f9eb17fdb1ee7f8566b3c57394cee060eaf28075244bc6058612ac91fdf04a",
)

assert os.path.isdir(cache_dir)
assert os.listdir(cache_dir) == [os.path.basename(cache_file)]
assert os.path.isfile(cache_file)

run_spy = mocker.spy(stage, "_run")
checkout_spy = mocker.spy(stage, "checkout")
with dvc.lock, dvc.state:
stage.run()

assert not run_spy.called
assert checkout_spy.call_count == 1

assert (tmp_dir / "out").exists()
assert (tmp_dir / "out").read_text() == "dep"
26 changes: 26 additions & 0 deletions tests/unit/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dvc.path_info import PathInfo
from dvc.utils import (
file_md5,
dict_sha256,
resolve_output,
fix_env,
relpath,
Expand Down Expand Up @@ -155,3 +156,28 @@ def test_hint_on_lockfile():
with pytest.raises(Exception) as exc:
assert parse_target("pipelines.lock:name")
assert "pipelines.yaml:name" in str(exc.value)


@pytest.mark.parametrize(
"d,sha",
[
(
{
"cmd": "echo content > out",
"deps": {"dep": "2254342becceafbd04538e0a38696791"},
"outs": {"out": "f75b8179e4bbe7e2b4a074dcef62de95"},
},
"f472eda60f09660a4750e8b3208cf90b3a3b24e5f42e0371d829710e9464d74a",
),
(
{
"cmd": "echo content > out",
"deps": {"dep": "2254342becceafbd04538e0a38696791"},
"outs": ["out"],
},
"a239b67073bd58affcdb81fff3305d1726c6e7f9c86f3d4fca0e92e8147dc7b0",
),
],
)
def test_dict_sha256(d, sha):
assert dict_sha256(d) == sha

0 comments on commit 18e8f07

Please sign in to comment.