Skip to content

Commit

Permalink
exp: implement internal queue.kill support(#7587)
Browse files Browse the repository at this point in the history
fix: #7587
1. Add implement kill method for local queue class
2. Add a test to make sure the following job will be success after the
   original job was killed.
3. Some refactoring work on `exp remove`
  • Loading branch information
karajan1001 authored and pmrowla committed Jul 12, 2022
1 parent 22deca3 commit 4df5a13
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 29 deletions.
33 changes: 33 additions & 0 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ def get_result(self, entry: QueueEntry) -> Optional[ExecutorResult]:
This method blocks until the specified item has been collected.
"""

@abstractmethod
def kill(self, revs: str) -> None:
"""Kill the specified running entries in the queue.
Arguments:
revs: Stash revs or running exp name to be killed.
"""

@abstractmethod
def shutdown(self, kill: bool = False):
"""Shutdown the queue worker.
Expand Down Expand Up @@ -495,3 +503,28 @@ def on_diverged(ref: str, checkpoint: bool):
executor.collect_cache(exp.repo, exec_result.ref_info)

return results

def get_queue_entry_by_names(
self,
exp_names: Collection[str],
) -> Dict[str, Optional[QueueEntry]]:
from scmrepo.exceptions import RevError as InternalRevError

exp_name_set = set(exp_names)
result: Dict[str, Optional[QueueEntry]] = {}
rev_entries = {}

for entry in self.iter_queued():
if entry.name in exp_name_set:
result[entry.name] = entry
else:
rev_entries[entry.stash_rev] = entry

for exp_name in exp_name_set.difference(result.keys()):
try:
rev = self.scm.resolve_rev(exp_name)
if rev in rev_entries:
result[exp_name] = rev_entries[rev]
except InternalRevError:
result[exp_name] = None
return result
29 changes: 28 additions & 1 deletion dvc/repo/experiments/queue/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Generator,
List,
Expand All @@ -20,7 +21,7 @@
from dvc.daemon import daemonize
from dvc.exceptions import DvcException

from ..exceptions import ExpQueueEmptyError
from ..exceptions import ExpQueueEmptyError, UnresolvedExpNamesError
from ..executor.base import (
EXEC_PID_DIR,
EXEC_TMP_DIR,
Expand Down Expand Up @@ -200,6 +201,29 @@ def get_result(self, entry: QueueEntry) -> Optional[ExecutorResult]:
pass
time.sleep(1)

def kill(self, revs: Collection[str]) -> None:
to_kill: Set[QueueEntry] = set()
not_active: List[str] = []
name_dict: Dict[
str, Optional[QueueEntry]
] = self.get_queue_entry_by_names(set(revs))

missing_rev: List[str] = []
active_queue_entry = set(self.iter_active())
for rev, queue_entry in name_dict.items():
if queue_entry is None:
missing_rev.append(rev)
elif queue_entry not in active_queue_entry:
not_active.append(rev)
else:
to_kill.add(queue_entry)

if missing_rev:
raise UnresolvedExpNamesError(missing_rev)

for queue_entry in to_kill:
self.proc.kill(queue_entry.name)

def _shutdown_handler(self, task_id: str = None, **kwargs):
if task_id in self._shutdown_task_ids:
self._shutdown_task_ids.remove(task_id)
Expand Down Expand Up @@ -336,5 +360,8 @@ def collect_executor( # pylint: disable=unused-argument
def get_result(self, entry: QueueEntry) -> Optional[ExecutorResult]:
raise NotImplementedError

def kill(self, revs: Collection[str]) -> None:
raise NotImplementedError

def shutdown(self, kill: bool = False):
raise NotImplementedError
30 changes: 4 additions & 26 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Union
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Union

from dvc.repo import locked
from dvc.repo.scm_context import scm_context
Expand Down Expand Up @@ -85,7 +85,9 @@ def _resolve_exp_by_name(
commit_ref_dict[exp_ref] = exp_name

if not git_remote:
_named_entries = _get_queue_entry_by_names(repo, remained)
_named_entries = (
repo.experiments.celery_queue.get_queue_entry_by_names(remained)
)
for exp_name, entry in _named_entries.items():
if entry is not None:
queue_entry_dict[exp_name] = entry
Expand Down Expand Up @@ -127,30 +129,6 @@ def _clear_all_commits(repo, git_remote) -> List:
return _remove_commited_exps(repo.scm, ref_infos, git_remote)


def _get_queue_entry_by_names(
repo: "Repo",
exp_name_set: Set[str],
) -> Dict[str, Optional[QueueEntry]]:
from scmrepo.exceptions import RevError as InternalRevError

result = {}
rev_entries = {}
for entry in repo.experiments.celery_queue.iter_queued():
if entry.name in exp_name_set:
result[entry.name] = entry
else:
rev_entries[entry.stash_rev] = entry

for exp_name in exp_name_set.difference(result.keys()):
try:
rev = repo.scm.resolve_rev(exp_name)
if rev in rev_entries:
result[exp_name] = rev_entries[rev]
except InternalRevError:
result[exp_name] = None
return result


def _remove_commited_exps(
scm: "Git", exp_ref_dict: Mapping[ExpRefInfo, str], remote: Optional[str]
) -> List[str]:
Expand Down
36 changes: 34 additions & 2 deletions tests/unit/repo/experiments/queue/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def test_shutdown_no_tasks(test_queue, mocker):


@shared_task
def _foo():
return "foo"
def _foo(arg="foo"):
return arg


def test_shutdown_active_tasks(test_queue, mocker):
Expand All @@ -35,3 +35,35 @@ def test_shutdown_active_tasks(test_queue, mocker):
assert "foo" == result.get()
assert result.id not in test_queue._shutdown_task_ids
shutdown_spy.assert_called_once()


def test_post_run_after_kill(test_queue, mocker):
import time

from celery import chain

sig_bar = test_queue.proc.run(
["python3", "-c", "import time; time.sleep(5)"], name="bar"
)
result_bar = sig_bar.freeze()
sig_foo = _foo.s()
result_foo = sig_foo.freeze()
run_chain = chain(sig_bar, sig_foo)

run_chain.delay()
timeout = time.time() + 10

while True:
if result_bar.status == "RUNNING":
break
if time.time() > timeout:
raise AssertionError()

assert result_foo.status == "PENDING"
test_queue.proc.kill("bar")

while True:
if result_foo.status == "SUCCESS":
break
if time.time() > timeout:
raise AssertionError()

0 comments on commit 4df5a13

Please sign in to comment.