diff --git a/.gitignore b/.gitignore index 5fd5046..733ac26 100644 --- a/.gitignore +++ b/.gitignore @@ -150,4 +150,7 @@ fabric.properties .idea/httpRequests # Android studio 3.1+ serialized cache file -.idea/caches/build_file_checksums.ser \ No newline at end of file +.idea/caches/build_file_checksums.ser + +# Dependencies +deps/* diff --git a/redisearch/query.py b/redisearch/query.py index d2c379c..8e97c2b 100644 --- a/redisearch/query.py +++ b/redisearch/query.py @@ -5,13 +5,13 @@ class Query(object): Query is used to build complex queries that have more parameters than just the query string. The query string is set in the constructor, and other options have setter functions. - The setter functions return the query object, so they can be chained, + The setter functions return the query object, so they can be chained, i.e. `Query("foo").verbatim().filter(...)` etc. """ def __init__(self, query_string): """ - Create a new query object. + Create a new query object. The query string is set in the constructor, and other options have setter functions. """ @@ -24,6 +24,7 @@ def __init__(self, query_string): self._verbatim = False self._with_payloads = False self._with_scores = False + self._scorer = False self._filters = list() self._ids = None self._slop = -1 @@ -129,6 +130,14 @@ def in_order(self): self._in_order = True return self + def scorer(self, scorer): + """ + Use a different scoring function to evaluate document relevance. Default is `TFIDF` + :param scorer: The scoring function to use (e.g. `TFIDF.DOCNORM` or `BM25`) + """ + self._scorer = scorer + return self + def get_args(self): """ Format the redis arguments for this query and return them @@ -144,7 +153,7 @@ def get_args(self): args.append('INFIELDS') args.append(len(self._fields)) args += self._fields - + if self._verbatim: args.append('VERBATIM') @@ -159,9 +168,12 @@ def get_args(self): if self._with_payloads: args.append('WITHPAYLOADS') + if self._scorer: + args += ['SCORER', self._scorer] + if self._with_scores: args.append('WITHSCORES') - + if self._ids: args.append('INKEYS') args.append(len(self._ids)) @@ -217,7 +229,7 @@ def no_content(self): def no_stopwords(self): """ - Prevent the query from being filtered for stopwords. + Prevent the query from being filtered for stopwords. Only useful in very big queries that you are certain contain no stopwords. """ self._no_stopwords = True @@ -236,7 +248,7 @@ def with_scores(self): """ self._with_scores = True return self - + def limit_fields(self, *fields): """ Limit the search to specific TEXT fields only @@ -248,7 +260,7 @@ def limit_fields(self, *fields): def add_filter(self, flt): """ - Add a numeric or geo filter to the query. + Add a numeric or geo filter to the query. **Currently only one of each filter is supported by the engine** - **flt**: A NumericFilter or GeoFilter object, used on a corresponding field @@ -273,7 +285,7 @@ class Filter(object): def __init__(self, keyword, field, *args): self.args = [keyword, field] + list(args) - + class NumericFilter(Filter): INF = '+inf' @@ -303,4 +315,4 @@ class SortbyField(object): def __init__(self, field, asc=True): - self.args = [field, 'ASC' if asc else 'DESC'] \ No newline at end of file + self.args = [field, 'ASC' if asc else 'DESC'] diff --git a/test/test.py b/test/test.py index 16b5b98..ea465a8 100644 --- a/test/test.py +++ b/test/test.py @@ -53,8 +53,8 @@ def createIndex(self, client, num_docs = 100, definition=None): assert isinstance(client, Client) try: - client.create_index((TextField('play', weight=5.0), - TextField('txt'), + client.create_index((TextField('play', weight=5.0), + TextField('txt'), NumericField('chapter')), definition=definition) except redis.ResponseError: client.dropindex(delete_documents=True) @@ -161,7 +161,7 @@ def testClient(self): self.assertEqual(len(subset), docs.total) ids = [x.id for x in docs.docs] self.assertEqual(set(ids), set(subset)) - + # self.assertRaises(redis.ResponseError, client.search, Query('henry king').return_fields('play', 'nonexist')) # test slop and in order @@ -272,7 +272,7 @@ def testScores(self): #self.assertEqual(0.2, res.docs[1].score) def testReplace(self): - + conn = self.redis() with conn as r: @@ -296,7 +296,7 @@ def testReplace(self): self.assertEqual(1, res.total) self.assertEqual('doc1', res.docs[0].id) - def testStopwords(self): + def testStopwords(self): # Creating a client with a given index name client = self.getCleanClient('idx') @@ -324,7 +324,7 @@ def testFilters(self): for i in r.retry_with_rdb_reload(): waitForIndex(r, 'idx') - # Test numerical filter + # Test numerical filter q1 = Query("foo").add_filter(NumericFilter('num', 0, 2)).no_content() q2 = Query("foo").add_filter(NumericFilter('num', 2, NumericFilter.INF, minExclusive=True)).no_content() res1, res2 = client.search(q1), client.search(q2) @@ -338,11 +338,11 @@ def testFilters(self): q1 = Query("foo").add_filter(GeoFilter('loc', -0.44, 51.45, 10)).no_content() q2 = Query("foo").add_filter(GeoFilter('loc', -0.44, 51.45, 100)).no_content() res1, res2 = client.search(q1), client.search(q2) - + self.assertEqual(1, res1.total) self.assertEqual(2, res2.total) self.assertEqual('doc1', res1.docs[0].id) - + # Sort results, after RDB reload order may change list = [res2.docs[0].id, res2.docs[1].id] list.sort() @@ -371,7 +371,7 @@ def testSortby(self): # Creating a client with a given index name client = Client('idx', port=conn.port) client.redis.flushdb() - + client.create_index((TextField('txt'), NumericField('num', sortable=True))) client.add_document('doc1', txt = 'foo bar', num = 1) client.add_document('doc2', txt = 'foo baz', num = 2) @@ -381,7 +381,7 @@ def testSortby(self): q1 = Query("foo").sort_by('num', asc=True).no_content() q2 = Query("foo").sort_by('num', asc=False).no_content() res1, res2 = client.search(q1), client.search(q2) - + self.assertEqual(3, res1.total) self.assertEqual('doc1', res1.docs[0].id) self.assertEqual('doc2', res1.docs[1].id) @@ -417,7 +417,7 @@ def testExample(self): # Creating a client with a given index name client = Client('myIndex', port=conn.port) client.redis.flushdb() - + # Creating the index definition and schema client.create_index((TextField('title', weight=5.0), TextField('body'))) @@ -552,7 +552,7 @@ def testNoCreate(self): # values res = client.search('@f3:f3_val @f2:f2_val @f1:f1_val') self.assertEqual(1, res.total) - + with self.assertRaises(redis.ResponseError) as error: client.add_document('doc3', f2='f2_val', f3='f3_val', no_create=True) @@ -578,7 +578,7 @@ def testSummarize(self): doc.txt) q = Query('king henry').paging(0, 1).summarize().highlight() - + doc = sorted(client.search(q).docs)[0] self.assertEqual('Henry ... ', doc.play) self.assertEqual('ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... ', @@ -600,7 +600,7 @@ def testAlias(self): def1 =IndexDefinition(prefix=['index1:'],score_field='name') def2 =IndexDefinition(prefix=['index2:'],score_field='name') - + index1.create_index((TextField('name'),),definition=def1) index2.create_index((TextField('name'),),definition=def2) @@ -628,7 +628,7 @@ def testAlias(self): with self.assertRaises(Exception) as context: alias_client2.search('*').docs[0] self.assertEqual('spaceballs: no such index', str(context.exception)) - + else: # Creating a client with one index @@ -808,6 +808,36 @@ def testPhoneticMatcher(self): self.assertEqual(2, len(res.docs)) self.assertEqual(['John', 'Jon'], sorted([d.name for d in res.docs])) + def testScorer(self): + # Creating a client with a given index name + client = self.getCleanClient('idx') + + client.create_index((TextField('description'),)) + + client.add_document('doc1', description='The quick brown fox jumps over the lazy dog') + client.add_document('doc2', description='Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.') + + # default scorer is TFIDF + res = client.search(Query('quick').with_scores()) + self.assertEqual(1.0, res.docs[0].score) + res = client.search(Query('quick').scorer('TFIDF').with_scores()) + self.assertEqual(1.0, res.docs[0].score) + + res = client.search(Query('quick').scorer('TFIDF.DOCNORM').with_scores()) + self.assertEqual(0.1111111111111111, res.docs[0].score) + + res = client.search(Query('quick').scorer('BM25').with_scores()) + self.assertEqual(0.17699114465425977, res.docs[0].score) + + res = client.search(Query('quick').scorer('DISMAX').with_scores()) + self.assertEqual(2.0, res.docs[0].score) + + res = client.search(Query('quick').scorer('DOCSCORE').with_scores()) + self.assertEqual(1.0, res.docs[0].score) + + res = client.search(Query('quick').scorer('HAMMING').with_scores()) + self.assertEqual(0.0, res.docs[0].score) + def testGet(self): client = self.getCleanClient('idx') client.create_index((TextField('f1'), TextField('f2')))