Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize listruns #537

Merged
merged 16 commits into from
Sep 24, 2021
107 changes: 66 additions & 41 deletions apstools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from bluesky.callbacks.best_effort import BestEffortCallback
from collections import defaultdict
from collections import OrderedDict
from databroker._drivers.mongo_normalized import BlueskyMongoCatalog
from dataclasses import dataclass
from email.mime.text import MIMEText
from event_model import NumpyEncoder
Expand All @@ -95,6 +96,7 @@
import sys
import threading
import time
import typing
import warnings
import zipfile

Expand Down Expand Up @@ -860,7 +862,6 @@ class ListRuns:
~_get_by_key
~_check_cat
~_apply_search_filters
~_sorter
~_check_keys

"""
Expand All @@ -875,6 +876,7 @@ class ListRuns:
sortby: str = "time"
timefmt: str = "%Y-%m-%d %H:%M:%S"
until: str = None
ids: "typing.Any" = None

_default_keys = "scan_id time plan_name detectors"

Expand Down Expand Up @@ -927,6 +929,7 @@ def _check_cat(self):
self.cat = getCatalog()

def _apply_search_filters(self):
"""Search for runs from the catalog."""
since = self.since or FIRST_DATA
until = self.until or LAST_DATA
self._check_cat()
Expand All @@ -938,29 +941,55 @@ def _apply_search_filters(self):

def parse_runs(self):
"""Parse the runs for the given metadata keys. Return a dict."""
self._check_cat()
self._check_keys()
cat = self._apply_search_filters()

def _sort(uid):
"""Sort runs in desired order based on metadata key."""
md = self.cat[uid].metadata
for doc in "start stop".split():
if md[doc] and self.sortby in md[doc]:
return md[doc][self.sortby] or self.missing
return self.missing

num_runs_requested = min(abs(self.num), len(cat))
dd = {
key: [
self._get_by_key(run.metadata, key)
for _, run in sorted(
cat.items(), key=self._sorter, reverse=self.reverse
)[:num_runs_requested]
]
for key in self.keys
}
return dd

def _sorter(self, args):
"""Sort runs in desired order based on metadata key."""
# args : (uid, run)
md = args[1].metadata
for doc in "start stop".split():
if md[doc] and self.sortby in md[doc]:
return md[doc][self.sortby] or self.missing
return self.missing
results = {k: [] for k in self.keys}
sequence = () # iterable of run uids

if self.ids is not None:
sequence = []
for k in self.ids:
try:
run = cat[k]
sequence.append(k)
except Exception as exc:
logger.warning(
"Could not find run %s in search of catalog %s: %s",
k,
self.cat.name,
exc,
)
else:
if isinstance(cat, BlueskyMongoCatalog) and self.sortby == "time":
if self.reverse:
# the default rendering: from MongoDB in reverse time order
sequence = iter(cat)
else:
# by increasing time order
sequence = [uid for uid in cat][::-1]
else:
# full search in Python
sequence = sorted(cat.keys(), key=_sort, reverse=self.reverse)

count = 0
for uid in sequence:
run = cat[uid]
for k in self.keys:
results[k].append(self._get_by_key(run.metadata, k))
count += 1
if count >= num_runs_requested:
break
return results

def _check_keys(self):
"""Check that self.keys is a list of strings."""
Expand All @@ -971,13 +1000,11 @@ def _check_keys(self):

def to_dataframe(self):
"""Output as pandas DataFrame object"""
self._check_keys()
dd = self.parse_runs()
return pd.DataFrame(dd, columns=self.keys)

def to_table(self, fmt=None):
"""Output as pyRestTable object."""
self._check_keys()
dd = self.parse_runs()

table = pyRestTable.Table()
Expand All @@ -1002,6 +1029,7 @@ def listruns(
tablefmt="dataframe",
timefmt="%Y-%m-%d %H:%M:%S",
until=None,
ids=None,
**query,
):
"""
Expand All @@ -1028,6 +1056,15 @@ def listruns(
*str*:
Test to report when a value is not available.
(default: ``""``)
ids
*[int]* or *[str]*:
List of ``uid`` or ``scan_id`` value(s).
Can mix different kinds in the same list.
Also can specify offsets (e.g., ``-1``).
According to the rules for ``databroker`` catalogs,
a string is a ``uid`` (partial representations allowed),
an int is ``scan_id`` if positive or an offset if negative.
(default: ``None``)
num
*int* :
Make the table include the ``num`` most recent runs.
Expand Down Expand Up @@ -1112,30 +1149,18 @@ def listruns(
sortby=sortby,
timefmt=timefmt,
until=until,
ids=ids,
)

# fmt: off
table_format_function = dict(
dataframe=lr.to_dataframe,
table=lr.to_table,
).get(tablefmt or "dataframe", lr.to_table)
# fmt: on
obj = table_format_function()
tablefmt = tablefmt or "dataframe"
if tablefmt == "dataframe":
obj = lr.to_dataframe()
else:
obj = lr.to_table()

do_print = False
if printing:
if lr.cat is not None:
print(f"catalog: {lr.cat.name}")
if printing == "smart":
try:
get_ipython() # console or notebook will handle
except NameError:
do_print = True # we print it here
if tablefmt == "table":
do_print = True
else:
do_print = True
if do_print:
print(obj)
return
return obj
Expand Down
105 changes: 69 additions & 36 deletions tests/test_listruns_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def lr():
def test_getDefaultCatalog_none_found():
with pytest.raises(ValueError) as exinfo:
APS_utils.getDefaultCatalog()
assert "Multiple catalog configurations available." in str(
exinfo.value
)
assert "Multiple catalog configurations available." in str(exinfo.value)


def test_getDefaultCatalog(cat):
Expand All @@ -44,9 +42,7 @@ def test_getDefaultCatalog_many_found(cat):

with pytest.raises(ValueError) as exinfo:
APS_utils.getDefaultCatalog()
assert "Multiple catalog objects available." in str(
exinfo.value
)
assert "Multiple catalog objects available." in str(exinfo.value)


def test_getCatalog():
Expand Down Expand Up @@ -183,24 +179,47 @@ def test_ListRuns_to_num(lr):
assert 0 <= len(dd["time"]) <= lr.num


def test_ListRuns_query(lr):
def test_ListRuns_query_count(lr):
lr.query = dict(plan_name="count")
dd = lr.parse_runs()
assert 0 <= len(dd["time"]) <= lr.num
for v in dd["plan_name"]:
assert v == "count"
# TODO: more?


def test_ListRuns_reverse(lr):
# include with some data missing or None
# default
# fmt: off
@pytest.mark.parametrize(
"nruns, query",
[
(27, dict(scan_id={"$lt": 20})),
(26, dict(plan_name="count")),
(0, dict(scan_id={"$lt": 20}, plan_name="count")),
(19, dict(scan_id={"$gte": 100})),
(19, dict(scan_id={"$gte": 100}, plan_name="count")),
],
)
def test_ListRuns_query_parametrize(nruns, query, lr):
lr.num = 100
lr.query = query
dd = lr.parse_runs()
assert dd["time"] == sorted(dd["time"], reverse=True)

lr.reverse = False
assert len(dd["time"]) == nruns
# fmt: on


# fmt: off
@pytest.mark.parametrize(
"reverse",
[
(True),
(False),
],
)
def test_ListRuns_reverse(reverse, lr):
# include with some data missing or None
lr.reverse = reverse
dd = lr.parse_runs()
assert dd["time"] == sorted(dd["time"], reverse=False)
assert dd["time"] == sorted(dd["time"], reverse=reverse)
# fmt: on


def test_ListRuns_since(lr):
Expand Down Expand Up @@ -250,9 +269,7 @@ def test_ListRuns_timefmt(lr):
lr.timefmt = None
with pytest.raises(TypeError) as exinfo:
lr.parse_runs()
assert "strftime() argument 1 must be str, not None" in str(
exinfo.value
)
assert "strftime() argument 1 must be str, not None" in str(exinfo.value)

# wrong format
lr.timefmt = "no such format"
Expand Down Expand Up @@ -311,22 +328,38 @@ def test_ListRuns_until(lr):
assert v < lr.until


def test_listruns(cat):
lr = APS_utils.listruns(cat=cat, printing=False)
assert lr is not None
assert isinstance(lr, pd.DataFrame)

lr = APS_utils.listruns(
cat=cat, tablefmt="dataframe", printing=False)
assert lr is not None
assert isinstance(lr, pd.DataFrame)

lr = APS_utils.listruns(
cat=cat, tablefmt="table", printing=False)
# fmt: off
@pytest.mark.parametrize(
"tablefmt, structure",
[
(None, pd.DataFrame),
("dataframe", pd.DataFrame),
("table", str),
("no such format", str),
],
)
def test_listruns_tablefmt(tablefmt, structure, cat):
lr = APS_utils.listruns(cat=cat, tablefmt=tablefmt, printing=False)
assert lr is not None
assert isinstance(lr, str)

lr = APS_utils.listruns(
cat=cat, tablefmt="no such format", printing=False)
assert lr is not None
assert isinstance(lr, str)
assert isinstance(lr, structure)
# fmt: on


# fmt: off
@pytest.mark.parametrize(
"ids, nresults",
[
([130, 131, 132], 3),
([1234], 0),
(["bb7e0", ], 1),
([-2, -4, -10, -8], 4),
([], 0),
([-2, 131, "3e89a", 1234567, "1234567"], 3),
],
)
def test_ListRuns_ids(ids, nresults, lr):
# include with some data missing or None
lr.ids = ids
dd = lr.parse_runs()
assert len(dd["time"]) == nresults
# fmt: on