Skip to content

Commit c47afb3

Browse files
authored
Merge pull request #1 from splitgraph/es-multicorn-agg-pushdown-poc-cu-1t1wycg
Multicorn aggregation/grouping pushdown support
2 parents c176c1c + 1aecbde commit c47afb3

File tree

2 files changed

+123
-21
lines changed

2 files changed

+123
-21
lines changed

pg_es_fdw/__init__.py

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
from multicorn import ForeignDataWrapper
1111
from multicorn.utils import log_to_postgres as log2pg
1212

13-
from ._es_query import quals_to_es
13+
from ._es_query import _PG_TO_ES_AGG_FUNCS, quals_to_es
1414

1515

1616
class ElasticsearchFDW(ForeignDataWrapper):
1717
""" Elastic Search Foreign Data Wrapper """
1818

1919
@property
2020
def rowid_column(self):
21-
""" Returns a column name which will act as a rowid column for
22-
delete/update operations.
21+
"""Returns a column name which will act as a rowid column for
22+
delete/update operations.
2323
24-
This can be either an existing column name, or a made-up one. This
25-
column name should be subsequently present in every returned
26-
resultset. """
24+
This can be either an existing column name, or a made-up one. This
25+
column name should be subsequently present in every returned
26+
resultset."""
2727

2828
return self._rowid_column
2929

@@ -73,8 +73,8 @@ def __init__(self, options, columns):
7373
self.scroll_id = None
7474

7575
def get_rel_size(self, quals, columns):
76-
""" Helps the planner by returning costs.
77-
Returns a tuple of the form (number of rows, average row width) """
76+
"""Helps the planner by returning costs.
77+
Returns a tuple of the form (number of rows, average row width)"""
7878

7979
try:
8080
query, _ = self._get_query(quals)
@@ -93,23 +93,41 @@ def get_rel_size(self, quals, columns):
9393
)
9494
return (0, 0)
9595

96-
def explain(self, quals, columns, sortkeys=None, verbose=False):
97-
query, _ = self._get_query(quals)
96+
def can_pushdown_upperrel(self):
97+
return {
98+
"groupby_supported": True,
99+
"agg_functions": _PG_TO_ES_AGG_FUNCS,
100+
}
101+
102+
def explain(
103+
self,
104+
quals,
105+
columns,
106+
sortkeys=None,
107+
aggs=None,
108+
group_clauses=None,
109+
verbose=False,
110+
):
111+
query, _ = self._get_query(quals, aggs=aggs, group_clauses=group_clauses)
98112
return [
99113
"Elasticsearch query to %s" % self.client,
100-
"Query: %s" % json.dumps(query),
114+
"Query: %s" % json.dumps(query, indent=4),
101115
]
102116

103-
def execute(self, quals, columns):
117+
def execute(self, quals, columns, aggs=None, group_clauses=None):
104118
""" Execute the query """
105119

106120
try:
107-
query, query_string = self._get_query(quals)
121+
query, query_string = self._get_query(
122+
quals, aggs=aggs, group_clauses=group_clauses
123+
)
124+
125+
is_aggregation = aggs or group_clauses
108126

109127
if query:
110128
response = self.client.search(
111-
size=self.scroll_size,
112-
scroll=self.scroll_duration,
129+
size=self.scroll_size if not is_aggregation else 0,
130+
scroll=self.scroll_duration if not is_aggregation else None,
113131
body=query,
114132
**self.arguments
115133
)
@@ -118,7 +136,13 @@ def execute(self, quals, columns):
118136
size=self.scroll_size, scroll=self.scroll_duration, **self.arguments
119137
)
120138

121-
if not response["hits"]["hits"]:
139+
if not response["hits"]["hits"] and not is_aggregation:
140+
return
141+
142+
if is_aggregation:
143+
yield from self._handle_aggregation_response(
144+
query, response, aggs, group_clauses
145+
)
122146
return
123147

124148
while True:
@@ -221,7 +245,7 @@ def delete(self, document_id):
221245
)
222246
return (0, 0)
223247

