From 594b117fe149a792e089c626d8a0ff4d13a12108 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 23 May 2022 18:05:24 +0800 Subject: [PATCH] Add more unit test and implement a new method 1. Add some more unit test for `celery_queue.kill` 2. Implement `kill` argument for `celery_queue.shutdown` 3. Use dataclass's internal methods instead of `__eq__` --- dvc/repo/experiments/queue/base.py | 2 +- dvc/repo/experiments/queue/local.py | 7 +- .../unit/repo/experiments/queue/test_local.py | 71 ++++++++++++++++--- 3 files changed, 67 insertions(+), 13 deletions(-) diff --git a/dvc/repo/experiments/queue/base.py b/dvc/repo/experiments/queue/base.py index 48d9b5a84c..e772c9571a 100644 --- a/dvc/repo/experiments/queue/base.py +++ b/dvc/repo/experiments/queue/base.py @@ -43,7 +43,7 @@ logger = logging.getLogger(__name__) -@dataclass +@dataclass(frozen=True) class QueueEntry: dvc_root: str scm_root: str diff --git a/dvc/repo/experiments/queue/local.py b/dvc/repo/experiments/queue/local.py index e0b46aad64..3367b5412d 100644 --- a/dvc/repo/experiments/queue/local.py +++ b/dvc/repo/experiments/queue/local.py @@ -238,7 +238,7 @@ def kill(self, revs: Collection[str]) -> None: raise UnresolvedExpNamesError(missing_rev) for queue_entry in to_kill: - self.proc.kill(queue_entry.name) + self.proc.kill(queue_entry.stash_rev) def _shutdown_handler(self, task_id: str = None, **kwargs): if task_id in self._shutdown_task_ids: @@ -249,13 +249,14 @@ def _shutdown_handler(self, task_id: str = None, **kwargs): def shutdown(self, kill: bool = False): from celery.signals import task_postrun - if kill: - raise NotImplementedError tasks = list(self._iter_active_tasks()) if tasks: for task_id, _ in tasks: self._shutdown_task_ids.add(task_id) task_postrun.connect()(self._shutdown_handler) + if kill: + for _, task_entry in tasks: + self.proc.kill(task_entry.stash_rev) else: self.celery.control.shutdown() diff --git a/tests/unit/repo/experiments/queue/test_local.py b/tests/unit/repo/experiments/queue/test_local.py index 10e5b72185..10bf7b29df 100644 --- a/tests/unit/repo/experiments/queue/test_local.py +++ b/tests/unit/repo/experiments/queue/test_local.py @@ -1,5 +1,10 @@ +import time + +import pytest from celery import shared_task +from dvc.repo.experiments.exceptions import UnresolvedExpNamesError + def test_shutdown_no_tasks(test_queue, mocker): shutdown_spy = mocker.spy(test_queue.celery.control, "shutdown") @@ -8,8 +13,8 @@ def test_shutdown_no_tasks(test_queue, mocker): @shared_task -def _foo(arg="foo"): - return arg +def _foo(arg=None): # pylint: disable=unused-argument + return "foo" def test_shutdown_active_tasks(test_queue, mocker): @@ -37,8 +42,32 @@ def test_shutdown_active_tasks(test_queue, mocker): shutdown_spy.assert_called_once() -def test_post_run_after_kill(test_queue, mocker): - import time +def test_shutdown_with_kill(test_queue, mocker): + + sig = _foo.s() + mock_entry = mocker.Mock(stash_rev=_foo.name) + + result = sig.freeze() + + shutdown_spy = mocker.patch("celery.app.control.Control.shutdown") + mocker.patch.object( + test_queue, + "_iter_active_tasks", + return_value=[(result.id, mock_entry)], + ) + kill_spy = mocker.patch.object(test_queue.proc, "kill") + + test_queue.shutdown(kill=True) + + sig.delay() + + assert result.get() == "foo" + assert result.id not in test_queue._shutdown_task_ids + kill_spy.assert_called_once_with(mock_entry.stash_rev) + shutdown_spy.assert_called_once() + + +def test_post_run_after_kill(test_queue): from celery import chain @@ -62,8 +91,32 @@ def test_post_run_after_kill(test_queue, mocker): assert result_foo.status == "PENDING" test_queue.proc.kill("bar") - while True: - if result_foo.status == "SUCCESS": - break - if time.time() > timeout: - raise AssertionError() + assert result_foo.get(timeout=10) == "foo" + + +def test_celery_queue_kill(test_queue, mocker): + + mock_entry = mocker.Mock(stash_rev=_foo.name) + + mocker.patch.object( + test_queue, + "iter_active", + return_value={mock_entry}, + ) + mocker.patch.object( + test_queue, + "get_queue_entry_by_names", + return_value={"bar": None}, + ) + with pytest.raises(UnresolvedExpNamesError): + test_queue.kill("bar") + + mocker.patch.object( + test_queue, + "get_queue_entry_by_names", + return_value={"bar": mock_entry}, + ) + + spy = mocker.patch.object(test_queue.proc, "kill") + test_queue.kill("bar") + assert spy.called_once_with(mock_entry.stash_rev)