Skip to content

[fix] fixes APPLY / SORTBY / GROUPBY / REDUCE order on FT.AGGREGATE s… #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 28, 2019
8 changes: 7 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ jobs:
name: run tests
command: |
. venv/bin/activate
REDIS_PORT=6379 python test/test.py
REDIS_PORT=6379 python test/test.py

- run:
name: run query builder tests
command: |
. venv/bin/activate
python test/test.py

# no need for store_artifacts on nightly builds

Expand Down
105 changes: 76 additions & 29 deletions redisearch/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,58 @@ def __init__(self, fields, reducers):
self.limit = Limit()

def build_args(self):
ret = [str(len(self.fields))]
ret = ['GROUPBY', str(len(self.fields))]
ret.extend(self.fields)
for reducer in self.reducers:
ret += ['REDUCE', reducer.NAME, str(len(reducer.args))]
ret.extend(reducer.args)
if reducer._alias:
if reducer._alias is not None:
ret += ['AS', reducer._alias]
return ret

class Projection(object):
"""
This object automatically created in the `AggregateRequest.apply()`
"""

def __init__(self, projector, alias=None ):

self.alias = alias
self.projector = projector

def build_args(self):
ret = ['APPLY', self.projector]
if self.alias is not None:
ret += ['AS', self.alias]

return ret

class SortBy(object):
"""
This object automatically created in the `AggregateRequest.sort_by()`
"""

def __init__(self, fields, max=0):
self.fields = fields
self.max = max



def build_args(self):
fields_args = []
for f in self.fields:
if isinstance(f, SortDirection):
fields_args += [f.field, f.DIRSTRING]
else:
fields_args += [f]

ret = ['SORTBY', str(len(fields_args))]
ret.extend(fields_args)
if self.max > 0:
ret += ['MAX', str(self.max)]

return ret


class AggregateRequest(object):
"""
Expand All @@ -127,11 +170,9 @@ def __init__(self, query='*'):
return the object itself, making them useful for chaining.
"""
self._query = query
self._groups = []
self._projections = []
self._aggregateplan = []
self._loadfields = []
self._limit = Limit()
self._sortby = []
self._max = 0
self._with_schema = False
self._verbatim = False
Expand Down Expand Up @@ -162,7 +203,7 @@ def group_by(self, fields, *reducers):
`aggregation` module.
"""
group = Group(fields, reducers)
self._groups.append(group)
self._aggregateplan.extend(group.build_args())

return self

Expand All @@ -177,7 +218,8 @@ def apply(self, **kwexpr):
expression itself, for example `apply(square_root="sqrt(@foo)")`
"""
for alias, expr in kwexpr.items():
self._projections.append([alias, expr])
projection = Projection(expr, alias )
self._aggregateplan.extend(projection.build_args())

return self

Expand Down Expand Up @@ -224,10 +266,7 @@ def limit(self, offset, num):

"""
limit = Limit(offset, num)
if self._groups:
self._groups[-1].limit = limit
else:
self._limit = limit
self._limit = limit
return self

def sort_by(self, *fields, **kwargs):
Expand Down Expand Up @@ -258,16 +297,34 @@ def sort_by(self, *fields, **kwargs):
.sort_by(Desc('@paid'), max=10)
```
"""
self._max = kwargs.get('max', 0)
if isinstance(fields, (string_types, SortDirection)):
fields = [fields]
for f in fields:
if isinstance(f, SortDirection):
self._sortby += [f.field, f.DIRSTRING]
else:
self._sortby.append(f)

max = kwargs.get('max', 0)
sortby = SortBy(fields, max)

self._aggregateplan.extend(sortby.build_args())
return self

def filter(self, expressions):
"""
Specify filter for post-query results using predicates relating to values in the result set.

### Parameters

- **fields**: Fields to group by. This can either be a single string,
or a list of strings.
"""
if isinstance(expressions, (string_types)):
expressions = [expressions]

for expression in expressions:
self._aggregateplan.extend(['FILTER', expression])

return self



