diff --git a/dvc/repo/experiments/queue/base.py b/dvc/repo/experiments/queue/base.py index 1a29575c69..1f895a64b5 100644 --- a/dvc/repo/experiments/queue/base.py +++ b/dvc/repo/experiments/queue/base.py @@ -170,6 +170,14 @@ def get_result(self, entry: QueueEntry) -> Optional[ExecutorResult]: This method blocks until the specified item has been collected. """ + @abstractmethod + def kill(self, revs: str) -> None: + """Kill the specified running entries in the queue. + + Arguments: + revs: Stash revs or running exp name to be killed. + """ + @abstractmethod def shutdown(self, kill: bool = False): """Shutdown the queue worker. @@ -495,3 +503,28 @@ def on_diverged(ref: str, checkpoint: bool): executor.collect_cache(exp.repo, exec_result.ref_info) return results + + def get_queue_entry_by_names( + self, + exp_names: Collection[str], + ) -> Dict[str, Optional[QueueEntry]]: + from scmrepo.exceptions import RevError as InternalRevError + + exp_name_set = set(exp_names) + result: Dict[str, Optional[QueueEntry]] = {} + rev_entries = {} + + for entry in self.iter_queued(): + if entry.name in exp_name_set: + result[entry.name] = entry + else: + rev_entries[entry.stash_rev] = entry + + for exp_name in exp_name_set.difference(result.keys()): + try: + rev = self.scm.resolve_rev(exp_name) + if rev in rev_entries: + result[exp_name] = rev_entries[rev] + except InternalRevError: + result[exp_name] = None + return result diff --git a/dvc/repo/experiments/queue/local.py b/dvc/repo/experiments/queue/local.py index cfe1161b20..944249cb02 100644 --- a/dvc/repo/experiments/queue/local.py +++ b/dvc/repo/experiments/queue/local.py @@ -5,6 +5,7 @@ from collections import defaultdict from typing import ( TYPE_CHECKING, + Collection, Dict, Generator, List, @@ -20,7 +21,7 @@ from dvc.daemon import daemonize from dvc.exceptions import DvcException -from ..exceptions import ExpQueueEmptyError +from ..exceptions import ExpQueueEmptyError, UnresolvedExpNamesError from ..executor.base import ( EXEC_PID_DIR, EXEC_TMP_DIR, @@ -200,6 +201,29 @@ def get_result(self, entry: QueueEntry) -> Optional[ExecutorResult]: pass time.sleep(1) + def kill(self, revs: Collection[str]) -> None: + to_kill: Set[QueueEntry] = set() + not_active: List[str] = [] + name_dict: Dict[ + str, Optional[QueueEntry] + ] = self.get_queue_entry_by_names(set(revs)) + + missing_rev: List[str] = [] + active_queue_entry = set(self.iter_active()) + for rev, queue_entry in name_dict.items(): + if queue_entry is None: + missing_rev.append(rev) + elif queue_entry not in active_queue_entry: + not_active.append(rev) + else: + to_kill.add(queue_entry) + + if missing_rev: + raise UnresolvedExpNamesError(missing_rev) + + for queue_entry in to_kill: + self.proc.kill(queue_entry.name) + def _shutdown_handler(self, task_id: str = None, **kwargs): if task_id in self._shutdown_task_ids: self._shutdown_task_ids.remove(task_id) @@ -336,5 +360,8 @@ def collect_executor( # pylint: disable=unused-argument def get_result(self, entry: QueueEntry) -> Optional[ExecutorResult]: raise NotImplementedError + def kill(self, revs: Collection[str]) -> None: + raise NotImplementedError + def shutdown(self, kill: bool = False): raise NotImplementedError diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 181b81f5fc..903bd90259 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Union from dvc.repo import locked from dvc.repo.scm_context import scm_context @@ -85,7 +85,9 @@ def _resolve_exp_by_name( commit_ref_dict[exp_ref] = exp_name if not git_remote: - _named_entries = _get_queue_entry_by_names(repo, remained) + _named_entries = ( + repo.experiments.celery_queue.get_queue_entry_by_names(remained) + ) for exp_name, entry in _named_entries.items(): if entry is not None: queue_entry_dict[exp_name] = entry @@ -127,30 +129,6 @@ def _clear_all_commits(repo, git_remote) -> List: return _remove_commited_exps(repo.scm, ref_infos, git_remote) -def _get_queue_entry_by_names( - repo: "Repo", - exp_name_set: Set[str], -) -> Dict[str, Optional[QueueEntry]]: - from scmrepo.exceptions import RevError as InternalRevError - - result = {} - rev_entries = {} - for entry in repo.experiments.celery_queue.iter_queued(): - if entry.name in exp_name_set: - result[entry.name] = entry - else: - rev_entries[entry.stash_rev] = entry - - for exp_name in exp_name_set.difference(result.keys()): - try: - rev = repo.scm.resolve_rev(exp_name) - if rev in rev_entries: - result[exp_name] = rev_entries[rev] - except InternalRevError: - result[exp_name] = None - return result - - def _remove_commited_exps( scm: "Git", exp_ref_dict: Mapping[ExpRefInfo, str], remote: Optional[str] ) -> List[str]: diff --git a/tests/unit/repo/experiments/queue/test_local.py b/tests/unit/repo/experiments/queue/test_local.py index b8d6337aa8..10e5b72185 100644 --- a/tests/unit/repo/experiments/queue/test_local.py +++ b/tests/unit/repo/experiments/queue/test_local.py @@ -8,8 +8,8 @@ def test_shutdown_no_tasks(test_queue, mocker): @shared_task -def _foo(): - return "foo" +def _foo(arg="foo"): + return arg def test_shutdown_active_tasks(test_queue, mocker): @@ -35,3 +35,35 @@ def test_shutdown_active_tasks(test_queue, mocker): assert "foo" == result.get() assert result.id not in test_queue._shutdown_task_ids shutdown_spy.assert_called_once() + + +def test_post_run_after_kill(test_queue, mocker): + import time + + from celery import chain + + sig_bar = test_queue.proc.run( + ["python3", "-c", "import time; time.sleep(5)"], name="bar" + ) + result_bar = sig_bar.freeze() + sig_foo = _foo.s() + result_foo = sig_foo.freeze() + run_chain = chain(sig_bar, sig_foo) + + run_chain.delay() + timeout = time.time() + 10 + + while True: + if result_bar.status == "RUNNING": + break + if time.time() > timeout: + raise AssertionError() + + assert result_foo.status == "PENDING" + test_queue.proc.kill("bar") + + while True: + if result_foo.status == "SUCCESS": + break + if time.time() > timeout: + raise AssertionError()