10
10
from multicorn import ForeignDataWrapper
11
11
from multicorn .utils import log_to_postgres as log2pg
12
12
13
- from ._es_query import quals_to_es
13
+ from ._es_query import _PG_TO_ES_AGG_FUNCS , quals_to_es
14
14
15
15
16
16
class ElasticsearchFDW (ForeignDataWrapper ):
17
17
""" Elastic Search Foreign Data Wrapper """
18
18
19
19
@property
20
20
def rowid_column (self ):
21
- """ Returns a column name which will act as a rowid column for
22
- delete/update operations.
21
+ """Returns a column name which will act as a rowid column for
22
+ delete/update operations.
23
23
24
- This can be either an existing column name, or a made-up one. This
25
- column name should be subsequently present in every returned
26
- resultset. """
24
+ This can be either an existing column name, or a made-up one. This
25
+ column name should be subsequently present in every returned
26
+ resultset."""
27
27
28
28
return self ._rowid_column
29
29
@@ -73,8 +73,8 @@ def __init__(self, options, columns):
73
73
self .scroll_id = None
74
74
75
75
def get_rel_size (self , quals , columns ):
76
- """ Helps the planner by returning costs.
77
- Returns a tuple of the form (number of rows, average row width) """
76
+ """Helps the planner by returning costs.
77
+ Returns a tuple of the form (number of rows, average row width)"""
78
78
79
79
try :
80
80
query , _ = self ._get_query (quals )
@@ -93,23 +93,41 @@ def get_rel_size(self, quals, columns):
93
93
)
94
94
return (0 , 0 )
95
95
96
- def explain (self , quals , columns , sortkeys = None , verbose = False ):
97
- query , _ = self ._get_query (quals )
96
+ def can_pushdown_upperrel (self ):
97
+ return {
98
+ "groupby_supported" : True ,
99
+ "agg_functions" : _PG_TO_ES_AGG_FUNCS ,
100
+ }
101
+
102
+ def explain (
103
+ self ,
104
+ quals ,
105
+ columns ,
106
+ sortkeys = None ,
107
+ aggs = None ,
108
+ group_clauses = None ,
109
+ verbose = False ,
110
+ ):
111
+ query , _ = self ._get_query (quals , aggs = aggs , group_clauses = group_clauses )
98
112
return [
99
113
"Elasticsearch query to %s" % self .client ,
100
- "Query: %s" % json .dumps (query ),
114
+ "Query: %s" % json .dumps (query , indent = 4 ),
101
115
]
102
116
103
- def execute (self , quals , columns ):
117
+ def execute (self , quals , columns , aggs = None , group_clauses = None ):
104
118
""" Execute the query """
105
119
106
120
try :
107
- query , query_string = self ._get_query (quals )
121
+ query , query_string = self ._get_query (
122
+ quals , aggs = aggs , group_clauses = group_clauses
123
+ )
124
+
125
+ is_aggregation = aggs or group_clauses
108
126
109
127
if query :
110
128
response = self .client .search (
111
- size = self .scroll_size ,
112
- scroll = self .scroll_duration ,
129
+ size = self .scroll_size if not is_aggregation else 0 ,
130
+ scroll = self .scroll_duration if not is_aggregation else None ,
113
131
body = query ,
114
132
** self .arguments
115
133
)
@@ -118,7 +136,13 @@ def execute(self, quals, columns):
118
136
size = self .scroll_size , scroll = self .scroll_duration , ** self .arguments
119
137
)
120
138
121
- if not response ["hits" ]["hits" ]:
139
+ if not response ["hits" ]["hits" ] and not is_aggregation :
140
+ return
141
+
142
+ if is_aggregation :
143
+ yield from self ._handle_aggregation_response (
144
+ query , response , aggs , group_clauses
145
+ )
122
146
return
123
147
124
148
while True :
@@ -221,7 +245,7 @@ def delete(self, document_id):
221
245
)
222
246
return (0 , 0 )
223
247
224
- def _get_query (self , quals ):
248
+ def _get_query (self , quals , aggs = None , group_clauses = None ):
225
249
ignore_columns = []
226
250
if self .query_column :
227
251
ignore_columns .append (self .query_column )
@@ -230,10 +254,16 @@ def _get_query(self, quals):
230
254
231
255
query = quals_to_es (
232
256
quals ,
257
+ aggs = aggs ,
258
+ group_clauses = group_clauses ,
233
259
ignore_columns = ignore_columns ,
234
260
column_map = {self ._rowid_column : "_id" } if self ._rowid_column else None ,
235
261
)
236
262
263
+ if group_clauses is not None :
264
+ # Configure pagination for GROUP BY's
265
+ query ["aggs" ]["group_buckets" ]["composite" ]["size" ] = self .scroll_size
266
+
237
267
if not self .query_column :
238
268
return query , None
239
269
@@ -283,3 +313,34 @@ def _convert_response_column(self, column, row_data):
283
313
if isinstance (value , (list , dict )):
284
314
return json .dumps (value )
285
315
return value
316
+
317
+ def _handle_aggregation_response (self , query , response , aggs , group_clauses ):
318
+ if group_clauses is None :
319
+ result = {}
320
+
321
+ for agg_name in aggs :
322
+ result [agg_name ] = response ["aggregations" ][agg_name ]["value" ]
323
+ yield result
324
+ else :
325
+ while True :
326
+ for bucket in response ["aggregations" ]["group_buckets" ]["buckets" ]:
327
+ result = {}
328
+
329
+ for column in group_clauses :
330
+ result [column ] = bucket ["key" ][column ]
331
+
332
+ if aggs is not None :
333
+ for agg_name in aggs :
334
+ result [agg_name ] = bucket [agg_name ]["value" ]
335
+
336
+ yield result
337
+
338
+ # Check if we need to paginate results
339
+ if "after_key" not in response ["aggregations" ]["group_buckets" ]:
340
+ break
341
+
342
+ query ["aggs" ]["group_buckets" ]["composite" ]["after" ] = response [
343
+ "aggregations"
344
+ ]["group_buckets" ]["after_key" ]
345
+
346
+ response = self .client .search (size = 0 , body = query , ** self .arguments )
0 commit comments