Skip to content

Commit

Permalink
Queue remove: Can remove successed,failed tasks msgs
Browse files Browse the repository at this point in the history
Followed from #7592

1. Seperate remove method to a new file.
2. Add queued, failed, processed flags to remove.
3. Add new unit tests for queue status --queue/fail/success
4. Implement methods to remove done tasks.
5. add some new unit and functional test for celery_queue.remove
6. bump into dvc-task version 0.0.13
  • Loading branch information
karajan1001 authored and pmrowla committed Jul 11, 2022
1 parent 28bc583 commit 6dddb31
Show file tree
Hide file tree
Showing 10 changed files with 568 additions and 109 deletions.
26 changes: 18 additions & 8 deletions dvc/commands/queue/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@ class CmdQueueRemove(CmdBase):
"""Remove exp in queue."""

def run(self):
if self.args.all:
removed_list = self.repo.experiments.celery_queue.clear()
else:
removed_list = self.repo.experiments.celery_queue.remove(
revs=self.args.task
)
removed_list = self.repo.experiments.celery_queue.remove(
revs=self.args.task,
_all=self.args.all,
success=self.args.success,
queued=self.args.queued,
failed=self.args.failed,
)

if removed_list:
removed = ", ".join(removed_list)
ui.write(f"Removed tasks in queue: {removed}")
else:
ui.write(f"No tasks found in queue named {self.args.task}")
ui.write(f"No tasks found named {self.args.task}")

return 0

