diff --git a/dvc/commands/queue/remove.py b/dvc/commands/queue/remove.py index ef68f8ef3e..15641dc4ce 100644 --- a/dvc/commands/queue/remove.py +++ b/dvc/commands/queue/remove.py @@ -12,18 +12,19 @@ class CmdQueueRemove(CmdBase): """Remove exp in queue.""" def run(self): - if self.args.all: - removed_list = self.repo.experiments.celery_queue.clear() - else: - removed_list = self.repo.experiments.celery_queue.remove( - revs=self.args.task - ) + removed_list = self.repo.experiments.celery_queue.remove( + revs=self.args.task, + _all=self.args.all, + success=self.args.success, + queued=self.args.queued, + failed=self.args.failed, + ) if removed_list: removed = ", ".join(removed_list) ui.write(f"Removed tasks in queue: {removed}") else: - ui.write(f"No tasks found in queue named {self.args.task}") + ui.write(f"No tasks found named {self.args.task}") return 0 @@ -39,7 +40,16 @@ def add_parser(queue_subparsers, parent_parser): formatter_class=argparse.RawDescriptionHelpFormatter, ) queue_remove_parser.add_argument( - "--all", action="store_true", help="Remove all tasks in queue." + "--all", action="store_true", help="Remove all tasks." + ) + queue_remove_parser.add_argument( + "--queued", action="store_true", help="Remove all tasks in queue." + ) + queue_remove_parser.add_argument( + "--success", action="store_true", help="Remove all successful tasks." + ) + queue_remove_parser.add_argument( + "--failed", action="store_true", help="Remove all failed tasks." ) queue_remove_parser.add_argument( "task", diff --git a/dvc/repo/experiments/queue/base.py b/dvc/repo/experiments/queue/base.py index a5caa9363d..90ec0508c0 100644 --- a/dvc/repo/experiments/queue/base.py +++ b/dvc/repo/experiments/queue/base.py @@ -131,15 +131,27 @@ def put(self, *args, **kwargs) -> QueueEntry: def get(self) -> QueueGetResult: """Pop and return the first item in the queue.""" - def remove(self, revs: Collection[str]) -> List[str]: + def remove( + self, + revs: Collection[str], + _all: bool = False, + queued: bool = False, + **kwargs, + ) -> List[str]: """Remove the specified entries from the queue. Arguments: revs: Stash revisions or queued exp names to be removed. + queued: Remove all queued tasks. + all: Remove all tasks. Returns: Revisions (or names) which were removed. """ + + if _all or queued: + return self.clear() + removed: List[str] = [] to_remove: Dict[str, ExpStashEntry] = {} queue_entries = self.match_queue_entry_by_name( @@ -152,18 +164,15 @@ def remove(self, revs: Collection[str]) -> List[str]: ] removed.append(name) - self._remove_revs(to_remove) + self._remove_revs(to_remove, self.stash) return removed - def clear(self) -> List[str]: - """Remove all entries from the queue. - - Returns: - Revisions which were removed. - """ + def clear(self, **kwargs) -> List[str]: + """Remove all entries from the queue.""" stash_revs = self.stash.stash_revs removed = list(stash_revs) - self._remove_revs(stash_revs) + self._remove_revs(stash_revs, self.stash) + return removed def status(self) -> List[Dict[str, Any]]: @@ -176,27 +185,13 @@ def _get_timestamp(rev: str) -> datetime: commit = self.scm.resolve_commit(rev) return datetime.fromtimestamp(commit.commit_time) - failed_revs: Dict[str, ExpStashEntry] = ( - dict(self.failed_stash.stash_revs) - if self.failed_stash is not None - else {} - ) - def _format_entry( entry: QueueEntry, status: str = "Unknown", - result: Optional[ExecutorResult] = None, ) -> Dict[str, Any]: name = entry.name # NOTE: We fallback to Unknown status for experiments # generated in prior (incompatible) DVC versions - if result is None and entry.stash_rev in failed_revs: - status = "Failed" - elif result is not None: - if result.exp_hash: - if result.ref_info: - status = "Success" - name = result.ref_info.name return { "rev": entry.stash_rev, "name": name, @@ -213,12 +208,17 @@ def _format_entry( for queue_entry in self.iter_queued() ) result.extend( - _format_entry(queue_entry, result=exp_result) - for queue_entry, exp_result in self.iter_done() + _format_entry(queue_entry, status="Failed") + for queue_entry, _ in self.iter_failed() + ) + result.extend( + _format_entry(queue_entry, status="Success") + for queue_entry, _ in self.iter_success() ) return result - def _remove_revs(self, stash_revs: Mapping[str, ExpStashEntry]): + @staticmethod + def _remove_revs(stash_revs: Mapping[str, ExpStashEntry], stash: ExpStash): """Remove the specified entries from the queue by stash revision.""" for index in sorted( ( @@ -228,7 +228,7 @@ def _remove_revs(self, stash_revs: Mapping[str, ExpStashEntry]): ), reverse=True, ): - self.stash.drop(index) + stash.drop(index) @abstractmethod def iter_queued(self) -> Generator[QueueEntry, None, None]: @@ -238,9 +238,18 @@ def iter_queued(self) -> Generator[QueueEntry, None, None]: def iter_active(self) -> Generator[QueueEntry, None, None]: """Iterate over items which are being actively processed.""" + @abstractmethod def iter_done(self) -> Generator[QueueDoneResult, None, None]: """Iterate over items which been processed.""" + @abstractmethod + def iter_success(self) -> Generator[QueueDoneResult, None, None]: + """Iterate over items which been success.""" + + @abstractmethod + def iter_failed(self) -> Generator[QueueDoneResult, None, None]: + """Iterate over items which been failed.""" + @abstractmethod def reproduce(self) -> Mapping[str, Mapping[str, str]]: """Reproduce queued experiments sequentially.""" @@ -460,8 +469,8 @@ def _stash_commit_deps(self, *args, **kwargs): data_only=True, ) + @staticmethod def _stash_msg( - self, rev: str, baseline_rev: str, branch: Optional[str] = None, @@ -503,7 +512,8 @@ def _pack_args(self, *args, **kwargs) -> None: ) self.scm.add(self.args_file) - def _format_new_params_msg(self, new_params, config_path): + @staticmethod + def _format_new_params_msg(new_params, config_path): """Format an error message for when new parameters are identified""" new_param_count = len(new_params) pluralise = "s are" if new_param_count > 1 else " is" diff --git a/dvc/repo/experiments/queue/celery.py b/dvc/repo/experiments/queue/celery.py index b5e0356a54..97880ed41f 100644 --- a/dvc/repo/experiments/queue/celery.py +++ b/dvc/repo/experiments/queue/celery.py @@ -171,14 +171,6 @@ def put(self, *args, **kwargs) -> QueueEntry: def get(self) -> QueueGetResult: raise NotImplementedError - def _remove_revs(self, stash_revs: Mapping[str, ExpStashEntry]): - try: - for msg, queue_entry in self._iter_queued(): - if queue_entry.stash_rev in stash_revs: - self.celery.reject(msg.delivery_tag) - finally: - super()._remove_revs(stash_revs) - def iter_queued(self) -> Generator[QueueEntry, None, None]: for _, entry in self._iter_queued(): yield entry @@ -225,6 +217,22 @@ def iter_done(self) -> Generator[QueueDoneResult, None, None]: for _, entry in self._iter_done_tasks(): yield QueueDoneResult(entry, self.get_result(entry)) + def iter_success(self) -> Generator[QueueDoneResult, None, None]: + for queue_entry, exp_result in self.iter_done(): + if exp_result and exp_result.exp_hash and exp_result.ref_info: + yield QueueDoneResult(queue_entry, exp_result) + + def iter_failed(self) -> Generator[QueueDoneResult, None, None]: + failed_revs: Dict[str, ExpStashEntry] = ( + dict(self.failed_stash.stash_revs) + if self.failed_stash is not None + else {} + ) + + for queue_entry, exp_result in self.iter_done(): + if exp_result is None and queue_entry.stash_rev in failed_revs: + yield QueueDoneResult(queue_entry, exp_result) + def reproduce(self) -> Mapping[str, Mapping[str, str]]: raise NotImplementedError @@ -341,3 +349,13 @@ def worker_status(self) -> Dict: status = self.celery.control.inspect().active() or {} logger.debug(f"Worker status: {status}") return status + + def clear(self, *args, **kwargs): + from .remove import clear + + return clear(self, *args, **kwargs) + + def remove(self, *args, **kwargs): + from .remove import remove + + return remove(self, *args, **kwargs) diff --git a/dvc/repo/experiments/queue/remove.py b/dvc/repo/experiments/queue/remove.py new file mode 100644 index 0000000000..edf0c89ca4 --- /dev/null +++ b/dvc/repo/experiments/queue/remove.py @@ -0,0 +1,185 @@ +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Union, +) + +from dvc.repo.experiments.exceptions import UnresolvedExpNamesError +from dvc.repo.experiments.queue.base import QueueDoneResult + +if TYPE_CHECKING: + from dvc.repo.experiments.queue.base import QueueEntry + from dvc.repo.experiments.queue.local import LocalCeleryQueue + from dvc.repo.experiments.stash import ExpStashEntry + + +def _remove_queued_tasks( + celery_queue: "LocalCeleryQueue", + queue_entries: Iterable[Optional["QueueEntry"]], +): + """Remove tasks from task queue. + + Arguments: + queue_entries: An iterable list of queued task to remove + """ + # pylint: disable=protected-access + stash_revs: Dict[str, "ExpStashEntry"] = {} + for entry in queue_entries: + if entry: + stash_revs[entry.stash_rev] = celery_queue.stash.stash_revs[ + entry.stash_rev + ] + + try: + for msg, queue_entry in celery_queue._iter_queued(): + if queue_entry.stash_rev in stash_revs: + celery_queue.celery.reject(msg.delivery_tag) + finally: + celery_queue._remove_revs(stash_revs, celery_queue.stash) + + +def _remove_done_tasks( + celery_queue: "LocalCeleryQueue", + queue_entries: Iterable[Optional["QueueEntry"]], +): + """Remove done tasks. + + Arguments: + queue_entries: An iterable list of done task to remove + """ + # pylint: disable=protected-access + from celery.result import AsyncResult + + failed_stash_revs: Dict[str, "ExpStashEntry"] = {} + queue_entry_set: Set["QueueEntry"] = set() + for entry in queue_entries: + if entry: + queue_entry_set.add(entry) + if entry.stash_rev in celery_queue.failed_stash.stash_revs: + failed_stash_revs[ + entry.stash_rev + ] = celery_queue.failed_stash.stash_revs[entry.stash_rev] + + try: + for msg, queue_entry in celery_queue._iter_processed(): + if queue_entry not in queue_entry_set: + continue + task_id = msg.headers["id"] + result: AsyncResult = AsyncResult(task_id) + if result is not None: + result.forget() + celery_queue.celery.purge(msg.delivery_tag) + finally: + celery_queue._remove_revs(failed_stash_revs, celery_queue.failed_stash) + + +def _get_names(entries: Iterable[Union["QueueEntry", "QueueDoneResult"]]): + names: List[str] = [] + for entry in entries: + if isinstance(entry, QueueDoneResult): + if entry.result and entry.result.ref_info: + names.append(entry.result.ref_info.name) + continue + entry = entry.entry + name = entry.name + name = name or entry.stash_rev[:7] + names.append(name) + return names + + +def clear(self: "LocalCeleryQueue", **kwargs) -> List[str]: + """Remove entries from the queue. + + Arguments: + queued: Remove all queued tasks. + failed: Remove all failed tasks. + success: Remove all success tasks. + + Returns: + Revisions which were removed. + """ + queued = kwargs.pop("queued", False) + failed = kwargs.get("failed", False) + success = kwargs.get("success", False) + + removed = [] + if queued: + queue_entries = list(self.iter_queued()) + _remove_queued_tasks(self, queue_entries) + removed.extend(_get_names(queue_entries)) + if failed or success: + done_tasks: List["QueueDoneResult"] = [] + if failed: + done_tasks.extend(self.iter_failed()) + if success: + done_tasks.extend(self.iter_success()) + done_entries = [result.entry for result in done_tasks] + _remove_done_tasks(self, done_entries) + removed.extend(_get_names(done_tasks)) + + return removed + + +def remove( + self: "LocalCeleryQueue", + revs: Collection[str], + queued: bool = False, + failed: bool = False, + success: bool = False, + _all: bool = False, +) -> List[str]: + """Remove the specified entries from the queue. + + Arguments: + revs: Stash revisions or queued exp names to be removed. + queued: Remove all queued tasks. + failed: Remove all failed tasks. + success: Remove all success tasks. + all: Remove all tasks. + + Returns: + Revisions (or names) which were removed. + """ + if _all: + queued = failed = success = True + if queued or failed or success: + return self.clear(failed=failed, success=success, queued=queued) + + # match_queued + queue_match_results = self.match_queue_entry_by_name( + revs, self.iter_queued() + ) + + done_match_results = self.match_queue_entry_by_name(revs, self.iter_done()) + + remained: List[str] = [] + removed: List[str] = [] + queued_to_remove: List["QueueEntry"] = [] + done_to_remove: List["QueueEntry"] = [] + for name in revs: + done_match = done_match_results[name] + if done_match: + done_to_remove.append(done_match) + removed.append(name) + continue + queue_match = queue_match_results[name] + if queue_match: + queued_to_remove.append(queue_match) + removed.append(name) + continue + remained.append(name) + + if remained: + raise UnresolvedExpNamesError(remained) + + if done_to_remove: + _remove_done_tasks(self, done_to_remove) + if queued_to_remove: + _remove_queued_tasks(self, queued_to_remove) + + return removed diff --git a/dvc/repo/experiments/queue/workspace.py b/dvc/repo/experiments/queue/workspace.py index e8c662cd2a..159dfd479f 100644 --- a/dvc/repo/experiments/queue/workspace.py +++ b/dvc/repo/experiments/queue/workspace.py @@ -63,6 +63,12 @@ def iter_active(self) -> Generator[QueueEntry, None, None]: def iter_done(self) -> Generator[QueueDoneResult, None, None]: raise NotImplementedError + def iter_failed(self) -> Generator[QueueDoneResult, None, None]: + raise NotImplementedError + + def iter_success(self) -> Generator[QueueDoneResult, None, None]: + raise NotImplementedError + def reproduce(self) -> Dict[str, Dict[str, str]]: results: Dict[str, Dict[str, str]] = defaultdict(dict) try: diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index de82ade93f..0fce81dd7c 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -121,7 +121,7 @@ def _clear_queue(repo: "Repo") -> List[str]: removed_name_list = [] for entry in repo.experiments.celery_queue.iter_queued(): removed_name_list.append(entry.name or entry.stash_rev[:7]) - repo.experiments.celery_queue.clear() + repo.experiments.celery_queue.clear(queued=True) return removed_name_list diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 54d6eb1de1..d74ef89bb7 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -101,50 +101,6 @@ def test_failed_exp_workspace( dvc.experiments.run(failed_exp_stage.addressing) -def test_failed_exp_celery( - tmp_dir, - scm, - dvc, - failed_exp_stage, - test_queue, - mocker, - capsys, -): - tmp_dir.gen("params.yaml", "foo: 2") - - dvc.experiments.run(failed_exp_stage.addressing, queue=True) - dvc.experiments.run(run_all=True) - output = capsys.readouterr() - assert "Failed to reproduce experiment" in output.err - assert len(dvc.experiments.celery_queue.failed_stash) == 1 - assert ( - first(dvc.experiments.celery_queue.status()).get("status") == "Failed" - ) - - -@pytest.mark.parametrize("follow", [True, False]) -def test_celery_logs( - tmp_dir, - scm, - dvc, - failed_exp_stage, - test_queue, - follow, - mocker, - capsys, -): - dvc.experiments.run(failed_exp_stage.addressing, queue=True) - dvc.experiments.run(run_all=True) - - queue = dvc.experiments.celery_queue - done_result = first(queue.iter_done()) - name = done_result.entry.stash_rev - captured = capsys.readouterr() - queue.logs(name, follow=follow) - captured = capsys.readouterr() - assert "failed to reproduce 'failed-copy-file'" in captured.out - - @pytest.mark.parametrize( "changes, expected", [ diff --git a/tests/func/experiments/test_queue.py b/tests/func/experiments/test_queue.py new file mode 100644 index 0000000000..2303b0cfea --- /dev/null +++ b/tests/func/experiments/test_queue.py @@ -0,0 +1,149 @@ +import pytest +from funcy import first + +from dvc.exceptions import InvalidArgumentError + + +def to_dict(tasks): + status_dict = {} + for task in tasks: + status_dict[task["name"]] = task["status"] + return status_dict + + +@pytest.fixture +def queued_tasks(tmp_dir, dvc, scm, exp_stage): + queue_length = 3 + name_list = [] + for i in range(queue_length): + name = f"queued{i}" + name_list.append(name) + dvc.experiments.run( + exp_stage.addressing, + params=[f"foo={i+2*queue_length}"], + queue=True, + name=name, + ) + return ["queued0", "queued1", "queued2"] + + +@pytest.fixture +def success_tasks(tmp_dir, dvc, scm, test_queue, exp_stage): + queue_length = 3 + name_list = [] + for i in range(queue_length): + name = f"success{i}" + name_list.append(name) + dvc.experiments.run( + exp_stage.addressing, params=[f"foo={i}"], queue=True, name=name + ) + dvc.experiments.run(run_all=True) + return name_list + + +@pytest.fixture +def failed_tasks(tmp_dir, dvc, scm, test_queue, failed_exp_stage, capsys): + queue_length = 3 + name_list = [] + for i in range(queue_length): + name = f"failed{i}" + name_list.append(name) + dvc.experiments.run( + failed_exp_stage.addressing, + params=[f"foo={i+queue_length}"], + queue=True, + name=name, + ) + dvc.experiments.run(run_all=True) + output = capsys.readouterr() + assert "Failed to reproduce experiment" in output.err + return name_list + + +@pytest.mark.parametrize("follow", [True, False]) +def test_celery_logs( + tmp_dir, + scm, + dvc, + failed_exp_stage, + test_queue, + follow, + capsys, +): + dvc.experiments.run(failed_exp_stage.addressing, queue=True) + dvc.experiments.run(run_all=True) + + queue = dvc.experiments.celery_queue + done_result = first(queue.iter_done()) + name = done_result.entry.stash_rev + captured = capsys.readouterr() + queue.logs(name, follow=follow) + captured = capsys.readouterr() + assert "failed to reproduce 'failed-copy-file'" in captured.out + + +def test_queue_status(dvc, failed_tasks, success_tasks, queued_tasks): + assert len(dvc.experiments.stash_revs) == 3 + assert len(dvc.experiments.celery_queue.failed_stash) == 3 + status = to_dict(dvc.experiments.celery_queue.status()) + assert len(status) == 9 + for task in failed_tasks: + assert status[task] == "Failed" + for task in success_tasks: + assert status[task] == "Success" + for task in queued_tasks: + assert status[task] == "Queued" + + +def test_queue_remove(dvc, failed_tasks, success_tasks, queued_tasks): + assert len(dvc.experiments.stash_revs) == 3 + assert len(dvc.experiments.celery_queue.failed_stash) == 3 + assert len(dvc.experiments.celery_queue.status()) == 9 + + with pytest.raises(InvalidArgumentError): + dvc.experiments.celery_queue.remove(failed_tasks[:2] + ["non-exist"]) + assert len(dvc.experiments.celery_queue.status()) == 9 + + to_remove = failed_tasks[:2] + success_tasks[1:] + queued_tasks[1:2] + assert set(dvc.experiments.celery_queue.remove(to_remove)) == set( + to_remove + ) + + assert len(dvc.experiments.stash_revs) == 2 + assert len(dvc.experiments.celery_queue.failed_stash) == 1 + status = to_dict(dvc.experiments.celery_queue.status()) + assert set(status) == set( + queued_tasks[:1] + + queued_tasks[2:] + + success_tasks[:1] + + failed_tasks[2:] + ) + assert status[queued_tasks[0]] == "Queued" + assert status[queued_tasks[2]] == "Queued" + + assert ( + dvc.experiments.celery_queue.remove([], queued=True) + == queued_tasks[:1] + queued_tasks[2:] + ) + + assert len(dvc.experiments.stash_revs) == 0 + assert len(dvc.experiments.celery_queue.failed_stash) == 1 + assert len(dvc.experiments.celery_queue.status()) == 2 + + assert ( + dvc.experiments.celery_queue.remove([], failed=True) + == failed_tasks[2:] + ) + + assert len(dvc.experiments.stash_revs) == 0 + assert len(dvc.experiments.celery_queue.failed_stash) == 0 + assert len(dvc.experiments.celery_queue.status()) == 1 + + assert ( + dvc.experiments.celery_queue.remove([], success=True) + == success_tasks[:1] + ) + + assert len(dvc.experiments.stash_revs) == 0 + assert len(dvc.experiments.celery_queue.failed_stash) == 0 + assert len(dvc.experiments.celery_queue.status()) == 0 diff --git a/tests/unit/command/test_queue.py b/tests/unit/command/test_queue.py index 35d4ae70b7..34a8bf5a68 100644 --- a/tests/unit/command/test_queue.py +++ b/tests/unit/command/test_queue.py @@ -15,23 +15,9 @@ def test_experiments_remove(dvc, scm, mocker): "queue", "remove", "--all", - ] - ) - assert cli_args.func == CmdQueueRemove - - cmd = cli_args.func(cli_args) - m = mocker.patch( - "dvc.repo.experiments.queue.celery.LocalCeleryQueue.clear", - return_value={}, - ) - - assert cmd.run() == 0 - m.assert_called_once_with() - - cli_args = parse_args( - [ - "queue", - "remove", + "--queued", + "--success", + "--failed", "exp1", "exp2", ] @@ -39,13 +25,19 @@ def test_experiments_remove(dvc, scm, mocker): assert cli_args.func == CmdQueueRemove cmd = cli_args.func(cli_args) - m = mocker.patch( + remove_mocker = mocker.patch( "dvc.repo.experiments.queue.celery.LocalCeleryQueue.remove", return_value={}, ) assert cmd.run() == 0 - m.assert_called_once_with(revs=["exp1", "exp2"]) + remove_mocker.assert_called_once_with( + revs=["exp1", "exp2"], + _all=True, + success=True, + failed=True, + queued=True, + ) def test_experiments_kill(dvc, scm, mocker): diff --git a/tests/unit/repo/experiments/queue/test_remove.py b/tests/unit/repo/experiments/queue/test_remove.py new file mode 100644 index 0000000000..3557a9d345 --- /dev/null +++ b/tests/unit/repo/experiments/queue/test_remove.py @@ -0,0 +1,133 @@ +from unittest.mock import call + +from dvc.repo.experiments.queue.base import QueueDoneResult + + +def test_remove_queued(test_queue, mocker): + + queued_test = ["queue1", "queue2", "queue3"] + + stash_dict = {} + for name in queued_test: + stash_dict[name] = mocker.Mock() + + msg_dict = {} + entry_dict = {} + for name in queued_test: + msg_dict[name] = mocker.Mock(delivery_tag=f"msg_{name}") + entry_dict[name] = mocker.Mock(stash_rev=name) + entry_dict[name].name = name + + msg_iter = [(msg_dict[name], entry_dict[name]) for name in queued_test] + entry_iter = [entry_dict[name] for name in queued_test] + + stash = mocker.patch.object( + test_queue, "stash", return_value=mocker.Mock() + ) + stash.stash_revs = stash_dict + mocker.patch.object(test_queue, "_iter_queued", return_value=msg_iter) + mocker.patch.object(test_queue, "iter_queued", return_value=entry_iter) + + remove_revs_mocker = mocker.patch.object(test_queue, "_remove_revs") + reject_mocker = mocker.patch.object(test_queue.celery, "reject") + + assert test_queue.remove(["queue2"]) == ["queue2"] + reject_mocker.assert_called_once_with("msg_queue2") + remove_revs_mocker.assert_called_once_with( + {"queue2": stash_dict["queue2"]}, test_queue.stash + ) + + remove_revs_mocker.reset_mock() + reject_mocker.reset_mock() + + assert test_queue.remove([], queued=True) == queued_test + remove_revs_mocker.assert_called_once_with(stash_dict, test_queue.stash) + reject_mocker.assert_has_calls( + [call("msg_queue1"), call("msg_queue2"), call("msg_queue3")] + ) + + +def test_remove_done(test_queue, mocker): + from funcy import concat + + failed_test = ["failed1", "failed2", "failed3"] + success_test = ["success1", "success2", "success3"] + + stash_dict = {} + for name in failed_test: + stash_dict[name] = mocker.Mock() + + msg_dict = {} + entry_dict = {} + for name in concat(failed_test, success_test): + msg_dict[name] = mocker.Mock( + delivery_tag=f"msg_{name}", headers={"id": 0} + ) + entry_dict[name] = mocker.Mock(stash_rev=name) + entry_dict[name].name = name + + msg_iter = [ + (msg_dict[name], entry_dict[name]) + for name in concat(failed_test, success_test) + ] + done_iter = [ + QueueDoneResult(entry_dict[name], None) + for name in concat(failed_test, success_test) + ] + failed_iter = [ + QueueDoneResult(entry_dict[name], None) for name in failed_test + ] + success_iter = [ + QueueDoneResult(entry_dict[name], None) for name in success_test + ] + + stash = mocker.patch.object( + test_queue, "failed_stash", return_value=mocker.Mock() + ) + stash.stash_revs = stash_dict + mocker.patch.object(test_queue, "_iter_processed", return_value=msg_iter) + mocker.patch.object(test_queue, "iter_done", return_value=done_iter) + mocker.patch.object(test_queue, "iter_success", return_value=success_iter) + mocker.patch.object(test_queue, "iter_failed", return_value=failed_iter) + mocker.patch("celery.result.AsyncResult", return_value=mocker.Mock()) + + remove_revs_mocker = mocker.patch.object(test_queue, "_remove_revs") + purge_mocker = mocker.patch.object(test_queue.celery, "purge") + + assert test_queue.remove(["failed3", "success2"]) == [ + "failed3", + "success2", + ] + remove_revs_mocker.assert_called_once_with( + {"failed3": stash_dict["failed3"]}, test_queue.failed_stash + ) + purge_mocker.assert_has_calls([call("msg_failed3"), call("msg_success2")]) + + remove_revs_mocker.reset_mock() + purge_mocker.reset_mock() + + assert set(test_queue.remove([], success=True, failed=True)) == set( + failed_test + ) | set(success_test) + purge_mocker.assert_has_calls( + [ + call("msg_failed1"), + call("msg_failed2"), + call("msg_failed3"), + call("msg_success1"), + call("msg_success2"), + call("msg_success3"), + ], + any_order=True, + ) + remove_revs_mocker.assert_called_once_with( + stash_dict, test_queue.failed_stash + ) + + +def test_remove_all(test_queue, mocker): + clear_mocker = mocker.patch.object( + test_queue, "clear", return_value=mocker.Mock() + ) + test_queue.remove([], _all=True) + assert clear_mocker.called_once_with(queud=True, failed=True, success=True)