Skip to content

Commit

Permalink
Add filter option for query type (#88)
Browse files Browse the repository at this point in the history
Adds filter option for query type. Filtering support was introduced in
the k-NN plugin in 2.4. Breaks backwards compatibility with OpenSearch
2.4, however, given that 2.4 is experimental, this is okay. Backwards
  compatibility issues will only arise during mixed cluster upgrade.

Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 authored Dec 21, 2022
1 parent 6f31e51 commit de551e2
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.query;

import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;

import java.io.IOException;
Expand Down Expand Up @@ -80,6 +81,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
@Getter(AccessLevel.PACKAGE)
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;

/**
* Constructor from stream input
Expand All @@ -93,6 +95,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
this.queryText = in.readString();
this.modelId = in.readString();
this.k = in.readVInt();
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}

@Override
Expand All @@ -101,6 +104,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.queryText);
out.writeString(this.modelId);
out.writeVInt(this.k);
out.writeOptionalNamedWriteable(this.filter);
}

@Override
Expand All @@ -110,6 +114,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
xContentBuilder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand All @@ -125,7 +132,8 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
* "model_id": "string",
* "k": int,
* "name": "string", (optional)
* "boost": float (optional)
* "boost": float (optional),
* "filter": map (optional)
* }
* }
*
Expand Down Expand Up @@ -184,6 +192,10 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
"[" + NAME + "] query does not support [" + currentFieldName + "]"
);
}
} else if (token == XContentParser.Token.START_OBJECT) {
if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.filter(parseInnerQueryBuilder(parser));
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand All @@ -205,7 +217,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
// create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just
// return the current unmodified query builder.
if (vectorSupplier() != null) {
return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k());
return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k(), filter());
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand All @@ -215,7 +227,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
actionListener.onResponse(null);
}, actionListener::onFailure)))
);
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get);
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get, filter());
}

@Override
Expand All @@ -233,6 +245,7 @@ protected boolean doEquals(NeuralQueryBuilder obj) {
equalsBuilder.append(queryText, obj.queryText);
equalsBuilder.append(modelId, obj.modelId);
equalsBuilder.append(k, obj.k);
equalsBuilder.append(filter, obj.filter);
return equalsBuilder.isEquals();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,19 @@ protected Map<String, Object> getFirstInnerHit(Map<String, Object> searchRespons
return (Map<String, Object>) hits2List.get(0);
}

/**
* Parse the total number of hits from the search
*
* @param searchResponseAsMap Complete search response as a map
* @return number of hits from the search
*/
@SuppressWarnings("unchecked")
protected int getHitCount(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hits1map = (Map<String, Object>) searchResponseAsMap.get("hits");
List<Object> hits1List = (List<Object>) hits1map.get("hits");
return hits1List.size();
}

/**
* Create a k-NN index from a list of KNNFieldConfigs
*
Expand Down
Loading

0 comments on commit de551e2

Please sign in to comment.