Expand All @@ -39,7 +40,16 @@ def add_parser(queue_subparsers, parent_parser):
formatter_class=argparse.RawDescriptionHelpFormatter,
)
queue_remove_parser.add_argument(
"--all", action="store_true", help="Remove all tasks in queue."
"--all", action="store_true", help="Remove all tasks."
)
queue_remove_parser.add_argument(
"--queued", action="store_true", help="Remove all tasks in queue."
)
queue_remove_parser.add_argument(
"--success", action="store_true", help="Remove all successful tasks."
)
queue_remove_parser.add_argument(
"--failed", action="store_true", help="Remove all failed tasks."
)
queue_remove_parser.add_argument(
"task",
Expand Down
68 changes: 39 additions & 29 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,27 @@ def put(self, *args, **kwargs) -> QueueEntry:
def get(self) -> QueueGetResult:
"""Pop and return the first item in the queue."""

def remove(self, revs: Collection[str]) -> List[str]:
def remove(
self,
revs: Collection[str],
_all: bool = False,
queued: bool = False,
**kwargs,
) -> List[str]:
"""Remove the specified entries from the queue.
Arguments:
revs: Stash revisions or queued exp names to be removed.
queued: Remove all queued tasks.
all: Remove all tasks.
Returns:
Revisions (or names) which were removed.
"""

if _all or queued:
return self.clear()

removed: List[str] = []
to_remove: Dict[str, ExpStashEntry] = {}
queue_entries = self.match_queue_entry_by_name(
Expand All @@ -152,18 +164,15 @@ def remove(self, revs: Collection[str]) -> List[str]:
]
removed.append(name)

self._remove_revs(to_remove)
self._remove_revs(to_remove, self.stash)
return removed

def clear(self) -> List[str]:
"""Remove all entries from the queue.
Returns:
Revisions which were removed.
"""
def clear(self, **kwargs) -> List[str]:
"""Remove all entries from the queue."""
stash_revs = self.stash.stash_revs
removed = list(stash_revs)
self._remove_revs(stash_revs)
self._remove_revs(stash_revs, self.stash)

return removed

def status(self) -> List[Dict[str, Any]]:
Expand All @@ -176,27 +185,13 @@ def _get_timestamp(rev: str) -> datetime:
commit = self.scm.resolve_commit(rev)
return datetime.fromtimestamp(commit.commit_time)

failed_revs: Dict[str, ExpStashEntry] = (
dict(self.failed_stash.stash_revs)
if self.failed_stash is not None
else {}
)

def _format_entry(
entry: QueueEntry,
status: str = "Unknown",
result: Optional[ExecutorResult] = None,
) -> Dict[str, Any]:
name = entry.name
# NOTE: We fallback to Unknown status for experiments
# generated in prior (incompatible) DVC versions
if result is None and entry.stash_rev in failed_revs:
status = "Failed"
elif result is not None:
if result.exp_hash:
if result.ref_info:
status = "Success"
name = result.ref_info.name
return {
"rev": entry.stash_rev,
"name": name,
Expand All @@ -213,12 +208,17 @@ def _format_entry(
for queue_entry in self.iter_queued()
)
result.extend(
_format_entry(queue_entry, result=exp_result)
for queue_entry, exp_result in self.iter_done()
_format_entry(queue_entry, status="Failed")
for queue_entry, _ in self.iter_failed()
)
result.extend(
_format_entry(queue_entry, status="Success")
for queue_entry, _ in self.iter_success()
)
return result

def _remove_revs(self, stash_revs: Mapping[str, ExpStashEntry]):
@staticmethod
def _remove_revs(stash_revs: Mapping[str, ExpStashEntry], stash: ExpStash):
"""Remove the specified entries from the queue by stash revision."""
for index in sorted(
(
Expand All @@ -228,7 +228,7 @@ def _remove_revs(self, stash_revs: Mapping[str, ExpStashEntry]):
),
reverse=True,
):
self.stash.drop(index)
stash.drop(index)

@abstractmethod
def iter_queued(self) -> Generator[QueueEntry, None, None]:
Expand All @@ -238,9 +238,18 @@ def iter_queued(self) -> Generator[QueueEntry, None, None]:
def iter_active(self) -> Generator[QueueEntry, None, None]:
"""Iterate over items which are being actively processed."""

@abstractmethod
def iter_done(self) -> Generator[QueueDoneResult, None, None]:
"""Iterate over items which been processed."""

@abstractmethod
def iter_success(self) -> Generator[QueueDoneResult, None, None]:
"""Iterate over items which been success."""

@abstractmethod
def iter_failed(self) -> Generator[QueueDoneResult, None, None]:
"""Iterate over items which been failed."""

@abstractmethod
def reproduce(self) -> Mapping[str, Mapping[str, str]]:
"""Reproduce queued experiments sequentially."""
Expand Down Expand Up @@ -460,8 +469,8 @@ def _stash_commit_deps(self, *args, **kwargs):
data_only=True,
)

@staticmethod
def _stash_msg(
self,
rev: str,
baseline_rev: str,
branch: Optional[str] = None,
Expand Down Expand Up @@ -503,7 +512,8 @@ def _pack_args(self, *args, **kwargs) -> None:
)
self.scm.add(self.args_file)

def _format_new_params_msg(self, new_params, config_path):
@staticmethod
def _format_new_params_msg(new_params, config_path):
"""Format an error message for when new parameters are identified"""
new_param_count = len(new_params)
pluralise = "s are" if new_param_count > 1 else " is"
Expand Down
34 changes: 26 additions & 8 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,6 @@ def put(self, *args, **kwargs) -> QueueEntry:
def get(self) -> QueueGetResult:
raise NotImplementedError

def _remove_revs(self, stash_revs: Mapping[str, ExpStashEntry]):
try:
for msg, queue_entry in self._iter_queued():
if queue_entry.stash_rev in stash_revs:
self.celery.reject(msg.delivery_tag)
finally:
super()._remove_revs(stash_revs)

def iter_queued(self) -> Generator[QueueEntry, None, None]:
for _, entry in self._iter_queued():
yield entry
Expand Down Expand Up @@ -225,6 +217,22 @@ def iter_done(self) -> Generator[QueueDoneResult, None, None]:
for _, entry in self._iter_done_tasks():
yield QueueDoneResult(entry, self.get_result(entry))

def iter_success(self) -> Generator[QueueDoneResult, None, None]:
for queue_entry, exp_result in self.iter_done():
if exp_result and exp_result.exp_hash and exp_result.ref_info:
yield QueueDoneResult(queue_entry, exp_result)

def iter_failed(self) -> Generator[QueueDoneResult, None, None]:
failed_revs: Dict[str, ExpStashEntry] = (
dict(self.failed_stash.stash_revs)
if self.failed_stash is not None
else {}
)

for queue_entry, exp_result in self.iter_done():
if exp_result is None and queue_entry.stash_rev in failed_revs:
yield QueueDoneResult(queue_entry, exp_result)

def reproduce(self) -> Mapping[str, Mapping[str, str]]:
raise NotImplementedError

Expand Down Expand Up @@ -341,3 +349,13 @@ def worker_status(self) -> Dict:
status = self.celery.control.inspect().active() or {}
logger.debug(f"Worker status: {status}")
return status

def clear(self, *args, **kwargs):
from .remove import clear

return clear(self, *args, **kwargs)

def remove(self, *args, **kwargs):
from .remove import remove

return remove(self, *args, **kwargs)
Loading

0 comments on commit 6dddb31

Please sign in to comment.