-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Expose term frequency in Painless script score context (#9081)
Add the following functions in Painless script score context: * termfreq * tf * totaltermfreq * sumtotaltermfreq Each of these maps to a Lucene value source. Signed-off-by: Louis Chu <clingzhi@amazon.com>
- Loading branch information
Showing
17 changed files
with
411 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
...rc/yamlRestTest/resources/rest-api-spec/test/painless/120_script_score_term_frequency.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
--- | ||
setup: | ||
- skip: | ||
version: " - 2.9.99" | ||
reason: "termFreq functions for script_score was introduced in 2.10.0" | ||
- do: | ||
indices.create: | ||
index: test | ||
body: | ||
settings: | ||
number_of_shards: 1 | ||
mappings: | ||
properties: | ||
f1: | ||
type: keyword | ||
f2: | ||
type: text | ||
- do: | ||
bulk: | ||
refresh: true | ||
body: | ||
- '{"index": {"_index": "test", "_id": "doc1"}}' | ||
- '{"f1": "v0", "f2": "v1"}' | ||
- '{"index": {"_index": "test", "_id": "doc2"}}' | ||
- '{"f2": "v2"}' | ||
|
||
--- | ||
"Script score function using the termFreq function": | ||
- do: | ||
search: | ||
index: test | ||
rest_total_hits_as_int: true | ||
body: | ||
query: | ||
function_score: | ||
query: | ||
match_all: {} | ||
script_score: | ||
script: | ||
source: "termFreq(params.field, params.term)" | ||
params: | ||
field: "f1" | ||
term: "v0" | ||
- match: { hits.total: 2 } | ||
- match: { hits.hits.0._id: "doc1" } | ||
- match: { hits.hits.1._id: "doc2" } | ||
- match: { hits.hits.0._score: 1.0 } | ||
- match: { hits.hits.1._score: 0.0 } | ||
|
||
--- | ||
"Script score function using the totalTermFreq function": | ||
- do: | ||
search: | ||
index: test | ||
rest_total_hits_as_int: true | ||
body: | ||
query: | ||
function_score: | ||
query: | ||
match_all: {} | ||
script_score: | ||
script: | ||
source: "if (doc[params.field].size() == 0) return params.default_value; else { return totalTermFreq(params.field, params.term); }" | ||
params: | ||
default_value: 0.5 | ||
field: "f1" | ||
term: "v0" | ||
- match: { hits.total: 2 } | ||
- match: { hits.hits.0._id: "doc1" } | ||
- match: { hits.hits.1._id: "doc2" } | ||
- match: { hits.hits.0._score: 1.0 } | ||
- match: { hits.hits.1._score: 0.5 } | ||
|
||
--- | ||
"Script score function using the sumTotalTermFreq function": | ||
- do: | ||
search: | ||
index: test | ||
rest_total_hits_as_int: true | ||
body: | ||
query: | ||
function_score: | ||
query: | ||
match_all: {} | ||
script_score: | ||
script: | ||
source: "if (doc[params.field].size() == 0) return params.default_value; else { return sumTotalTermFreq(params.field); }" | ||
params: | ||
default_value: 0.5 | ||
field: "f1" | ||
- match: { hits.total: 2 } | ||
- match: { hits.hits.0._id: "doc1" } | ||
- match: { hits.hits.1._id: "doc2" } | ||
- match: { hits.hits.0._score: 1.0 } | ||
- match: { hits.hits.1._score: 0.5 } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
22 changes: 22 additions & 0 deletions
22
server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
|
||
package org.opensearch.index.query.functionscore; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* An interface representing a term frequency function used to compute document scores | ||
* based on specific term frequency calculations. Implementations of this interface should | ||
* provide a way to execute the term frequency function for a given document ID. | ||
* | ||
* @opensearch.internal | ||
*/ | ||
public interface TermFrequencyFunction { | ||
Object execute(int docId) throws IOException; | ||
} |
95 changes: 95 additions & 0 deletions
95
.../src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
|
||
package org.opensearch.index.query.functionscore; | ||
|
||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.queries.function.FunctionValues; | ||
import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource; | ||
import org.apache.lucene.queries.function.valuesource.TFValueSource; | ||
import org.apache.lucene.queries.function.valuesource.TermFreqValueSource; | ||
import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource; | ||
import org.apache.lucene.search.IndexSearcher; | ||
import org.opensearch.common.lucene.BytesRefs; | ||
|
||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
/** | ||
* A factory class for creating instances of {@link TermFrequencyFunction}. | ||
* This class provides methods for creating different term frequency functions based on | ||
* the specified function name, field, and term. Each term frequency function is designed | ||
* to compute document scores based on specific term frequency calculations. | ||
* | ||
* @opensearch.internal | ||
*/ | ||
public class TermFrequencyFunctionFactory { | ||
public static TermFrequencyFunction createFunction( | ||
TermFrequencyFunctionName functionName, | ||
String field, | ||
String term, | ||
LeafReaderContext readerContext, | ||
IndexSearcher indexSearcher | ||
) throws IOException { | ||
switch (functionName) { | ||
case TERM_FREQ: | ||
TermFreqValueSource termFreqValueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term)); | ||
FunctionValues functionValues = termFreqValueSource.getValues(null, readerContext); | ||
return docId -> functionValues.intVal(docId); | ||
case TF: | ||
TFValueSource tfValueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term)); | ||
Map<Object, Object> tfContext = new HashMap<>() { | ||
{ | ||
put("searcher", indexSearcher); | ||
} | ||
}; | ||
functionValues = tfValueSource.getValues(tfContext, readerContext); | ||
return docId -> functionValues.floatVal(docId); | ||
case TOTAL_TERM_FREQ: | ||
TotalTermFreqValueSource totalTermFreqValueSource = new TotalTermFreqValueSource( | ||
field, | ||
term, | ||
field, | ||
BytesRefs.toBytesRef(term) | ||
); | ||
Map<Object, Object> ttfContext = new HashMap<>(); | ||
totalTermFreqValueSource.createWeight(ttfContext, indexSearcher); | ||
functionValues = totalTermFreqValueSource.getValues(ttfContext, readerContext); | ||
return docId -> functionValues.longVal(docId); | ||
case SUM_TOTAL_TERM_FREQ: | ||
SumTotalTermFreqValueSource sumTotalTermFreqValueSource = new SumTotalTermFreqValueSource(field); | ||
Map<Object, Object> sttfContext = new HashMap<>(); | ||
sumTotalTermFreqValueSource.createWeight(sttfContext, indexSearcher); | ||
functionValues = sumTotalTermFreqValueSource.getValues(sttfContext, readerContext); | ||
return docId -> functionValues.longVal(docId); | ||
default: | ||
throw new IllegalArgumentException("Unsupported function: " + functionName); | ||
} | ||
} | ||
|
||
/** | ||
* An enumeration representing the names of supported term frequency functions. | ||
*/ | ||
public enum TermFrequencyFunctionName { | ||
TERM_FREQ("termFreq"), | ||
TF("tf"), | ||
TOTAL_TERM_FREQ("totalTermFreq"), | ||
SUM_TOTAL_TERM_FREQ("sumTotalTermFreq"); | ||
|
||
private final String termFrequencyFunctionName; | ||
|
||
TermFrequencyFunctionName(String termFrequencyFunctionName) { | ||
this.termFrequencyFunctionName = termFrequencyFunctionName; | ||
} | ||
|
||
public String getTermFrequencyFunctionName() { | ||
return termFrequencyFunctionName; | ||
} | ||
} | ||
} |
Oops, something went wrong.