Skip to content

Commit

Permalink
Merge pull request #1083 from lsst/tickets/DM-46401
Browse files Browse the repository at this point in the history
DM-46401: fix support for multiple instruments (etc) in where expressions
  • Loading branch information
TallJimbo committed Sep 23, 2024
2 parents 568f80b + 622876d commit 8c62bcf
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 7 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-46401.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix support for multiple-instrument (and multiple-skymap) `where` expressions in the new query system.
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/direct_query_driver/_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query
where_governors: set[str] = set()
result.predicate.gather_governors(where_governors)
for governor in where_governors:
if governor not in result.constraint_data_id:
if governor not in result.constraint_data_id and governor not in result.governors_referenced:
if governor in self._default_data_id.dimensions:
result.constraint_data_id[governor] = self._default_data_id[governor]
else:
Expand Down
31 changes: 27 additions & 4 deletions python/lsst/daf/butler/direct_query_driver/_query_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,19 @@ class QueryJoinsPlan:
rows.
"""

governors_referenced: set[str] = dataclasses.field(default_factory=set)
"""Governor dimensions referenced directly in the predicate, but not
necessarily constrained to the same value in all logic branches.
"""

def __post_init__(self) -> None:
self.predicate.gather_required_columns(self.columns)
# Extract the data ID implied by the predicate; we can use the governor
# dimensions in that to constrain the collections we search for
# datasets later.
self.predicate.visit(_DataIdExtractionVisitor(self.constraint_data_id, self.messages))
self.predicate.visit(
_DataIdExtractionVisitor(self.constraint_data_id, self.messages, self.governors_referenced)
)

def iter_mandatory(self) -> Iterator[DimensionElement]:
"""Return an iterator over the dimension elements that must be joined
Expand Down Expand Up @@ -304,11 +311,17 @@ class _DataIdExtractionVisitor(
Dictionary to populate in place.
messages : `list` [ `str` ]
List of diagnostic messages to populate in place.
governor_references : `set` [ `str` ]
Set of the names of governor dimension names that were referenced
directly. This includes dimensions that were constrained to different
values in different logic branches, and hence not included in
``data_id``.
"""

def __init__(self, data_id: dict[str, DataIdValue], messages: list[str]):
def __init__(self, data_id: dict[str, DataIdValue], messages: list[str], governor_references: set[str]):
self.data_id = data_id
self.messages = messages
self.governor_references = governor_references

def visit_comparison(
self,
Expand All @@ -317,6 +330,8 @@ def visit_comparison(
b: qt.ColumnExpression,
flags: PredicateVisitFlags,
) -> None:
k_a, v_a = a.visit(self)
k_b, v_b = b.visit(self)
if flags & PredicateVisitFlags.HAS_OR_SIBLINGS:
return None
if flags & PredicateVisitFlags.INVERTED:
Expand All @@ -326,8 +341,6 @@ def visit_comparison(
return None
if operator != "==":
return None
k_a, v_a = a.visit(self)
k_b, v_b = b.visit(self)
if k_a is not None and v_b is not None:
key = k_a
value = v_b
Expand All @@ -341,18 +354,28 @@ def visit_comparison(
return None

def visit_binary_expression(self, expression: qt.BinaryExpression) -> tuple[None, None]:
expression.a.visit(self)
expression.b.visit(self)
return None, None

def visit_unary_expression(self, expression: qt.UnaryExpression) -> tuple[None, None]:
expression.operand.visit(self)
return None, None

def visit_literal(self, expression: qt.ColumnLiteral) -> tuple[None, Any]:
return None, expression.get_literal_value()

def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> tuple[str, None]:
if expression.dimension.governor is expression.dimension:
self.governor_references.add(expression.dimension.name)
return expression.dimension.name, None

def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> tuple[None, None]:
if (
expression.element.governor is expression.element
and expression.field in expression.element.alternate_keys.names
):
self.governor_references.add(expression.element.name)
return None, None

def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> tuple[None, None]:
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/queries/tree/_column_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def gather_required_columns(self, columns: ColumnSet) -> None:
columns.update_dimensions(self.dimension.minimal_group)

def gather_governors(self, governors: set[str]) -> None:
if self.dimension.governor is not None:
if self.dimension.governor is not None and self.dimension.governor is not self.dimension:
governors.add(self.dimension.governor.name)

@property
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/queries/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class PredicateVisitor(Generic[_A, _O, _L]):
-----
The concrete `PredicateLeaf` types are only semi-public (they appear in
the serialized form of a `Predicate`, but their types should not generally
be referenced directly outside of the module in which they are defined.
be referenced directly outside of the module in which they are defined).
As a result, visiting these objects unpacks their attributes into the
visit method arguments.
"""
Expand Down
33 changes: 33 additions & 0 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,6 +1859,39 @@ def test_dataset_queries(self) -> None:
self.assertEqual(rows[0]["visit"], 1)
self.assertEqual(rows[0]["dt.collection"], "run1")

def test_multiple_instrument_queries(self) -> None:
"""Test that multiple-instrument queries are not rejected as having
governor dimension ambiguities.
"""
butler = self.make_butler("base.yaml")
butler.registry.insertDimensionData("instrument", {"name": "Cam2"})
self.assertCountEqual(
butler.query_data_ids(["detector"], where="instrument='Cam1' OR instrument='Cam2'"),
[
DataCoordinate.standardize(instrument="Cam1", detector=n, universe=butler.dimensions)
for n in range(1, 5)
],
)
self.assertCountEqual(
butler.query_data_ids(
["detector"],
where="(instrument='Cam1' OR instrument='Cam2') AND visit.region OVERLAPS region",
bind={"region": Region.from_ivoa_pos("CIRCLE 320. -0.25 10.")},
explain=False,
),
# No visits in this test dataset means no result, but the point of
# the test is just that the query can be constructed at all.
[],
)
self.assertCountEqual(
butler.query_data_ids(
["instrument"],
where="(instrument='Cam1' AND detector=2) OR (instrument='Cam2' AND detector=500)",
explain=False,
),
[DataCoordinate.standardize(instrument="Cam1", universe=butler.dimensions)],
)


def _get_exposure_ids_from_dimension_records(dimension_records: Iterable[DimensionRecord]) -> list[int]:
output = []
Expand Down

0 comments on commit 8c62bcf

Please sign in to comment.