Skip to content

Commit

Permalink
Add more unit test and implement a new method
Browse files Browse the repository at this point in the history
1. Add some more unit test for `celery_queue.kill`
2. Implement `kill` argument for `celery_queue.shutdown`
3. Use dataclass's internal methods instead of `__eq__`
  • Loading branch information
karajan1001 authored and pmrowla committed Jul 5, 2022
1 parent d2ae5c2 commit 594b117
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 13 deletions.
2 changes: 1 addition & 1 deletion dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
logger = logging.getLogger(__name__)


@dataclass
@dataclass(frozen=True)
class QueueEntry:
dvc_root: str
scm_root: str
Expand Down
7 changes: 4 additions & 3 deletions dvc/repo/experiments/queue/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def kill(self, revs: Collection[str]) -> None:
raise UnresolvedExpNamesError(missing_rev)

for queue_entry in to_kill:
self.proc.kill(queue_entry.name)
self.proc.kill(queue_entry.stash_rev)

def _shutdown_handler(self, task_id: str = None, **kwargs):
if task_id in self._shutdown_task_ids:
Expand All @@ -249,13 +249,14 @@ def _shutdown_handler(self, task_id: str = None, **kwargs):
def shutdown(self, kill: bool = False):
from celery.signals import task_postrun

if kill:
raise NotImplementedError
tasks = list(self._iter_active_tasks())
if tasks:
for task_id, _ in tasks:
self._shutdown_task_ids.add(task_id)
task_postrun.connect()(self._shutdown_handler)
if kill:
for _, task_entry in tasks:
self.proc.kill(task_entry.stash_rev)
else:
self.celery.control.shutdown()

Expand Down
71 changes: 62 additions & 9 deletions tests/unit/repo/experiments/queue/test_local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import time

import pytest
from celery import shared_task

from dvc.repo.experiments.exceptions import UnresolvedExpNamesError


def test_shutdown_no_tasks(test_queue, mocker):
shutdown_spy = mocker.spy(test_queue.celery.control, "shutdown")
Expand All @@ -8,8 +13,8 @@ def test_shutdown_no_tasks(test_queue, mocker):


@shared_task
def _foo(arg="foo"):
return arg
def _foo(arg=None): # pylint: disable=unused-argument
return "foo"


def test_shutdown_active_tasks(test_queue, mocker):
Expand Down Expand Up @@ -37,8 +42,32 @@ def test_shutdown_active_tasks(test_queue, mocker):
shutdown_spy.assert_called_once()


def test_post_run_after_kill(test_queue, mocker):
import time
def test_shutdown_with_kill(test_queue, mocker):

sig = _foo.s()
mock_entry = mocker.Mock(stash_rev=_foo.name)

result = sig.freeze()

shutdown_spy = mocker.patch("celery.app.control.Control.shutdown")
mocker.patch.object(
test_queue,
"_iter_active_tasks",
return_value=[(result.id, mock_entry)],
)
kill_spy = mocker.patch.object(test_queue.proc, "kill")

test_queue.shutdown(kill=True)

sig.delay()

assert result.get() == "foo"
assert result.id not in test_queue._shutdown_task_ids
kill_spy.assert_called_once_with(mock_entry.stash_rev)
shutdown_spy.assert_called_once()


def test_post_run_after_kill(test_queue):

from celery import chain

Expand All @@ -62,8 +91,32 @@ def test_post_run_after_kill(test_queue, mocker):
assert result_foo.status == "PENDING"
test_queue.proc.kill("bar")

while True:
if result_foo.status == "SUCCESS":
break
if time.time() > timeout:
raise AssertionError()
assert result_foo.get(timeout=10) == "foo"


def test_celery_queue_kill(test_queue, mocker):

mock_entry = mocker.Mock(stash_rev=_foo.name)

mocker.patch.object(
test_queue,
"iter_active",
return_value={mock_entry},
)
mocker.patch.object(
test_queue,
"get_queue_entry_by_names",
return_value={"bar": None},
)
with pytest.raises(UnresolvedExpNamesError):
test_queue.kill("bar")

mocker.patch.object(
test_queue,
"get_queue_entry_by_names",
return_value={"bar": mock_entry},
)

spy = mocker.patch.object(test_queue.proc, "kill")
test_queue.kill("bar")
assert spy.called_once_with(mock_entry.stash_rev)

0 comments on commit 594b117

Please sign in to comment.