Skip to content

Commit 7f9760d

Browse files
authored
Merge pull request #2 from splitgraph/es-agg-pushdown-v2-cu-1z461e4
ES Agg pushdown v2
2 parents c47afb3 + 4dfffd7 commit 7f9760d

File tree

2 files changed

+62
-16
lines changed

2 files changed

+62
-16
lines changed

pg_es_fdw/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from multicorn import ForeignDataWrapper
1111
from multicorn.utils import log_to_postgres as log2pg
1212

13-
from ._es_query import _PG_TO_ES_AGG_FUNCS, quals_to_es
13+
from ._es_query import _PG_TO_ES_AGG_FUNCS, _OPERATORS_SUPPORTED, quals_to_es
1414

1515

1616
class ElasticsearchFDW(ForeignDataWrapper):
@@ -97,6 +97,7 @@ def can_pushdown_upperrel(self):
9797
return {
9898
"groupby_supported": True,
9999
"agg_functions": _PG_TO_ES_AGG_FUNCS,
100+
"operators_supported": _OPERATORS_SUPPORTED,
100101
}
101102

102103
def explain(
@@ -319,6 +320,12 @@ def _handle_aggregation_response(self, query, response, aggs, group_clauses):
319320
result = {}
320321

321322
for agg_name in aggs:
323+
if agg_name == "count.*":
324+
# COUNT(*) is a special case, since it doesn't have a
325+
# corresponding aggregation primitive in ES
326+
result[agg_name] = response["hits"]["total"]["value"]
327+
continue
328+
322329
result[agg_name] = response["aggregations"][agg_name]["value"]
323330
yield result
324331
else:
@@ -331,6 +338,12 @@ def _handle_aggregation_response(self, query, response, aggs, group_clauses):
331338

332339
if aggs is not None:
333340
for agg_name in aggs:
341+
if agg_name == "count.*":
342+
# In general case with GROUP BY clauses COUNT(*)
343+
# is taken from the bucket's doc_count field
344+
result[agg_name] = bucket["doc_count"]
345+
continue
346+
334347
result[agg_name] = bucket[agg_name]["value"]
335348

336349
yield result

pg_es_fdw/_es_query.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
try:
23
from multicorn import ANY
34
except ImportError:
@@ -18,8 +19,31 @@
1819
"min": "min",
1920
"sum": "sum",
2021
"count": "value_count",
22+
"count.*": None # not mapped to a particular function
2123
}
2224

25+
_OPERATORS_SUPPORTED = [">", ">=", "<", "<=", "=", "<>", "!=", "~~"]
26+
27+
28+
def _convert_pattern_match_to_es(expr):
29+
def _pg_es_pattern_map(matchobj):
30+
if matchobj.group(0) == "%":
31+
return "*"
32+
elif matchobj.group(0) == "_":
33+
return "?"
34+
elif matchobj.group(0) == "\%":
35+
return "%"
36+
elif matchobj.group(0) == "\_":
37+
return "_"
38+
elif matchobj.group(0) == "*":
39+
return "\\*"
40+
elif matchobj.group(0) == "?":
41+
return "\\?"
42+
elif matchobj.group(0) == "\\\\":
43+
return "\\"
44+
45+
return re.sub(r'\\\\|\\?%|\\?_|\*|\?', _pg_es_pattern_map, fr"{expr}")
46+
2347

2448
def _base_qual_to_es(col, op, value, column_map=None):
2549
if column_map:
@@ -43,7 +67,7 @@ def _base_qual_to_es(col, op, value, column_map=None):
4367
return {"bool": {"must_not": {"term": {col: value}}}}
4468

4569
if op == "~~":
46-
return {"match": {col: value.replace("%", "*")}}
70+
return {"wildcard": {col: _convert_pattern_match_to_es(value)}}
4771

4872
# For unknown operators, get everything
4973
return {"match_all": {}}
@@ -82,6 +106,18 @@ def quals_to_es(
82106
"""Convert a list of Multicorn quals to an ElasticSearch query"""
83107
ignore_columns = ignore_columns or []
84108

109+
query = {
110+
"query": {
111+
"bool": {
112+
"must": [
113+
_qual_to_es(q, column_map)
114+
for q in quals
115+
if q.field_name not in ignore_columns
116+
]
117+
}
118+
}
119+
}
120+
85121
# Aggregation/grouping queries
86122
if aggs is not None:
87123
aggs_query = {
@@ -91,10 +127,18 @@ def quals_to_es(
91127
}
92128
}
93129
for agg_name, agg_props in aggs.items()
130+
if agg_name != "count.*"
94131
}
95132

96133
if group_clauses is None:
97-
return {"aggs": aggs_query}
134+
if "count.*" in aggs:
135+
# There is no particular COUNT(*) equivalent in ES, instead
136+
# for plain aggregations (e.g. no grouping statements), we need
137+
# to enable the track_total_hits option in order to get an
138+
# accuate number of matched docs.
139+
query["track_total_hits"] = True
140+
141+
query["aggs"] = aggs_query
98142

99143
if group_clauses is not None:
100144
group_query = {
@@ -111,17 +155,6 @@ def quals_to_es(
111155
if aggs is not None:
112156
group_query["group_buckets"]["aggregations"] = aggs_query
113157

114-
return {"aggs": group_query}
158+
query["aggs"] = group_query
115159

116-
# Regular query
117-
return {
118-
"query": {
119-
"bool": {
120-
"must": [
121-
_qual_to_es(q, column_map)
122-
for q in quals
123-
if q.field_name not in ignore_columns
124-
]
125-
}
126-
}
127-
}
160+
return query

0 commit comments

Comments
 (0)