diff --git a/.gitignore b/.gitignore
index 5fd5046..733ac26 100644
--- a/.gitignore
+++ b/.gitignore
@@ -150,4 +150,7 @@ fabric.properties
.idea/httpRequests
# Android studio 3.1+ serialized cache file
-.idea/caches/build_file_checksums.ser
\ No newline at end of file
+.idea/caches/build_file_checksums.ser
+
+# Dependencies
+deps/*
diff --git a/redisearch/query.py b/redisearch/query.py
index d2c379c..8e97c2b 100644
--- a/redisearch/query.py
+++ b/redisearch/query.py
@@ -5,13 +5,13 @@ class Query(object):
Query is used to build complex queries that have more parameters than just the query string.
The query string is set in the constructor, and other options have setter functions.
- The setter functions return the query object, so they can be chained,
+ The setter functions return the query object, so they can be chained,
i.e. `Query("foo").verbatim().filter(...)` etc.
"""
def __init__(self, query_string):
"""
- Create a new query object.
+ Create a new query object.
The query string is set in the constructor, and other options have setter functions.
"""
@@ -24,6 +24,7 @@ def __init__(self, query_string):
self._verbatim = False
self._with_payloads = False
self._with_scores = False
+ self._scorer = False
self._filters = list()
self._ids = None
self._slop = -1
@@ -129,6 +130,14 @@ def in_order(self):
self._in_order = True
return self
+ def scorer(self, scorer):
+ """
+ Use a different scoring function to evaluate document relevance. Default is `TFIDF`
+ :param scorer: The scoring function to use (e.g. `TFIDF.DOCNORM` or `BM25`)
+ """
+ self._scorer = scorer
+ return self
+
def get_args(self):
"""
Format the redis arguments for this query and return them
@@ -144,7 +153,7 @@ def get_args(self):
args.append('INFIELDS')
args.append(len(self._fields))
args += self._fields
-
+
if self._verbatim:
args.append('VERBATIM')
@@ -159,9 +168,12 @@ def get_args(self):
if self._with_payloads:
args.append('WITHPAYLOADS')
+ if self._scorer:
+ args += ['SCORER', self._scorer]
+
if self._with_scores:
args.append('WITHSCORES')
-
+
if self._ids:
args.append('INKEYS')
args.append(len(self._ids))
@@ -217,7 +229,7 @@ def no_content(self):
def no_stopwords(self):
"""
- Prevent the query from being filtered for stopwords.
+ Prevent the query from being filtered for stopwords.
Only useful in very big queries that you are certain contain no stopwords.
"""
self._no_stopwords = True
@@ -236,7 +248,7 @@ def with_scores(self):
"""
self._with_scores = True
return self
-
+
def limit_fields(self, *fields):
"""
Limit the search to specific TEXT fields only
@@ -248,7 +260,7 @@ def limit_fields(self, *fields):
def add_filter(self, flt):
"""
- Add a numeric or geo filter to the query.
+ Add a numeric or geo filter to the query.
**Currently only one of each filter is supported by the engine**
- **flt**: A NumericFilter or GeoFilter object, used on a corresponding field
@@ -273,7 +285,7 @@ class Filter(object):
def __init__(self, keyword, field, *args):
self.args = [keyword, field] + list(args)
-
+
class NumericFilter(Filter):
INF = '+inf'
@@ -303,4 +315,4 @@ class SortbyField(object):
def __init__(self, field, asc=True):
- self.args = [field, 'ASC' if asc else 'DESC']
\ No newline at end of file
+ self.args = [field, 'ASC' if asc else 'DESC']
diff --git a/test/test.py b/test/test.py
index 16b5b98..ea465a8 100644
--- a/test/test.py
+++ b/test/test.py
@@ -53,8 +53,8 @@ def createIndex(self, client, num_docs = 100, definition=None):
assert isinstance(client, Client)
try:
- client.create_index((TextField('play', weight=5.0),
- TextField('txt'),
+ client.create_index((TextField('play', weight=5.0),
+ TextField('txt'),
NumericField('chapter')), definition=definition)
except redis.ResponseError:
client.dropindex(delete_documents=True)
@@ -161,7 +161,7 @@ def testClient(self):
self.assertEqual(len(subset), docs.total)
ids = [x.id for x in docs.docs]
self.assertEqual(set(ids), set(subset))
-
+
# self.assertRaises(redis.ResponseError, client.search, Query('henry king').return_fields('play', 'nonexist'))
# test slop and in order
@@ -272,7 +272,7 @@ def testScores(self):
#self.assertEqual(0.2, res.docs[1].score)
def testReplace(self):
-
+
conn = self.redis()
with conn as r:
@@ -296,7 +296,7 @@ def testReplace(self):
self.assertEqual(1, res.total)
self.assertEqual('doc1', res.docs[0].id)
- def testStopwords(self):
+ def testStopwords(self):
# Creating a client with a given index name
client = self.getCleanClient('idx')
@@ -324,7 +324,7 @@ def testFilters(self):
for i in r.retry_with_rdb_reload():
waitForIndex(r, 'idx')
- # Test numerical filter
+ # Test numerical filter
q1 = Query("foo").add_filter(NumericFilter('num', 0, 2)).no_content()
q2 = Query("foo").add_filter(NumericFilter('num', 2, NumericFilter.INF, minExclusive=True)).no_content()
res1, res2 = client.search(q1), client.search(q2)
@@ -338,11 +338,11 @@ def testFilters(self):
q1 = Query("foo").add_filter(GeoFilter('loc', -0.44, 51.45, 10)).no_content()
q2 = Query("foo").add_filter(GeoFilter('loc', -0.44, 51.45, 100)).no_content()
res1, res2 = client.search(q1), client.search(q2)
-
+
self.assertEqual(1, res1.total)
self.assertEqual(2, res2.total)
self.assertEqual('doc1', res1.docs[0].id)
-
+
# Sort results, after RDB reload order may change
list = [res2.docs[0].id, res2.docs[1].id]
list.sort()
@@ -371,7 +371,7 @@ def testSortby(self):
# Creating a client with a given index name
client = Client('idx', port=conn.port)
client.redis.flushdb()
-
+
client.create_index((TextField('txt'), NumericField('num', sortable=True)))
client.add_document('doc1', txt = 'foo bar', num = 1)
client.add_document('doc2', txt = 'foo baz', num = 2)
@@ -381,7 +381,7 @@ def testSortby(self):
q1 = Query("foo").sort_by('num', asc=True).no_content()
q2 = Query("foo").sort_by('num', asc=False).no_content()
res1, res2 = client.search(q1), client.search(q2)
-
+
self.assertEqual(3, res1.total)
self.assertEqual('doc1', res1.docs[0].id)
self.assertEqual('doc2', res1.docs[1].id)
@@ -417,7 +417,7 @@ def testExample(self):
# Creating a client with a given index name
client = Client('myIndex', port=conn.port)
client.redis.flushdb()
-
+
# Creating the index definition and schema
client.create_index((TextField('title', weight=5.0), TextField('body')))
@@ -552,7 +552,7 @@ def testNoCreate(self):
# values
res = client.search('@f3:f3_val @f2:f2_val @f1:f1_val')
self.assertEqual(1, res.total)
-
+
with self.assertRaises(redis.ResponseError) as error:
client.add_document('doc3', f2='f2_val', f3='f3_val', no_create=True)
@@ -578,7 +578,7 @@ def testSummarize(self):
doc.txt)
q = Query('king henry').paging(0, 1).summarize().highlight()
-
+
doc = sorted(client.search(q).docs)[0]
self.assertEqual('Henry ... ', doc.play)
self.assertEqual('ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... ',
@@ -600,7 +600,7 @@ def testAlias(self):
def1 =IndexDefinition(prefix=['index1:'],score_field='name')
def2 =IndexDefinition(prefix=['index2:'],score_field='name')
-
+
index1.create_index((TextField('name'),),definition=def1)
index2.create_index((TextField('name'),),definition=def2)
@@ -628,7 +628,7 @@ def testAlias(self):
with self.assertRaises(Exception) as context:
alias_client2.search('*').docs[0]
self.assertEqual('spaceballs: no such index', str(context.exception))
-
+
else:
# Creating a client with one index
@@ -808,6 +808,36 @@ def testPhoneticMatcher(self):
self.assertEqual(2, len(res.docs))
self.assertEqual(['John', 'Jon'], sorted([d.name for d in res.docs]))
+ def testScorer(self):
+ # Creating a client with a given index name
+ client = self.getCleanClient('idx')
+
+ client.create_index((TextField('description'),))
+
+ client.add_document('doc1', description='The quick brown fox jumps over the lazy dog')
+ client.add_document('doc2', description='Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.')
+
+ # default scorer is TFIDF
+ res = client.search(Query('quick').with_scores())
+ self.assertEqual(1.0, res.docs[0].score)
+ res = client.search(Query('quick').scorer('TFIDF').with_scores())
+ self.assertEqual(1.0, res.docs[0].score)
+
+ res = client.search(Query('quick').scorer('TFIDF.DOCNORM').with_scores())
+ self.assertEqual(0.1111111111111111, res.docs[0].score)
+
+ res = client.search(Query('quick').scorer('BM25').with_scores())
+ self.assertEqual(0.17699114465425977, res.docs[0].score)
+
+ res = client.search(Query('quick').scorer('DISMAX').with_scores())
+ self.assertEqual(2.0, res.docs[0].score)
+
+ res = client.search(Query('quick').scorer('DOCSCORE').with_scores())
+ self.assertEqual(1.0, res.docs[0].score)
+
+ res = client.search(Query('quick').scorer('HAMMING').with_scores())
+ self.assertEqual(0.0, res.docs[0].score)
+
def testGet(self):
client = self.getCleanClient('idx')
client.create_index((TextField('f1'), TextField('f2')))