Skip to content

Commit

Permalink
Support Radial Search (#1166) (#1240)
Browse files Browse the repository at this point in the history
* Generate the ml namespace (#1158)

* Generate ml.register_model_group



* Start neural search sample



* Re-generate ShardStatistics



* Re-generate ShardFailure



* Re-generate Result



* Re-generate WriteResponseBase



* Generate ml.delete_model_group



* Generate ml.register_model



* Exclude legacy license from ml namespace



* Generate ml.get_task



* Generate ml.delete_task



* Generate ml.delete_model



* Generate ml.deploy_model



* Generate ml.undeploy_model



* Complete neural search sample



* Generate ml.get_model



* Add changelog entry



* note



* Fix tests



---------




* Fix copy-paste mistake in NeuralSearch sample (#1161)




* Support Radial Search

Add minScore, maxDistance parameters to KnnQuery in order to support Radial Search, which was introduced in OpenSearch 2.14
https://opensearch.org/docs/latest/search-plugins/knn/radial-search-knn/



* Update CHANGELOG.md




* Update changelog post releasing 2.14.0 (#1162) (#1167)


(cherry picked from commit 2a362a6)

* Reduce required release approvals (#1168)




---------





(cherry picked from commit 3902aef)

Signed-off-by: Thomas Farr <tsfarr@amazon.com>
Signed-off-by: Alex Keeler <akeeler4227@gmail.com>
Signed-off-by: alex-keeler <59743435+alex-keeler@users.noreply.github.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Thomas Farr <tsfarr@amazon.com>
  • Loading branch information
3 people authored Oct 22, 2024
1 parent 4a5660a commit 80c3e22
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## [Unreleased 2.x]
### Added
- Added `minScore` and `maxDistance` to `KnnQuery` ([#1166](https://github.com/opensearch-project/opensearch-java/pull/1166))

### Dependencies

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
public class KnnQuery extends QueryBase implements QueryVariant {
private final String field;
private final float[] vector;
private final int k;
@Nullable
private final Integer k;
@Nullable
private final Float minScore;
@Nullable
private final Float maxDistance;
@Nullable
private final Query filter;

Expand All @@ -32,7 +37,9 @@ private KnnQuery(Builder builder) {

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.vector = ApiTypeHelper.requireNonNull(builder.vector, this, "vector");
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.k = builder.k;
this.minScore = builder.minScore;
this.maxDistance = builder.maxDistance;
this.filter = builder.filter;
}

Expand Down Expand Up @@ -66,13 +73,29 @@ public final float[] vector() {
}

/**
* Required - The number of neighbors the search of each graph will return.
* Optional - The number of neighbors the search of each graph will return.
* @return The number of neighbors to return.
*/
public final int k() {
public final Integer k() {
return this.k;
}

/**
* Optional - The minimum score allowed for the returned search results.
* @return The minimum score allowed for the returned search results.
*/
private final Float minScore() {
return this.minScore;
}

/**
* Optional - The maximum distance allowed between the vector and each of the returned search results.
* @return The maximum distance allowed between the vector and each ofthe returned search results.
*/
private final Float maxDistance() {
return this.maxDistance;
}

/**
* Optional - A query to filter the results of the query.
* @return The filter query.
Expand All @@ -97,7 +120,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}
generator.writeEnd();

generator.write("k", this.k);
if (this.k != null) {
generator.write("k", this.k);
}

if (this.minScore != null) {
generator.write("min_score", this.minScore);
}

if (this.maxDistance != null) {
generator.write("max_distance", this.maxDistance);
}

if (this.filter != null) {
generator.writeKey("filter");
Expand All @@ -108,7 +141,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}

public Builder toBuilder() {
return toBuilder(new Builder()).field(field).vector(vector).k(k).filter(filter);
return toBuilder(new Builder()).field(field).vector(vector).k(k).minScore(minScore).maxDistance(maxDistance).filter(filter);
}

/**
Expand All @@ -122,6 +155,10 @@ public static class Builder extends QueryBase.AbstractBuilder<Builder> implement
@Nullable
private Integer k;
@Nullable
private Float minScore;
@Nullable
private Float maxDistance;
@Nullable
private Query filter;

/**
Expand Down Expand Up @@ -156,6 +193,28 @@ public Builder k(@Nullable Integer k) {
return this;
}

/**
* Optional - The minimum score allowed for the returned search results.
*
* @param minScore The minimum score allowed for the returned search results.
* @return This builder.
*/
public Builder minScore(@Nullable Float minScore) {
this.minScore = minScore;
return this;
}

/**
* Optional - The maximum distance allowed between the vector and each of the returned search results.
*
* @param maxDistance The maximum distance allowed between the vector and each ofthe returned search results.
* @return This builder.
*/
public Builder maxDistance(@Nullable Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

/**
* Optional - A query to filter the results of the knn query.
*
Expand Down Expand Up @@ -201,6 +260,8 @@ protected static void setupKnnQueryDeserializer(ObjectDeserializer<Builder> op)
b.vector(vector);
}, JsonpDeserializer.arrayDeserializer(JsonpDeserializer.floatDeserializer()), "vector");
op.add(Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score");
op.add(Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance");
op.add(Builder::filter, Query._DESERIALIZER, "filter");

op.setKey(Builder::field, JsonpDeserializer.stringDeserializer());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
public class KnnQueryTest extends ModelTestCase {
@Test
public void toBuilder() {
KnnQuery origin = new KnnQuery.Builder().field("field").vector(new float[] { 1.0f }).k(1).build();
KnnQuery origin = new KnnQuery.Builder().field("field").vector(new float[] { 1.0f }).k(1).minScore(0.0f).maxDistance(1.0f).build();
KnnQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ public void testHybridQuery() {
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k());
assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k());
}

@Test
Expand All @@ -304,6 +304,6 @@ public void testHybridQueryFromJson() {
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k());
assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k());
}
}

0 comments on commit 80c3e22

Please sign in to comment.