Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exp: implement internal queue.kill support(#7587) #7714

Merged
merged 1 commit into from
May 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ def get_result(self, entry: QueueEntry) -> Optional[ExecutorResult]:
This method blocks until the specified item has been collected.
"""

@abstractmethod
def kill(self, revs: str) -> None:
"""Kill the specified running entries in the queue.

Arguments:
revs: Stash revs or running exp name to be killed.
"""

@abstractmethod
def shutdown(self, kill: bool = False):
"""Shutdown the queue worker.
Expand Down Expand Up @@ -495,3 +503,28 @@ def on_diverged(ref: str, checkpoint: bool):
executor.collect_cache(exp.repo, exec_result.ref_info)

return results

def get_queue_entry_by_names(
self,
exp_names: Collection[str],
) -> Dict[str, Optional[QueueEntry]]:
from scmrepo.exceptions import RevError as InternalRevError

exp_name_set = set(exp_names)
result: Dict[str, Optional[QueueEntry]] = {}
rev_entries = {}

for entry in self.iter_queued():
if entry.name in exp_name_set:
result[entry.name] = entry
else:
rev_entries[entry.stash_rev] = entry

for exp_name in exp_name_set.difference(result.keys()):
try:
rev = self.scm.resolve_rev(exp_name)
if rev in rev_entries:
result[exp_name] = rev_entries[rev]
except InternalRevError:
result[exp_name] = None
return result
29 changes: 28 additions & 1 deletion dvc/repo/experiments/queue/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Generator,
List,
Expand All @@ -20,7 +21,7 @@
from dvc.daemon import daemonize
from dvc.exceptions import DvcException

from ..exceptions import ExpQueueEmptyError
from ..exceptions import ExpQueueEmptyError, UnresolvedExpNamesError
from ..executor.base import (
EXEC_PID_DIR,
EXEC_TMP_DIR,
Expand Down Expand Up @@ -201,6 +202,29 @@ def get_result(self, entry: QueueEntry) -> Optional[ExecutorResult]:
pass
time.sleep(1)

def kill(self, revs: Collection[str]) -> None:
to_kill: Set[QueueEntry] = set()
not_active: List[str] = []
name_dict: Dict[
str, Optional[QueueEntry]
] = self.get_queue_entry_by_names(set(revs))

missing_rev: List[str] = []
active_queue_entry = set(self.iter_active())
for rev, queue_entry in name_dict.items():
if queue_entry is None:
missing_rev.append(rev)
elif queue_entry not in active_queue_entry:
not_active.append(rev)
else:
to_kill.add(queue_entry)

if missing_rev:
raise UnresolvedExpNamesError(missing_rev)

for queue_entry in to_kill:
self.proc.kill(queue_entry.name)

def _shutdown_handler(self, task_id: str = None, **kwargs):
if task_id in self._shutdown_task_ids:
self._shutdown_task_ids.remove(task_id)
Expand Down Expand Up @@ -337,5 +361,8 @@ def collect_executor( # pylint: disable=unused-argument
def get_result(self, entry: QueueEntry) -> Optional[ExecutorResult]:
raise NotImplementedError

def kill(self, revs: Collection[str]) -> None:
raise NotImplementedError

def shutdown(self, kill: bool = False):
raise NotImplementedError
30 changes: 4 additions & 26 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Union
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Union

from dvc.repo import locked
from dvc.repo.scm_context import scm_context
Expand Down Expand Up @@ -85,7 +85,9 @@ def _resolve_exp_by_name(
commit_ref_dict[exp_ref] = exp_name

if not git_remote:
_named_entries = _get_queue_entry_by_names(repo, remained)
_named_entries = (
repo.experiments.celery_queue.get_queue_entry_by_names(remained)
)
for exp_name, entry in _named_entries.items():
if entry is not None:
queue_entry_dict[exp_name] = entry
Expand Down Expand Up @@ -127,30 +129,6 @@ def _clear_all_commits(repo, git_remote) -> List:
return _remove_commited_exps(repo.scm, ref_infos, git_remote)


def _get_queue_entry_by_names(
repo: "Repo",
exp_name_set: Set[str],
) -> Dict[str, Optional[QueueEntry]]:
from scmrepo.exceptions import RevError as InternalRevError

result = {}
rev_entries = {}
for entry in repo.experiments.celery_queue.iter_queued():
if entry.name in exp_name_set:
result[entry.name] = entry
else:
rev_entries[entry.stash_rev] = entry

for exp_name in exp_name_set.difference(result.keys()):
try:
rev = repo.scm.resolve_rev(exp_name)
if rev in rev_entries:
result[exp_name] = rev_entries[rev]
except InternalRevError:
result[exp_name] = None
return result


def _remove_commited_exps(
scm: "Git", exp_ref_dict: Mapping[ExpRefInfo, str], remote: Optional[str]
) -> List[str]:
Expand Down
36 changes: 34 additions & 2 deletions tests/unit/repo/experiments/queue/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def test_shutdown_no_tasks(test_queue, mocker):


@shared_task
def _foo():
return "foo"
def _foo(arg="foo"):
return arg


def test_shutdown_active_tasks(test_queue, mocker):
Expand All @@ -35,3 +35,35 @@ def test_shutdown_active_tasks(test_queue, mocker):
assert "foo" == result.get()
assert result.id not in test_queue._shutdown_task_ids
shutdown_spy.assert_called_once()


def test_post_run_after_kill(test_queue, mocker):
import time

from celery import chain

sig_bar = test_queue.proc.run(
["python3", "-c", "import time; time.sleep(5)"], name="bar"
)
result_bar = sig_bar.freeze()
sig_foo = _foo.s()
result_foo = sig_foo.freeze()
run_chain = chain(sig_bar, sig_foo)

run_chain.delay()
timeout = time.time() + 10

while True:
if result_bar.status == "RUNNING":
break
if time.time() > timeout:
raise AssertionError()

assert result_foo.status == "PENDING"
test_queue.proc.kill("bar")

while True:
if result_foo.status == "SUCCESS":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, even if we killed the process "bar", the result_bar.status still returns "SUCCESS", only celery.control.revoke would make it fail, and this is why the following "foo" process will continue.

break
if time.time() > timeout:
raise AssertionError()