224-
def _get_query(self, quals):
248+
def _get_query(self, quals, aggs=None, group_clauses=None):
225249
ignore_columns = []
226250
if self.query_column:
227251
ignore_columns.append(self.query_column)
@@ -230,10 +254,16 @@ def _get_query(self, quals):
230254

231255
query = quals_to_es(
232256
quals,
257+
aggs=aggs,
258+
group_clauses=group_clauses,
233259
ignore_columns=ignore_columns,
234260
column_map={self._rowid_column: "_id"} if self._rowid_column else None,
235261
)
236262

263+
if group_clauses is not None:
264+
# Configure pagination for GROUP BY's
265+
query["aggs"]["group_buckets"]["composite"]["size"] = self.scroll_size
266+
237267
if not self.query_column:
238268
return query, None
239269

@@ -283,3 +313,34 @@ def _convert_response_column(self, column, row_data):
283313
if isinstance(value, (list, dict)):
284314
return json.dumps(value)
285315
return value
316+
317+
def _handle_aggregation_response(self, query, response, aggs, group_clauses):
318+
if group_clauses is None:
319+
result = {}
320+
321+
for agg_name in aggs:
322+
result[agg_name] = response["aggregations"][agg_name]["value"]
323+
yield result
324+
else:
325+
while True:
326+
for bucket in response["aggregations"]["group_buckets"]["buckets"]:
327+
result = {}
328+
329+
for column in group_clauses:
330+
result[column] = bucket["key"][column]
331+
332+
if aggs is not None:
333+
for agg_name in aggs:
334+
result[agg_name] = bucket[agg_name]["value"]
335+
336+
yield result
337+
338+
# Check if we need to paginate results
339+
if "after_key" not in response["aggregations"]["group_buckets"]:
340+
break
341+
342+
query["aggs"]["group_buckets"]["composite"]["after"] = response[
343+
"aggregations"
344+
]["group_buckets"]["after_key"]
345+
346+
response = self.client.search(size=0, body=query, **self.arguments)

pg_es_fdw/_es_query.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@
1212
"<=": "lte",
1313
}
1414

15+
_PG_TO_ES_AGG_FUNCS = {
16+
"avg": "avg",
17+
"max": "max",
18+
"min": "min",
19+
"sum": "sum",
20+
"count": "value_count",
21+
}
22+
1523

1624
def _base_qual_to_es(col, op, value, column_map=None):
1725
if column_map:
@@ -65,14 +73,47 @@ def _qual_to_es(qual, column_map=None):
6573
}
6674
}
6775
else:
68-
return _base_qual_to_es(
69-
qual.field_name, qual.operator, qual.value, column_map
70-
)
76+
return _base_qual_to_es(qual.field_name, qual.operator, qual.value, column_map)
7177

7278

73-
def quals_to_es(quals, ignore_columns=None, column_map=None):
79+
def quals_to_es(
80+
quals, aggs=None, group_clauses=None, ignore_columns=None, column_map=None
81+
):
7482
"""Convert a list of Multicorn quals to an ElasticSearch query"""
7583
ignore_columns = ignore_columns or []
84+
85+
# Aggregation/grouping queries
86+
if aggs is not None:
87+
aggs_query = {
88+
agg_name: {
89+
_PG_TO_ES_AGG_FUNCS[agg_props["function"]]: {
90+
"field": agg_props["column"]
91+
}
92+
}
93+
for agg_name, agg_props in aggs.items()
94+
}
95+
96+
if group_clauses is None:
97+
return {"aggs": aggs_query}
98+
99+
if group_clauses is not None:
100+
group_query = {
101+
"group_buckets": {
102+
"composite": {
103+
"sources": [
104+
{column: {"terms": {"field": column}}}
105+
for column in group_clauses
106+
]
107+
}
108+
}
109+
}
110+
111+
if aggs is not None:
112+
group_query["group_buckets"]["aggregations"] = aggs_query
113+
114+
return {"aggs": group_query}
115+
116+
# Regular query
76117
return {
77118
"query": {
78119
"bool": {

0 commit comments

Comments
 (0)