Skip to content

Commit

Permalink
[Feature] Expose term frequency in Painless script score context (#9081)
Browse files Browse the repository at this point in the history
Add the following functions in Painless script score context:
* termfreq
* tf
* totaltermfreq
* sumtotaltermfreq

Each of these maps to a Lucene value source.

Signed-off-by: Louis Chu <clingzhi@amazon.com>
  • Loading branch information
noCharger authored Aug 23, 2023
1 parent 60d272b commit 5d3633c
Show file tree
Hide file tree
Showing 17 changed files with 411 additions and 20 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Make SearchTemplateRequest implement IndicesRequest.Replaceable ([#9122]()https://github.com/opensearch-project/OpenSearch/pull/9122)
- [BWC and API enforcement] Define the initial set of annotations, their meaning and relations between them ([#9223](https://github.com/opensearch-project/OpenSearch/pull/9223))
- [Segment Replication] Support realtime reads for GET requests ([#9212](https://github.com/opensearch-project/OpenSearch/pull/9212))
- [Feature] Expose term frequency in Painless script score context ([#9081](https://github.com/opensearch-project/OpenSearch/pull/9081))

### Dependencies
- Bump `org.apache.logging.log4j:log4j-core` from 2.17.1 to 2.20.0 ([#8307](https://github.com/opensearch-project/OpenSearch/pull/8307))
Expand Down Expand Up @@ -164,4 +165,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Security

[Unreleased 3.0]: https://github.com/opensearch-project/OpenSearch/compare/2.x...HEAD
[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.10...2.x
[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.10...2.x
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public boolean needs_score() {

@Override
public ScoreScript newInstance(final LeafReaderContext leaf) throws IOException {
return new ScoreScript(null, null, null) {
return new ScoreScript(null, null, null, null) {
// Fake the scorer until setScorer is called.
DoubleValues values = source.getValues(leaf, new DoubleValues() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.lucene.expressions.js.JavascriptCompiler;
import org.apache.lucene.expressions.js.VariableContext;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.SpecialPermission;
import org.opensearch.common.Nullable;
import org.opensearch.index.fielddata.IndexFieldData;
Expand Down Expand Up @@ -110,7 +111,7 @@ public FilterScript.LeafFactory newFactory(Map<String, Object> params, SearchLoo

contexts.put(ScoreScript.CONTEXT, (Expression expr) -> new ScoreScript.Factory() {
@Override
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup) {
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher) {
return newScoreScript(expr, lookup, params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,11 @@ static Response innerShardOperation(Request request, ScriptService scriptService
} else if (scriptContext == ScoreScript.CONTEXT) {
return prepareRamIndex(request, (context, leafReaderContext) -> {
ScoreScript.Factory factory = scriptService.compile(request.script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory leafFactory = factory.newFactory(request.getScript().getParams(), context.lookup());
ScoreScript.LeafFactory leafFactory = factory.newFactory(
request.getScript().getParams(),
context.lookup(),
context.searcher()
);
ScoreScript scoreScript = leafFactory.newInstance(leafReaderContext);
scoreScript.setDocument(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class org.opensearch.script.ScoreScript @no_import {
}

static_import {
int termFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TermFreq
float tf(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TF
long totalTermFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TotalTermFreq
long sumTotalTermFreq(org.opensearch.script.ScoreScript, String) bound_to org.opensearch.script.ScoreScriptUtils$SumTotalTermFreq
double saturation(double, double) from_class org.opensearch.script.ScoreScriptUtils
double sigmoid(double, double, double) from_class org.opensearch.script.ScoreScriptUtils
double randomScore(org.opensearch.script.ScoreScript, int, String) bound_to org.opensearch.script.ScoreScriptUtils$RandomScoreField
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
---
setup:
- skip:
version: " - 2.9.99"
reason: "termFreq functions for script_score was introduced in 2.10.0"
- do:
indices.create:
index: test
body:
settings:
number_of_shards: 1
mappings:
properties:
f1:
type: keyword
f2:
type: text
- do:
bulk:
refresh: true
body:
- '{"index": {"_index": "test", "_id": "doc1"}}'
- '{"f1": "v0", "f2": "v1"}'
- '{"index": {"_index": "test", "_id": "doc2"}}'
- '{"f2": "v2"}'

---
"Script score function using the termFreq function":
- do:
search:
index: test
rest_total_hits_as_int: true
body:
query:
function_score:
query:
match_all: {}
script_score:
script:
source: "termFreq(params.field, params.term)"
params:
field: "f1"
term: "v0"
- match: { hits.total: 2 }
- match: { hits.hits.0._id: "doc1" }
- match: { hits.hits.1._id: "doc2" }
- match: { hits.hits.0._score: 1.0 }
- match: { hits.hits.1._score: 0.0 }

---
"Script score function using the totalTermFreq function":
- do:
search:
index: test
rest_total_hits_as_int: true
body:
query:
function_score:
query:
match_all: {}
script_score:
script:
source: "if (doc[params.field].size() == 0) return params.default_value; else { return totalTermFreq(params.field, params.term); }"
params:
default_value: 0.5
field: "f1"
term: "v0"
- match: { hits.total: 2 }
- match: { hits.hits.0._id: "doc1" }
- match: { hits.hits.1._id: "doc2" }
- match: { hits.hits.0._score: 1.0 }
- match: { hits.hits.1._score: 0.5 }

---
"Script score function using the sumTotalTermFreq function":
- do:
search:
index: test
rest_total_hits_as_int: true
body:
query:
function_score:
query:
match_all: {}
script_score:
script:
source: "if (doc[params.field].size() == 0) return params.default_value; else { return sumTotalTermFreq(params.field); }"
params:
default_value: 0.5
field: "f1"
- match: { hits.total: 2 }
- match: { hits.hits.0._id: "doc1" }
- match: { hits.hits.1._id: "doc2" }
- match: { hits.hits.0._score: 1.0 }
- match: { hits.hits.1._score: 0.5 }
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.settings.Settings;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.ScriptPlugin;
Expand Down Expand Up @@ -120,20 +121,22 @@ public boolean isResultDeterministic() {
@Override
public LeafFactory newFactory(
Map<String, Object> params,
SearchLookup lookup
SearchLookup lookup,
IndexSearcher indexSearcher
) {
return new PureDfLeafFactory(params, lookup);
return new PureDfLeafFactory(params, lookup, indexSearcher);
}
}

private static class PureDfLeafFactory implements LeafFactory {
private final Map<String, Object> params;
private final SearchLookup lookup;
private final IndexSearcher indexSearcher;
private final String field;
private final String term;

private PureDfLeafFactory(
Map<String, Object> params, SearchLookup lookup) {
Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher) {
if (params.containsKey("field") == false) {
throw new IllegalArgumentException(
"Missing parameter [field]");
Expand All @@ -144,6 +147,7 @@ private PureDfLeafFactory(
}
this.params = params;
this.lookup = lookup;
this.indexSearcher = indexSearcher;
field = params.get("field").toString();
term = params.get("term").toString();
}
Expand All @@ -163,7 +167,7 @@ public ScoreScript newInstance(LeafReaderContext context)
* the field and/or term don't exist in this segment,
* so always return 0
*/
return new ScoreScript(params, lookup, context) {
return new ScoreScript(params, lookup, indexSearcher, context) {
@Override
public double execute(
ExplanationHolder explanation
Expand All @@ -172,7 +176,7 @@ public double execute(
}
};
}
return new ScoreScript(params, lookup, context) {
return new ScoreScript(params, lookup, indexSearcher, context) {
int currentDocid = -1;
@Override
public void setDocument(int docid) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.action.index.IndexRequestBuilder;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchType;
Expand Down Expand Up @@ -93,15 +94,15 @@ public String getType() {
public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
assert scriptSource.equals("explainable_script");
assert context == ScoreScript.CONTEXT;
ScoreScript.Factory factory = (params1, lookup) -> new ScoreScript.LeafFactory() {
ScoreScript.Factory factory = (params1, lookup, indexSearcher) -> new ScoreScript.LeafFactory() {
@Override
public boolean needs_score() {
return false;
}

@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return new MyScript(params1, lookup, ctx);
return new MyScript(params1, lookup, indexSearcher, ctx);
}
};
return context.factoryClazz.cast(factory);
Expand All @@ -117,8 +118,8 @@ public Set<ScriptContext<?>> getSupportedContexts() {

static class MyScript extends ScoreScript implements ExplainableScoreScript {

MyScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
super(params, lookup, leafContext);
MyScript(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) {
super(params, lookup, indexSearcher, leafContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ protected int doHashCode() {
protected ScoreFunction doToFunction(QueryShardContext context) {
try {
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup(), context.searcher());
return new ScriptScoreFunction(
script,
searchScript,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
);
}
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup(), context.searcher());
final QueryBuilder queryBuilder = this.query;
Query query = queryBuilder.toQuery(context);
return new ScriptScoreQuery(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.query.functionscore;

import java.io.IOException;

/**
* An interface representing a term frequency function used to compute document scores
* based on specific term frequency calculations. Implementations of this interface should
* provide a way to execute the term frequency function for a given document ID.
*
* @opensearch.internal
*/
public interface TermFrequencyFunction {
Object execute(int docId) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.query.functionscore;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TFValueSource;
import org.apache.lucene.queries.function.valuesource.TermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.lucene.BytesRefs;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
* A factory class for creating instances of {@link TermFrequencyFunction}.
* This class provides methods for creating different term frequency functions based on
* the specified function name, field, and term. Each term frequency function is designed
* to compute document scores based on specific term frequency calculations.
*
* @opensearch.internal
*/
public class TermFrequencyFunctionFactory {
public static TermFrequencyFunction createFunction(
TermFrequencyFunctionName functionName,
String field,
String term,
LeafReaderContext readerContext,
IndexSearcher indexSearcher
) throws IOException {
switch (functionName) {
case TERM_FREQ:
TermFreqValueSource termFreqValueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
FunctionValues functionValues = termFreqValueSource.getValues(null, readerContext);
return docId -> functionValues.intVal(docId);
case TF:
TFValueSource tfValueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term));
Map<Object, Object> tfContext = new HashMap<>() {
{
put("searcher", indexSearcher);
}
};
functionValues = tfValueSource.getValues(tfContext, readerContext);
return docId -> functionValues.floatVal(docId);
case TOTAL_TERM_FREQ:
TotalTermFreqValueSource totalTermFreqValueSource = new TotalTermFreqValueSource(
field,
term,
field,
BytesRefs.toBytesRef(term)
);
Map<Object, Object> ttfContext = new HashMap<>();
totalTermFreqValueSource.createWeight(ttfContext, indexSearcher);
functionValues = totalTermFreqValueSource.getValues(ttfContext, readerContext);
return docId -> functionValues.longVal(docId);
case SUM_TOTAL_TERM_FREQ:
SumTotalTermFreqValueSource sumTotalTermFreqValueSource = new SumTotalTermFreqValueSource(field);
Map<Object, Object> sttfContext = new HashMap<>();
sumTotalTermFreqValueSource.createWeight(sttfContext, indexSearcher);
functionValues = sumTotalTermFreqValueSource.getValues(sttfContext, readerContext);
return docId -> functionValues.longVal(docId);
default:
throw new IllegalArgumentException("Unsupported function: " + functionName);
}
}

/**
* An enumeration representing the names of supported term frequency functions.
*/
public enum TermFrequencyFunctionName {
TERM_FREQ("termFreq"),
TF("tf"),
TOTAL_TERM_FREQ("totalTermFreq"),
SUM_TOTAL_TERM_FREQ("sumTotalTermFreq");

private final String termFrequencyFunctionName;

TermFrequencyFunctionName(String termFrequencyFunctionName) {
this.termFrequencyFunctionName = termFrequencyFunctionName;
}

public String getTermFrequencyFunctionName() {
return termFrequencyFunctionName;
}
}
}
Loading

0 comments on commit 5d3633c

Please sign in to comment.