def with_schema(self):
"""
If set, the `schema` property will contain a list of `[field, type]`
Expand Down Expand Up @@ -312,18 +369,8 @@ def build_args(self):
ret.append('LOAD')
ret.append(str(len(self._loadfields)))
ret.extend(self._loadfields)
for group in self._groups:
ret += ['GROUPBY'] + group.build_args() + group.limit.build_args()
for alias, projector in self._projections:
ret += ['APPLY', projector]
if alias:
ret += ['AS', alias]

if self._sortby:
ret += ['SORTBY', str(len(self._sortby))]
ret += self._sortby
if self._max:
ret += ['MAX', str(self._max)]

ret.extend(self._aggregateplan)

ret += self._limit.build_args()

Expand Down
45 changes: 37 additions & 8 deletions test/test_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from unittest import TestCase
import unittest
import redisearch.aggregation as a
import redisearch.querystring as q
import redisearch.reducers as r

class QueryBuilderTest(TestCase):
class QueryBuilderTest(unittest.TestCase):
def testBetween(self):
b = q.between(1, 10)
self.assertEqual('[1 10]', str(b))
Expand Down Expand Up @@ -42,16 +42,16 @@ def testGroup(self):
# Single field, single reducer
g = a.Group('foo', r.count())
ret = g.build_args()
self.assertEqual(['1', 'foo', 'REDUCE', 'COUNT', '0'], ret)
self.assertEqual(['GROUPBY', '1', 'foo', 'REDUCE', 'COUNT', '0'], ret)

# Multiple fields, single reducer
g = a.Group(['foo', 'bar'], r.count())
self.assertEqual(['2', 'foo', 'bar', 'REDUCE', 'COUNT', '0'],
self.assertEqual(['GROUPBY', '2', 'foo', 'bar', 'REDUCE', 'COUNT', '0'],
g.build_args())

# Multiple fields, multiple reducers
g = a.Group(['foo', 'bar'], [r.count(), r.count_distinct('@fld1')])
self.assertEqual(['2', 'foo', 'bar', 'REDUCE', 'COUNT', '0', 'REDUCE', 'COUNT_DISTINCT', '1', '@fld1'],
self.assertEqual(['GROUPBY', '2', 'foo', 'bar', 'REDUCE', 'COUNT', '0', 'REDUCE', 'COUNT_DISTINCT', '1', '@fld1'],
g.build_args())

def testAggRequest(self):
Expand All @@ -62,13 +62,38 @@ def testAggRequest(self):
req = a.AggregateRequest().group_by('@foo', r.count())
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'], req.build_args())

# Test with group_by and alias on reducer
req = a.AggregateRequest().group_by('@foo', r.count().alias('foo_count'))
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'AS', 'foo_count'], req.build_args())

# Test with limit
req = a.AggregateRequest().\
group_by('@foo', r.count()).\
req = a.AggregateRequest(). \
group_by('@foo', r.count()). \
sort_by('@foo')
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'SORTBY', '1',
'@foo'], req.build_args())

# Test with apply
req = a.AggregateRequest(). \
apply(foo="@bar / 2"). \
group_by('@foo', r.count())

self.assertEqual(['*', 'APPLY', '@bar / 2', 'AS', 'foo', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'],
req.build_args())

# Test with filter
req = a.AggregateRequest().group_by('@foo', r.count()).filter( "@foo=='bar'")
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'FILTER', "@foo=='bar'" ], req.build_args())

# Test with filter on different state of the pipeline
req = a.AggregateRequest().filter("@foo=='bar'").group_by('@foo', r.count())
self.assertEqual(['*', 'FILTER', "@foo=='bar'", 'GROUPBY', '1', '@foo','REDUCE', 'COUNT', '0' ], req.build_args())

# Test with filter on different state of the pipeline
req = a.AggregateRequest().filter(["@foo=='bar'","@foo2=='bar2'"]).group_by('@foo', r.count())
self.assertEqual(['*', 'FILTER', "@foo=='bar'", 'FILTER', "@foo2=='bar2'", 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'],
req.build_args())

# Test with sort_by
req = a.AggregateRequest().group_by('@foo', r.count()).sort_by('@date')
# print req.build_args()
Expand Down Expand Up @@ -105,4 +130,8 @@ def test_reducers(self):
self.assertEqual(('f1', 'BY', 'f2', 'ASC'), r.first_value('f1', a.Asc('f2')).args)
self.assertEqual(('f1', 'BY', 'f1', 'ASC'), r.first_value('f1', a.Asc).args)

self.assertEqual(('f1', '50'), r.random_sample('f1', 50).args)
self.assertEqual(('f1', '50'), r.random_sample('f1', 50).args)

if __name__ == '__main__':

unittest.main()