diff --git a/apstools/utils.py b/apstools/utils.py index 705813336..850c109c7 100644 --- a/apstools/utils.py +++ b/apstools/utils.py @@ -73,6 +73,7 @@ from bluesky.callbacks.best_effort import BestEffortCallback from collections import defaultdict from collections import OrderedDict +from databroker._drivers.mongo_normalized import BlueskyMongoCatalog from dataclasses import dataclass from email.mime.text import MIMEText from event_model import NumpyEncoder @@ -95,6 +96,7 @@ import sys import threading import time +import typing import warnings import zipfile @@ -860,7 +862,6 @@ class ListRuns: ~_get_by_key ~_check_cat ~_apply_search_filters - ~_sorter ~_check_keys """ @@ -875,6 +876,7 @@ class ListRuns: sortby: str = "time" timefmt: str = "%Y-%m-%d %H:%M:%S" until: str = None + ids: "typing.Any" = None _default_keys = "scan_id time plan_name detectors" @@ -927,6 +929,7 @@ def _check_cat(self): self.cat = getCatalog() def _apply_search_filters(self): + """Search for runs from the catalog.""" since = self.since or FIRST_DATA until = self.until or LAST_DATA self._check_cat() @@ -938,29 +941,55 @@ def _apply_search_filters(self): def parse_runs(self): """Parse the runs for the given metadata keys. Return a dict.""" - self._check_cat() self._check_keys() cat = self._apply_search_filters() + + def _sort(uid): + """Sort runs in desired order based on metadata key.""" + md = self.cat[uid].metadata + for doc in "start stop".split(): + if md[doc] and self.sortby in md[doc]: + return md[doc][self.sortby] or self.missing + return self.missing + num_runs_requested = min(abs(self.num), len(cat)) - dd = { - key: [ - self._get_by_key(run.metadata, key) - for _, run in sorted( - cat.items(), key=self._sorter, reverse=self.reverse - )[:num_runs_requested] - ] - for key in self.keys - } - return dd - - def _sorter(self, args): - """Sort runs in desired order based on metadata key.""" - # args : (uid, run) - md = args[1].metadata - for doc in "start stop".split(): - if md[doc] and self.sortby in md[doc]: - return md[doc][self.sortby] or self.missing - return self.missing + results = {k: [] for k in self.keys} + sequence = () # iterable of run uids + + if self.ids is not None: + sequence = [] + for k in self.ids: + try: + run = cat[k] + sequence.append(k) + except Exception as exc: + logger.warning( + "Could not find run %s in search of catalog %s: %s", + k, + self.cat.name, + exc, + ) + else: + if isinstance(cat, BlueskyMongoCatalog) and self.sortby == "time": + if self.reverse: + # the default rendering: from MongoDB in reverse time order + sequence = iter(cat) + else: + # by increasing time order + sequence = [uid for uid in cat][::-1] + else: + # full search in Python + sequence = sorted(cat.keys(), key=_sort, reverse=self.reverse) + + count = 0 + for uid in sequence: + run = cat[uid] + for k in self.keys: + results[k].append(self._get_by_key(run.metadata, k)) + count += 1 + if count >= num_runs_requested: + break + return results def _check_keys(self): """Check that self.keys is a list of strings.""" @@ -971,13 +1000,11 @@ def _check_keys(self): def to_dataframe(self): """Output as pandas DataFrame object""" - self._check_keys() dd = self.parse_runs() return pd.DataFrame(dd, columns=self.keys) def to_table(self, fmt=None): """Output as pyRestTable object.""" - self._check_keys() dd = self.parse_runs() table = pyRestTable.Table() @@ -1002,6 +1029,7 @@ def listruns( tablefmt="dataframe", timefmt="%Y-%m-%d %H:%M:%S", until=None, + ids=None, **query, ): """ @@ -1028,6 +1056,15 @@ def listruns( *str*: Test to report when a value is not available. (default: ``""``) + ids + *[int]* or *[str]*: + List of ``uid`` or ``scan_id`` value(s). + Can mix different kinds in the same list. + Also can specify offsets (e.g., ``-1``). + According to the rules for ``databroker`` catalogs, + a string is a ``uid`` (partial representations allowed), + an int is ``scan_id`` if positive or an offset if negative. + (default: ``None``) num *int* : Make the table include the ``num`` most recent runs. @@ -1112,30 +1149,18 @@ def listruns( sortby=sortby, timefmt=timefmt, until=until, + ids=ids, ) - # fmt: off - table_format_function = dict( - dataframe=lr.to_dataframe, - table=lr.to_table, - ).get(tablefmt or "dataframe", lr.to_table) - # fmt: on - obj = table_format_function() + tablefmt = tablefmt or "dataframe" + if tablefmt == "dataframe": + obj = lr.to_dataframe() + else: + obj = lr.to_table() - do_print = False if printing: if lr.cat is not None: print(f"catalog: {lr.cat.name}") - if printing == "smart": - try: - get_ipython() # console or notebook will handle - except NameError: - do_print = True # we print it here - if tablefmt == "table": - do_print = True - else: - do_print = True - if do_print: print(obj) return return obj diff --git a/tests/test_listruns_class.py b/tests/test_listruns_class.py index 821d7bfe0..b2a242cb4 100644 --- a/tests/test_listruns_class.py +++ b/tests/test_listruns_class.py @@ -25,9 +25,7 @@ def lr(): def test_getDefaultCatalog_none_found(): with pytest.raises(ValueError) as exinfo: APS_utils.getDefaultCatalog() - assert "Multiple catalog configurations available." in str( - exinfo.value - ) + assert "Multiple catalog configurations available." in str(exinfo.value) def test_getDefaultCatalog(cat): @@ -44,9 +42,7 @@ def test_getDefaultCatalog_many_found(cat): with pytest.raises(ValueError) as exinfo: APS_utils.getDefaultCatalog() - assert "Multiple catalog objects available." in str( - exinfo.value - ) + assert "Multiple catalog objects available." in str(exinfo.value) def test_getCatalog(): @@ -183,24 +179,47 @@ def test_ListRuns_to_num(lr): assert 0 <= len(dd["time"]) <= lr.num -def test_ListRuns_query(lr): +def test_ListRuns_query_count(lr): lr.query = dict(plan_name="count") dd = lr.parse_runs() assert 0 <= len(dd["time"]) <= lr.num for v in dd["plan_name"]: assert v == "count" - # TODO: more? -def test_ListRuns_reverse(lr): - # include with some data missing or None - # default +# fmt: off +@pytest.mark.parametrize( + "nruns, query", + [ + (27, dict(scan_id={"$lt": 20})), + (26, dict(plan_name="count")), + (0, dict(scan_id={"$lt": 20}, plan_name="count")), + (19, dict(scan_id={"$gte": 100})), + (19, dict(scan_id={"$gte": 100}, plan_name="count")), + ], +) +def test_ListRuns_query_parametrize(nruns, query, lr): + lr.num = 100 + lr.query = query dd = lr.parse_runs() - assert dd["time"] == sorted(dd["time"], reverse=True) - - lr.reverse = False + assert len(dd["time"]) == nruns +# fmt: on + + +# fmt: off +@pytest.mark.parametrize( + "reverse", + [ + (True), + (False), + ], +) +def test_ListRuns_reverse(reverse, lr): + # include with some data missing or None + lr.reverse = reverse dd = lr.parse_runs() - assert dd["time"] == sorted(dd["time"], reverse=False) + assert dd["time"] == sorted(dd["time"], reverse=reverse) +# fmt: on def test_ListRuns_since(lr): @@ -250,9 +269,7 @@ def test_ListRuns_timefmt(lr): lr.timefmt = None with pytest.raises(TypeError) as exinfo: lr.parse_runs() - assert "strftime() argument 1 must be str, not None" in str( - exinfo.value - ) + assert "strftime() argument 1 must be str, not None" in str(exinfo.value) # wrong format lr.timefmt = "no such format" @@ -311,22 +328,38 @@ def test_ListRuns_until(lr): assert v < lr.until -def test_listruns(cat): - lr = APS_utils.listruns(cat=cat, printing=False) - assert lr is not None - assert isinstance(lr, pd.DataFrame) - - lr = APS_utils.listruns( - cat=cat, tablefmt="dataframe", printing=False) - assert lr is not None - assert isinstance(lr, pd.DataFrame) - - lr = APS_utils.listruns( - cat=cat, tablefmt="table", printing=False) +# fmt: off +@pytest.mark.parametrize( + "tablefmt, structure", + [ + (None, pd.DataFrame), + ("dataframe", pd.DataFrame), + ("table", str), + ("no such format", str), + ], +) +def test_listruns_tablefmt(tablefmt, structure, cat): + lr = APS_utils.listruns(cat=cat, tablefmt=tablefmt, printing=False) assert lr is not None - assert isinstance(lr, str) - - lr = APS_utils.listruns( - cat=cat, tablefmt="no such format", printing=False) - assert lr is not None - assert isinstance(lr, str) + assert isinstance(lr, structure) +# fmt: on + + +# fmt: off +@pytest.mark.parametrize( + "ids, nresults", + [ + ([130, 131, 132], 3), + ([1234], 0), + (["bb7e0", ], 1), + ([-2, -4, -10, -8], 4), + ([], 0), + ([-2, 131, "3e89a", 1234567, "1234567"], 3), + ], +) +def test_ListRuns_ids(ids, nresults, lr): + # include with some data missing or None + lr.ids = ids + dd = lr.parse_runs() + assert len(dd["time"]) == nresults +# fmt: on