Skip to content

Commit

Permalink
Fixing the backward incompatible changes coming from core in ScoreScr…
Browse files Browse the repository at this point in the history
…ipt class

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Aug 23, 2023
1 parent 507bee3 commit 9ba967f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
20 changes: 12 additions & 8 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.plugin.script;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.index.KNNVectorScriptDocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.index.fielddata.ScriptDocValues;
Expand Down Expand Up @@ -32,9 +33,9 @@ public KNNScoreScript(
String field,
BiFunction<T, T, Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext
LeafReaderContext leafContext, IndexSearcher searcher
) {
super(params, lookup, leafContext);
super(params, lookup, searcher, leafContext);
this.queryValue = queryValue;
this.field = field;
this.scoringMethod = scoringMethod;
Expand All @@ -51,9 +52,10 @@ public LongType(
String field,
BiFunction<Long, Long, Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext
LeafReaderContext leafContext,
IndexSearcher searcher
) {
super(params, queryValue, field, scoringMethod, lookup, leafContext);
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
Expand Down Expand Up @@ -84,9 +86,10 @@ public BigIntegerType(
String field,
BiFunction<BigInteger, BigInteger, Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext
LeafReaderContext leafContext,
IndexSearcher searcher
) {
super(params, queryValue, field, scoringMethod, lookup, leafContext);
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
Expand Down Expand Up @@ -118,9 +121,10 @@ public KNNVectorType(
String field,
BiFunction<float[], float[], Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext
LeafReaderContext leafContext,
IndexSearcher searcher
) throws IOException {
super(params, queryValue, field, scoringMethod, lookup, leafContext);
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.plugin.script;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.script.ScoreScript;
Expand All @@ -21,13 +22,16 @@ public class KNNScoreScriptFactory implements ScoreScript.LeafFactory {
private Object query;
private KNNScoringSpace knnScoringSpace;

public KNNScoreScriptFactory(Map<String, Object> params, SearchLookup lookup) {
private IndexSearcher searcher;

public KNNScoreScriptFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher searcher) {
KNNCounter.SCRIPT_QUERY_REQUESTS.increment();
this.params = params;
this.lookup = lookup;
this.field = getValue(params, "field").toString();
this.similaritySpace = getValue(params, "space_type").toString();
this.query = getValue(params, "query_value");
this.searcher = searcher;

this.knnScoringSpace = KNNScoringSpaceFactory.create(
this.similaritySpace,
Expand Down Expand Up @@ -60,6 +64,6 @@ public boolean needs_score() {
*/
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return knnScoringSpace.getScoreScript(params, field, lookup, ctx);
return knnScoringSpace.getScoreScript(params, field, lookup, ctx, this.searcher);
}
}
47 changes: 29 additions & 18 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.plugin.script;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.core.index.Index;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNWeight;
import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -29,14 +31,16 @@ public interface KNNScoringSpace {
/**
* Return the correct scoring script for a given query. The scoring script
*
* @param params Map of parameters
* @param field Fieldname
* @param lookup SearchLookup
* @param ctx ctx LeafReaderContext to be used for scoring documents
* @param params Map of parameters
* @param field Fieldname
* @param lookup SearchLookup
* @param ctx ctx LeafReaderContext to be used for scoring documents
* @param searcher IndexSearcher
* @return ScoreScript for this query
* @throws IOException throws IOException if ScoreScript cannot be constructed
*/
ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx) throws IOException;
ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx,
IndexSearcher searcher) throws IOException;

class L2 implements KNNScoringSpace {

Expand All @@ -62,9 +66,10 @@ public L2(Object query, MappedFieldType fieldType) {
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v));
}

public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher)
throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx);
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup,
ctx, searcher);
}
}

Expand Down Expand Up @@ -94,9 +99,10 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) {
this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude);
}

public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher)
throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx);
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup,
ctx, searcher);
}
}

Expand Down Expand Up @@ -127,7 +133,8 @@ public HammingBit(Object query, MappedFieldType fieldType) {
}

@SuppressWarnings("unchecked")
public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup,
LeafReaderContext ctx, IndexSearcher searcher)
throws IOException {
if (this.processedQuery instanceof Long) {
return new KNNScoreScript.LongType(
Expand All @@ -136,7 +143,8 @@ public ScoreScript getScoreScript(Map<String, Object> params, String field, Sear
field,
(BiFunction<Long, Long, Float>) this.scoringMethod,
lookup,
ctx
ctx,
searcher
);
}

Expand All @@ -146,7 +154,7 @@ public ScoreScript getScoreScript(Map<String, Object> params, String field, Sear
field,
(BiFunction<BigInteger, BigInteger, Float>) this.scoringMethod,
lookup,
ctx
ctx, searcher
);
}
}
Expand Down Expand Up @@ -175,9 +183,10 @@ public L1(Object query, MappedFieldType fieldType) {
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v));
}

public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher)
throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx);
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup,
ctx, searcher);
}
}

Expand Down Expand Up @@ -205,9 +214,10 @@ public LInf(Object query, MappedFieldType fieldType) {
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v));
}

public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher)
throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx);
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup,
ctx, searcher);
}
}

Expand Down Expand Up @@ -238,9 +248,10 @@ public InnerProd(Object query, MappedFieldType fieldType) {
}

@Override
public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher)
throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx);
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup,
ctx, searcher);
}
}
}

0 comments on commit 9ba967f

Please sign in to comment.