Skip to content

Commit

Permalink
Merge pull request #537 from BCDA-APS/526-listruns
Browse files Browse the repository at this point in the history
optimize listruns
  • Loading branch information
prjemian authored Sep 24, 2021
2 parents 40180d1 + 3ddeae3 commit 5d03adc
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 77 deletions.
107 changes: 66 additions & 41 deletions apstools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from bluesky.callbacks.best_effort import BestEffortCallback
from collections import defaultdict
from collections import OrderedDict
from databroker._drivers.mongo_normalized import BlueskyMongoCatalog
from dataclasses import dataclass
from email.mime.text import MIMEText
from event_model import NumpyEncoder
Expand All @@ -95,6 +96,7 @@
import sys
import threading
import time
import typing
import warnings
import zipfile

Expand Down Expand Up @@ -860,7 +862,6 @@ class ListRuns:
~_get_by_key
~_check_cat
~_apply_search_filters
~_sorter
~_check_keys
"""
Expand All @@ -875,6 +876,7 @@ class ListRuns:
sortby: str = "time"
timefmt: str = "%Y-%m-%d %H:%M:%S"
until: str = None
ids: "typing.Any" = None

_default_keys = "scan_id time plan_name detectors"

Expand Down Expand Up @@ -927,6 +929,7 @@ def _check_cat(self):
self.cat = getCatalog()

def _apply_search_filters(self):
"""Search for runs from the catalog."""
since = self.since or FIRST_DATA
until = self.until or LAST_DATA
self._check_cat()
Expand All @@ -938,29 +941,55 @@ def _apply_search_filters(self):

def parse_runs(self):
"""Parse the runs for the given metadata keys. Return a dict."""
self._check_cat()
self._check_keys()
cat = self._apply_search_filters()

def _sort(uid):
"""Sort runs in desired order based on metadata key."""
md = self.cat[uid].metadata
for doc in "start stop".split():
if md[doc] and self.sortby in md[doc]:
return md[doc][self.sortby] or self.missing
return self.missing

num_runs_requested = min(abs(self.num), len(cat))
dd = {
key: [
self._get_by_key(run.metadata, key)
for _, run in sorted(
cat.items(), key=self._sorter, reverse=self.reverse
)[:num_runs_requested]
]
for key in self.keys
}
return dd

def _sorter(self, args):
"""Sort runs in desired order based on metadata key."""
# args : (uid, run)
md = args[1].metadata
for doc in "start stop".split():
if md[doc] and self.sortby in md[doc]:
return md[doc][self.sortby] or self.missing
return self.missing
results = {k: [] for k in self.keys}
sequence = () # iterable of run uids

if self.ids is not None:
sequence = []
for k in self.ids:
try:
run = cat[k]
sequence.append(k)
except Exception as exc:
logger.warning(
"Could not find run %s in search of catalog %s: %s",
k,
self.cat.name,
exc,
)
else:
if isinstance(cat, BlueskyMongoCatalog) and self.sortby == "time":
if self.reverse:
# the default rendering: from MongoDB in reverse time order
sequence = iter(cat)
else:
# by increasing time order
sequence = [uid for uid in cat][::-1]
else:
# full search in Python
sequence = sorted(cat.keys(), key=_sort, reverse=self.reverse)

count = 0
for uid in sequence:
run = cat[uid]
for k in self.keys:
results[k].append(self._get_by_key(run.metadata, k))
count += 1
if count >= num_runs_requested:
break
return results

def _check_keys(self):
"""Check that self.keys is a list of strings."""
Expand All @@ -971,13 +1000,11 @@ def _check_keys(self):

def to_dataframe(self):
"""Output as pandas DataFrame object"""
self._check_keys()
dd = self.parse_runs()
return pd.DataFrame(dd, columns=self.keys)

def to_table(self, fmt=None):
"""Output as pyRestTable object."""
self._check_keys()
dd = self.parse_runs()

table = pyRestTable.Table()
Expand All @@ -1002,6 +1029,7 @@ def listruns(
tablefmt="dataframe",
timefmt="%Y-%m-%d %H:%M:%S",
until=None,
ids=None,
**query,
):
"""
Expand All @@ -1028,6 +1056,15 @@ def listruns(
*str*:
Test to report when a value is not available.
(default: ``""``)
ids
*[int]* or *[str]*:
List of ``uid`` or ``scan_id`` value(s).
Can mix different kinds in the same list.
Also can specify offsets (e.g., ``-1``).
According to the rules for ``databroker`` catalogs,
a string is a ``uid`` (partial representations allowed),
an int is ``scan_id`` if positive or an offset if negative.
(default: ``None``)
num
*int* :
Make the table include the ``num`` most recent runs.
Expand Down Expand Up @@ -1112,30 +1149,18 @@ def listruns(
sortby=sortby,
timefmt=timefmt,
until=until,
ids=ids,
)

# fmt: off
table_format_function = dict(
dataframe=lr.to_dataframe,
table=lr.to_table,
).get(tablefmt or "dataframe", lr.to_table)
# fmt: on
obj = table_format_function()
tablefmt = tablefmt or "dataframe"
if tablefmt == "dataframe":
obj = lr.to_dataframe()
else:
obj = lr.to_table()

do_print = False
if printing:
if lr.cat is not None:
print(f"catalog: {lr.cat.name}")
if printing == "smart":
try:
get_ipython() # console or notebook will handle
except NameError:
do_print = True # we print it here
if tablefmt == "table":
do_print = True
else:
do_print = True
if do_print:
print(obj)
return
return obj
Expand Down
105 changes: 69 additions & 36 deletions tests/test_listruns_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def lr():
def test_getDefaultCatalog_none_found():
with pytest.raises(ValueError) as exinfo:
APS_utils.getDefaultCatalog()
assert "Multiple catalog configurations available." in str(
exinfo.value
)
assert "Multiple catalog configurations available." in str(exinfo.value)


def test_getDefaultCatalog(cat):
Expand All @@ -44,9 +42,7 @@ def test_getDefaultCatalog_many_found(cat):

with pytest.raises(ValueError) as exinfo:
APS_utils.getDefaultCatalog()
assert "Multiple catalog objects available." in str(
exinfo.value
)
assert "Multiple catalog objects available." in str(exinfo.value)


def test_getCatalog():
Expand Down Expand Up @@ -183,24 +179,47 @@ def test_ListRuns_to_num(lr):
assert 0 <= len(dd["time"]) <= lr.num


def test_ListRuns_query(lr):
def test_ListRuns_query_count(lr):
lr.query = dict(plan_name="count")
dd = lr.parse_runs()
assert 0 <= len(dd["time"]) <= lr.num
for v in dd["plan_name"]:
assert v == "count"
# TODO: more?


def test_ListRuns_reverse(lr):
# include with some data missing or None
# default
# fmt: off
@pytest.mark.parametrize(
"nruns, query",
[
(27, dict(scan_id={"$lt": 20})),
(26, dict(plan_name="count")),
(0, dict(scan_id={"$lt": 20}, plan_name="count")),
(19, dict(scan_id={"$gte": 100})),
(19, dict(scan_id={"$gte": 100}, plan_name="count")),
],
)
def test_ListRuns_query_parametrize(nruns, query, lr):
lr.num = 100
lr.query = query
dd = lr.parse_runs()
assert dd["time"] == sorted(dd["time"], reverse=True)

lr.reverse = False
assert len(dd["time"]) == nruns
# fmt: on


# fmt: off
@pytest.mark.parametrize(
"reverse",
[
(True),
(False),
],
)
def test_ListRuns_reverse(reverse, lr):
# include with some data missing or None
lr.reverse = reverse
dd = lr.parse_runs()
assert dd["time"] == sorted(dd["time"], reverse=False)
assert dd["time"] == sorted(dd["time"], reverse=reverse)
# fmt: on


def test_ListRuns_since(lr):
Expand Down Expand Up @@ -250,9 +269,7 @@ def test_ListRuns_timefmt(lr):
lr.timefmt = None
with pytest.raises(TypeError) as exinfo:
lr.parse_runs()
assert "strftime() argument 1 must be str, not None" in str(
exinfo.value
)
assert "strftime() argument 1 must be str, not None" in str(exinfo.value)

# wrong format
lr.timefmt = "no such format"
Expand Down Expand Up @@ -311,22 +328,38 @@ def test_ListRuns_until(lr):
assert v < lr.until


def test_listruns(cat):
lr = APS_utils.listruns(cat=cat, printing=False)
assert lr is not None
assert isinstance(lr, pd.DataFrame)

lr = APS_utils.listruns(
cat=cat, tablefmt="dataframe", printing=False)
assert lr is not None
assert isinstance(lr, pd.DataFrame)

lr = APS_utils.listruns(
cat=cat, tablefmt="table", printing=False)
# fmt: off
@pytest.mark.parametrize(
"tablefmt, structure",
[
(None, pd.DataFrame),
("dataframe", pd.DataFrame),
("table", str),
("no such format", str),
],
)
def test_listruns_tablefmt(tablefmt, structure, cat):
lr = APS_utils.listruns(cat=cat, tablefmt=tablefmt, printing=False)
assert lr is not None
assert isinstance(lr, str)

lr = APS_utils.listruns(
cat=cat, tablefmt="no such format", printing=False)
assert lr is not None
assert isinstance(lr, str)
assert isinstance(lr, structure)
# fmt: on


# fmt: off
@pytest.mark.parametrize(
"ids, nresults",
[
([130, 131, 132], 3),
([1234], 0),
(["bb7e0", ], 1),
([-2, -4, -10, -8], 4),
([], 0),
([-2, 131, "3e89a", 1234567, "1234567"], 3),
],
)
def test_ListRuns_ids(ids, nresults, lr):
# include with some data missing or None
lr.ids = ids
dd = lr.parse_runs()
assert len(dd["time"]) == nresults
# fmt: on

0 comments on commit 5d03adc

Please sign in to comment.