Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support neural query type #674

Merged
merged 5 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import org.opensearch.client.json.JsonpDeserializable;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;

import javax.annotation.Nullable;
import java.util.function.Function;

@JsonpDeserializable
public class NeuralQuery extends QueryBase implements QueryVariant {

private final String field;
private final String queryText;
private final int k;
@Nullable
private final String modelId;


private NeuralQuery(NeuralQuery.Builder builder) {
super(builder);

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.queryText = ApiTypeHelper.requireNonNull(builder.queryText, this, "queryText");
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.modelId = builder.modelId;
}

public static NeuralQuery of(Function<NeuralQuery.Builder, ObjectBuilder<NeuralQuery>> fn) {
return fn.apply(new NeuralQuery.Builder()).build();
}

/**
* Query variant kind.
*
* @return The query variant kind.
*/
@Override
public Query.Kind _queryKind() {
return null;
reta marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Required - The target field.
*
* @return The target field.
*/
public final String field() {
return this.field;
}

/**
* Required - Search query text.
*
* @return Search query text.
*/
public final String queryText() {
return this.queryText;
}

/**
* Required - The number of neighbors to return.
*
* @return The number of neighbors to return.
*/
public final int k() {
return this.k;
}

/**
* Builder for {@link NeuralQuery}.
*/

/**
* Optional - The model_id field if the default model for the index or field is set.
* Required - The model_id field if there is no default model set for the index or field.
*
* @return The model_id field.
*/
@Nullable
public final String modelId() {
return this.modelId;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeStartObject(this.field);

super.serializeInternal(generator, mapper);

generator.write("query_text", this.queryText);

if (this.modelId != null) {
generator.write("model_id", this.modelId);
}

generator.write("k", this.k);

generator.writeEnd();
}

/**
* Builder for {@link NeuralQuery}.
*/
public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builder> implements ObjectBuilder<NeuralQuery> {
@Nullable
reta marked this conversation as resolved.
Show resolved Hide resolved
private String field;
@Nullable
private String queryText;
@Nullable
private String modelId;
@Nullable
private Integer k;

/**
* Required - The target field.
*
* @param field The target field.
* @return This builder.
*/
public NeuralQuery.Builder field(@Nullable String field) {
this.field = field;
return this;
}

/**
* Required - Search query text.
*
* @param queryText Search query text.
* @return This builder.
*/
public NeuralQuery.Builder queryText(@Nullable String queryText) {
this.queryText = queryText;
return this;
}

/**
* Optional - The model_id field if the default model for the index or field is set.
* Required - The model_id field if there is no default model set for the index or field.
*
* @param modelId The model_id field.
* @return This builder.
*/
public NeuralQuery.Builder modelId(@Nullable String modelId) {
this.modelId = modelId;
return this;
}

/**
* Required - The number of neighbors to return.
*
* @param k The number of neighbors to return.
* @return This builder.
*/
public NeuralQuery.Builder k(@Nullable Integer k) {
this.k = k;
return this;
}

@Override
protected NeuralQuery.Builder self() {
return this;
}

/**
* Builds a {@link NeuralQuery}.
*
* @return The built {@link NeuralQuery}.
*/
@Override
public NeuralQuery build() {
_checkSingleUse();

return new NeuralQuery(this);
}
}

public static final JsonpDeserializer<NeuralQuery> _DESERIALIZER = ObjectBuilderDeserializer.lazy(
NeuralQuery.Builder::new,
NeuralQuery::setupNeuralQueryDeserializer
);

protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuery.Builder> op) {
setupQueryBaseDeserializer(op);

op.add(NeuralQuery.Builder::queryText, JsonpDeserializer.stringDeserializer(), "query_text");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");


op.setKey(NeuralQuery.Builder::field, JsonpDeserializer.stringDeserializer());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ public enum Kind implements JsonEnum {

Nested("nested"),

Neural("neural"),

ParentId("parent_id"),

Percolate("percolate"),
Expand Down Expand Up @@ -706,6 +708,23 @@ public NestedQuery nested() {
return TaggedUnionUtils.get(this, Kind.Nested);
}

/**
* Is this variant instance of kind {@code neural}?
*/
public boolean isNeural() {
return _kind == Kind.Neural;
}

/**
* Get the {@code neural} variant value.
*
* @throws IllegalStateException
* if the current variant is not of the {@code neural} kind.
*/
public NeuralQuery neural() {
return TaggedUnionUtils.get(this, Kind.Neural);
}

/**
* Is this variant instance of kind {@code parent_id}?
*/
Expand Down Expand Up @@ -1450,6 +1469,16 @@ public ObjectBuilder<Query> nested(Function<NestedQuery.Builder, ObjectBuilder<N
return this.nested(fn.apply(new NestedQuery.Builder()).build());
}

public ObjectBuilder<Query> neural(NeuralQuery v) {
this._kind = Kind.Neural;
this._value = v;
return this;
}

public ObjectBuilder<Query> neural(Function<NeuralQuery.Builder, ObjectBuilder<NeuralQuery>> fn) {
return this.neural(fn.apply(new NeuralQuery.Builder()).build());
}

public ObjectBuilder<Query> parentId(ParentIdQuery v) {
this._kind = Kind.ParentId;
this._value = v;
Expand Down Expand Up @@ -1747,6 +1776,7 @@ protected static void setupQueryDeserializer(ObjectDeserializer<Builder> op) {
op.add(Builder::moreLikeThis, MoreLikeThisQuery._DESERIALIZER, "more_like_this");
op.add(Builder::multiMatch, MultiMatchQuery._DESERIALIZER, "multi_match");
op.add(Builder::nested, NestedQuery._DESERIALIZER, "nested");
op.add(Builder::neural, NeuralQuery._DESERIALIZER, "neural");
op.add(Builder::parentId, ParentIdQuery._DESERIALIZER, "parent_id");
op.add(Builder::percolate, PercolateQuery._DESERIALIZER, "percolate");
op.add(Builder::pinned, PinnedQuery._DESERIALIZER, "pinned");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ public static NestedQuery.Builder nested() {
return new NestedQuery.Builder();
}

/**
* Creates a builder for the {@link NeuralQuery nested} {@code Query} variant.
*/
public static NeuralQuery.Builder neural() {
return new NeuralQuery.Builder();
}

/**
* Creates a builder for the {@link ParentIdQuery parent_id} {@code Query}
* variant.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,48 @@ public void testNestedVariantsWithContainerProperties() {
assertEquals("m1 value", search.aggregations().get("agg1").meta().get("m1").to(String.class));
assertEquals("m2 value", search.aggregations().get("agg1").meta().get("m2").to(String.class));
}

@Test
public void testNeuralQuery() {

SearchRequest searchRequest = SearchRequest.of(
s -> s.query(
q -> q.neural(n -> n.field("passage_embedding")
.queryText("Hi world")
.modelId("bQ1J8ooBpBj3wT4HVUsb")
.k(100)
)
)
);

assertEquals("passage_embedding", searchRequest.query().neural().field());
assertEquals("Hi world", searchRequest.query().neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
}

@Test
public void testNeuralQueryFromJson() {

String json = "{\n" +
" \"from\": 0,\n" +
" \"size\": 100,\n" +
" \"query\": {\n" +
" \"neural\": {\n" +
" \"passage_embedding\": {\n" +
" \"query_text\": \"Hi world\",\n" +
" \"model_id\": \"bQ1J8ooBpBj3wT4HVUsb\",\n" +
" \"k\": 100\n" +
" }\n" +
" }\n" +
" }\n" +
"}";

SearchRequest searchRequest = ModelTestCase.fromJson(json, SearchRequest.class, mapper);

assertEquals("passage_embedding", searchRequest.query().neural().field());
assertEquals("Hi world", searchRequest.query().neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
}
}
Loading