Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exp queue: preserve failed exp status #7855

Merged
merged 2 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .queue.base import BaseStashQueue, QueueEntry
from .queue.local import LocalCeleryQueue, WorkspaceQueue
from .refs import (
CELERY_FAILED_STASH,
CELERY_STASH,
EXEC_APPLY,
EXEC_BRANCH,
Expand Down Expand Up @@ -79,7 +80,7 @@ def workspace_queue(self) -> WorkspaceQueue:

@cached_property
def celery_queue(self) -> LocalCeleryQueue:
return LocalCeleryQueue(self.repo, CELERY_STASH)
return LocalCeleryQueue(self.repo, CELERY_STASH, CELERY_FAILED_STASH)

@property
def stash_revs(self) -> Dict[str, ExpStashEntry]:
Expand Down
106 changes: 83 additions & 23 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class QueueEntry:
baseline_rev: str
branch: Optional[str]
name: Optional[str]
head_rev: Optional[str] = None

def __eq__(self, other: object):
return (
Expand All @@ -77,21 +78,30 @@ class QueueGetResult(NamedTuple):
executor: BaseExecutor


class QueueDoneResult(NamedTuple):
entry: QueueEntry
result: Optional[ExecutorResult]


class BaseStashQueue(ABC):
"""Naive Git-stash based experiment queue.

Maps queued experiments to (Git) stash reflog entries.
"""

def __init__(self, repo: "Repo", ref: str):
def __init__(
self, repo: "Repo", ref: str, failed_ref: Optional[str] = None
):
"""Construct a queue.

Arguments:
scm: Git SCM instance for this queue.
ref: Git stash ref for this queue.
failed_ref: Failed run Git stash ref for this queue.
"""
self.repo = repo
self.ref = ref
self.failed_ref = failed_ref

@property
def scm(self) -> "Git":
Expand All @@ -101,6 +111,10 @@ def scm(self) -> "Git":
def stash(self) -> ExpStash:
return ExpStash(self.scm, self.ref)

@cached_property
def failed_stash(self) -> Optional[ExpStash]:
return ExpStash(self.scm, self.failed_ref) if self.failed_ref else None

@cached_property
def pid_dir(self) -> str:
return os.path.join(self.repo.tmp_dir, EXEC_TMP_DIR, EXEC_PID_DIR)
Expand Down Expand Up @@ -149,36 +163,56 @@ def clear(self) -> List[str]:
self._remove_revs(stash_revs)
return removed

def status(self) -> List[Mapping[str, Optional[str]]]:
def status(self) -> List[Dict[str, Any]]:
"""Show the status of exp tasks in queue"""
from datetime import datetime

result: List[Mapping[str, Optional[str]]] = []
result: List[Dict[str, Optional[str]]] = []

def _get_timestamp(rev):
def _get_timestamp(rev: str) -> datetime:
commit = self.scm.resolve_commit(rev)
return datetime.fromtimestamp(commit.commit_time)

for queue_entry in self.iter_active():
result.append(
{
"rev": queue_entry.stash_rev,
"name": queue_entry.name,
"timestamp": _get_timestamp(queue_entry.stash_rev),
"status": "Running",
}
)

for queue_entry in self.iter_queued():
result.append(
{
"rev": queue_entry.stash_rev,
"name": queue_entry.name,
"timestamp": _get_timestamp(queue_entry.stash_rev),
"status": "Queued",
}
)
failed_revs: Dict[str, ExpStashEntry] = (
dict(self.failed_stash.stash_revs)
if self.failed_stash is not None
else {}
)

def _format_entry(
entry: QueueEntry,
status: str = "Unknown",
result: Optional[ExecutorResult] = None,
) -> Dict[str, Any]:
name = entry.name
# NOTE: We fallback to Unknown status for experiments
# generated in prior (incompatible) DVC versions
if result is None and entry.stash_rev in failed_revs:
status = "Failed"
elif result is not None:
if result.exp_hash:
if result.ref_info:
status = "Success"
name = result.ref_info.name
return {
"rev": entry.stash_rev,
"name": name,
"timestamp": _get_timestamp(entry.stash_rev),
"status": status,
}

result.extend(
_format_entry(queue_entry, status="Running")
for queue_entry in self.iter_active()
)
result.extend(
_format_entry(queue_entry, status="Queued")
for queue_entry in self.iter_queued()
)
result.extend(
_format_entry(queue_entry, result=exp_result)
for queue_entry, exp_result in self.iter_done()
)
return result

@abstractmethod
Expand All @@ -193,6 +227,9 @@ def iter_queued(self) -> Generator[QueueEntry, None, None]:
def iter_active(self) -> Generator[QueueEntry, None, None]:
"""Iterate over items which are being actively processed."""

def iter_done(self) -> Generator[QueueDoneResult, None, None]:
"""Iterate over items which been processed."""

@abstractmethod
def reproduce(self) -> Mapping[str, Mapping[str, str]]:
"""Reproduce queued experiments sequentially."""
Expand Down Expand Up @@ -387,6 +424,7 @@ def _stash_exp(
baseline_rev,
branch,
name,
stash_head,
)

def _stash_commit_deps(self, *args, **kwargs):
Expand Down Expand Up @@ -577,3 +615,25 @@ def match_queue_entry_by_name(
result[exp_name] = None

return result

def stash_failed(self, entry: QueueEntry) -> None:
"""Add an entry to the failed exp stash.

Arguments:
entry: Failed queue entry to add. ``entry.stash_rev`` must be a
valid Git stash commit.
"""
if self.failed_stash is not None:
assert entry.head_rev
logger.debug("Stashing failed exp '%s'", entry.stash_rev[:7])
msg = self.failed_stash.format_message(
entry.head_rev,
baseline_rev=entry.baseline_rev,
name=entry.name,
branch=entry.branch,
)
self.scm.set_ref(
self.failed_stash.ref,
entry.stash_rev,
message=f"commit: {msg}",
)
20 changes: 19 additions & 1 deletion dvc/repo/experiments/queue/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..executor.local import WorkspaceExecutor
from ..refs import EXEC_BRANCH
from ..stash import ExpStashEntry
from .base import BaseStashQueue, QueueEntry, QueueGetResult
from .base import BaseStashQueue, QueueDoneResult, QueueEntry, QueueGetResult
from .tasks import run_exp

if TYPE_CHECKING:
Expand Down Expand Up @@ -183,10 +183,23 @@ def _iter_active_tasks(self) -> Generator[_TaskEntry, None, None]:
if not result.ready():
yield _TaskEntry(task_id, entry)

def _iter_done_tasks(self) -> Generator[_TaskEntry, None, None]:
from celery.result import AsyncResult

for msg, entry in self._iter_processed():
task_id = msg.headers["id"]
result = AsyncResult(task_id)
if result.ready():
yield _TaskEntry(task_id, entry)

def iter_active(self) -> Generator[QueueEntry, None, None]:
for _, entry in self._iter_active_tasks():
yield entry

def iter_done(self) -> Generator[QueueDoneResult, None, None]:
for _, entry in self._iter_done_tasks():
yield QueueDoneResult(entry, self.get_result(entry))

def reproduce(self) -> Mapping[str, Mapping[str, str]]:
raise NotImplementedError

Expand Down Expand Up @@ -286,6 +299,7 @@ def get(self) -> QueueGetResult:
stash_entry.baseline_rev,
stash_entry.branch,
stash_entry.name,
stash_entry.head_rev,
)
executor = self.setup_executor(self.repo.experiments, entry)
return QueueGetResult(entry, executor)
Expand All @@ -311,13 +325,17 @@ def iter_queued(self) -> Generator[QueueEntry, None, None]:
entry.baseline_rev,
entry.branch,
entry.name,
entry.head_rev,
)

def iter_active(self) -> Generator[QueueEntry, None, None]:
# Workspace run state is reflected in the workspace itself and does not
# need to be handled via the queue
raise NotImplementedError

def iter_done(self) -> Generator[QueueDoneResult, None, None]:
raise NotImplementedError

def reproduce(self) -> Dict[str, Dict[str, str]]:
results: Dict[str, Dict[str, str]] = defaultdict(dict)
try:
Expand Down
14 changes: 5 additions & 9 deletions dvc/repo/experiments/queue/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def setup_exp(entry_dict: Dict[str, Any]) -> str:

@shared_task
def collect_exp(
proc_dict: Dict[str, Any],
proc_dict: Dict[str, Any], # pylint: disable=unused-argument
entry_dict: Dict[str, Any],
) -> str:
"""Collect results for an experiment.
Expand All @@ -51,16 +51,11 @@ def collect_exp(
Directory to be cleaned up after this experiment.
"""
from dvc.repo import Repo
from dvc_task.proc.process import ProcessInfo

proc_info = ProcessInfo.from_dict(proc_dict)
if proc_info.returncode != 0:
# TODO: handle errors, track failed exps separately
pass

entry = QueueEntry.from_dict(entry_dict)
repo = Repo(entry.dvc_root)
infofile = repo.experiments.celery_queue.get_infofile_path(entry.stash_rev)
celery_queue = repo.experiments.celery_queue
infofile = celery_queue.get_infofile_path(entry.stash_rev)
executor_info = ExecutorInfo.load_json(infofile)
logger.debug("Collecting experiment info '%s'", str(executor_info))
executor = TempDirExecutor.from_info(executor_info)
Expand All @@ -73,7 +68,8 @@ def collect_exp(
for rev in results:
logger.debug("Collected experiment '%s'", rev[:7])
else:
logger.debug("Exec result was None")
logger.debug("Experiment failed (Exec result was None)")
celery_queue.stash_failed(entry)
except Exception: # pylint: disable=broad-except
# Log exceptions but do not re-raise so that task chain execution
# continues
Expand Down
1 change: 1 addition & 0 deletions dvc/repo/experiments/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
EXPS_STASH = f"{EXPS_NAMESPACE}/stash"
WORKSPACE_STASH = EXPS_STASH
CELERY_STASH = f"{EXPS_NAMESPACE}/celery/stash"
CELERY_FAILED_STASH = f"{EXPS_NAMESPACE}/celery/failed"
EXEC_NAMESPACE = f"{EXPS_NAMESPACE}/exec"
EXEC_APPLY = f"{EXEC_NAMESPACE}/EXEC_APPLY"
EXEC_CHECKPOINT = f"{EXEC_NAMESPACE}/EXEC_CHECKPOINT"
Expand Down
10 changes: 8 additions & 2 deletions dvc/repo/experiments/stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,14 @@ def stash_revs(self) -> Dict[str, ExpStashEntry]:

@classmethod
def format_message(
cls, rev: str, baseline_rev: str, name: Optional[str] = None
cls,
rev: str,
baseline_rev: str,
name: Optional[str] = None,
branch: Optional[str] = None,
) -> str:
return cls.MESSAGE_FORMAT.format(
msg = cls.MESSAGE_FORMAT.format(
rev=rev, baseline_rev=baseline_rev, name=name if name else ""
)
branch_msg = f":{branch}" if branch else ""
return f"{msg}{branch_msg}"
25 changes: 25 additions & 0 deletions tests/func/experiments/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,31 @@ def checkpoint_stage(tmp_dir, scm, dvc, mocker):
return stage


@pytest.fixture
def failed_exp_stage(tmp_dir, scm, dvc):
tmp_dir.gen("copy.py", COPY_SCRIPT)
tmp_dir.gen("params.yaml", "foo: 1")
stage = dvc.stage.add(
cmd="python -c 'import sys; sys.exit(1)'",
metrics_no_cache=["metrics.yaml"],
params=["foo"],
name="copy-file",
deps=["copy.py"],
)
scm.add(
[
"dvc.yaml",
"dvc.lock",
"copy.py",
"params.yaml",
"metrics.yaml",
".gitignore",
]
)
scm.commit("init")
return stage


@pytest.fixture
def http_auth_patch(mocker):
from dulwich.client import HTTPUnauthorized
Expand Down
35 changes: 29 additions & 6 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from funcy import first

from dvc.dvcfile import PIPELINE_FILE
from dvc.exceptions import DvcException
from dvc.exceptions import DvcException, ReproductionError
from dvc.repo.experiments.queue.base import BaseStashQueue
from dvc.repo.experiments.utils import exp_refs_by_rev
from dvc.scm import resolve_rev
Expand Down Expand Up @@ -88,15 +88,38 @@ def test_file_permissions(tmp_dir, scm, dvc, exp_stage, mocker):
assert stat.S_IMODE(os.stat(tmp_dir / "copy.py").st_mode) == mode


def test_failed_exp(tmp_dir, scm, dvc, exp_stage, mocker, capsys, test_queue):
def test_failed_exp_workspace(
tmp_dir,
scm,
dvc,
failed_exp_stage,
mocker,
capsys,
):
tmp_dir.gen("params.yaml", "foo: 2")
with pytest.raises(ReproductionError):
dvc.experiments.run(failed_exp_stage.addressing)


def test_failed_exp_celery(
tmp_dir,
scm,
dvc,
failed_exp_stage,
test_queue,
mocker,
capsys,
):
tmp_dir.gen("params.yaml", "foo: 2")

mocker.patch.object(
dvc.experiments.celery_queue, "get_result", return_value=None
)
dvc.experiments.run(exp_stage.addressing, tmp_dir=True)
dvc.experiments.run(failed_exp_stage.addressing, queue=True)
dvc.experiments.run(run_all=True)
output = capsys.readouterr()
assert "Failed to reproduce experiment" in output.err
assert len(dvc.experiments.celery_queue.failed_stash) == 1
assert (
first(dvc.experiments.celery_queue.status()).get("status") == "Failed"
)


@pytest.mark.parametrize(
Expand Down