Skip to content

Commit

Permalink
Revert "[Backport 2.x] Add ignore_unmapped support in KNNQueryBuilder (
Browse files Browse the repository at this point in the history
…opensearch-project#1152)"

This reverts commit 39836c6.

Signed-off-by: Heemin Kim <heemin@amazon.com>
  • Loading branch information
heemin32 committed Oct 11, 2023
1 parent f807aa2 commit ea41fce
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 78 deletions.
19 changes: 0 additions & 19 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import org.opensearch.Version;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.common.ValidationException;
Expand All @@ -25,7 +24,6 @@

import java.io.File;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
Expand All @@ -34,15 +32,6 @@

public class IndexUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_2_4_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_10_0;
public static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("filter", MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
}
};

/**
* Determines the size of a file on disk in kilobytes
*
Expand Down Expand Up @@ -206,12 +195,4 @@ public static Map<String, Object> getParametersAtLoading(SpaceType spaceType, KN

return Collections.unmodifiableMap(loadParameters);
}

public static boolean isClusterOnOrAfterMinRequiredVersion(String key) {
Version minimalRequiredVersion = minimalRequiredVersionMap.get(key);
if (minimalRequiredVersion == null) {
return false;
}
return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(minimalRequiredVersion);
}
}
57 changes: 12 additions & 45 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.commons.lang.StringUtils;
import org.opensearch.Version;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
Expand All @@ -32,7 +33,6 @@
import java.util.List;
import java.util.Objects;

import static org.opensearch.knn.index.IndexUtil.*;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;

/**
Expand All @@ -45,7 +45,6 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -58,7 +57,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final float[] vector;
private int k = 0;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_2_4_0;

/**
* Constructs a new knn query
Expand Down Expand Up @@ -92,7 +91,6 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.vector = vector;
this.k = k;
this.filter = filter;
this.ignoreUnmapped = false;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -119,12 +117,9 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
k = in.readInt();
// We're checking if all cluster nodes has at least that version or higher. This check is required
// to avoid issues with cluster upgrade
if (isClusterOnOrAfterMinRequiredVersion("filter")) {
if (isClusterOnOrAfterMinRequiredVersion()) {
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = in.readOptionalBoolean();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -136,7 +131,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
int k = 0;
QueryBuilder filter = null;
boolean ignoreUnmapped = false;
String queryName = null;
String currentFieldName = null;
XContentParser.Token token;
Expand All @@ -159,10 +153,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals("ignore_unmapped")) {
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = parser.booleanValue();
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand All @@ -178,20 +168,20 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
// MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER variable.
// Here we're checking if all cluster nodes has at least that version or higher. This check is required
// to avoid issues with rolling cluster upgrade
if (isClusterOnOrAfterMinRequiredVersion("filter")) {
if (isClusterOnOrAfterMinRequiredVersion()) {
filter = parseInnerQueryBuilder(parser);
} else {
log.debug(
String.format(
"This version of k-NN doesn't support [filter] field, minimal required version is [%s]",
minimalRequiredVersionMap.get("filter")
MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER
)
);
throw new IllegalArgumentException(
String.format(
"%s field is supported from version %s",
FILTER_FIELD.getPreferredName(),
minimalRequiredVersionMap.get("filter")
MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER
)
);
}
Expand All @@ -214,9 +204,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter);
knnQueryBuilder.queryName(queryName);
if (isClusterOnOrAfterMinRequiredVersion("ignoreUnmapped")) {
knnQueryBuilder.ignoreUnmapped(ignoreUnmapped);
}
knnQueryBuilder.boost(boost);
return knnQueryBuilder;
}
Expand All @@ -228,12 +215,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeInt(k);
// We're checking if all cluster nodes has at least that version or higher. This check is required
// to avoid issues with cluster upgrade
if (isClusterOnOrAfterMinRequiredVersion("filter")) {
if (isClusterOnOrAfterMinRequiredVersion()) {
out.writeOptionalNamedWriteable(filter);
}
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
out.writeOptionalBoolean(ignoreUnmapped);
}
}

/**
Expand All @@ -258,20 +242,6 @@ public QueryBuilder getFilter() {
return this.filter;
}

/**
* Sets whether the query builder should ignore unmapped paths (and run a
* {@link MatchNoDocsQuery} in place of this query) or throw an exception if
* the path is unmapped.
*/
public KNNQueryBuilder ignoreUnmapped(boolean ignoreUnmapped) {
this.ignoreUnmapped = ignoreUnmapped;
return this;
}

public boolean getIgnoreUnmapped() {
return this.ignoreUnmapped;
}

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
Expand All @@ -282,9 +252,6 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand All @@ -294,10 +261,6 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
protected Query doToQuery(QueryShardContext context) {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);

if (mappedFieldType == null && ignoreUnmapped) {
return new MatchNoDocsQuery();
}

if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) {
throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName));
}
Expand Down Expand Up @@ -382,4 +345,8 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinRequiredVersion() {
return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.cluster.ClusterModule;
Expand Down Expand Up @@ -42,7 +41,6 @@
import java.util.List;
import java.util.Optional;

import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -299,18 +297,6 @@ public void testSerialization() throws Exception {
assertSerialization(Version.V_2_3_0, Optional.empty());
}

public void testIgnoreUnmapped() throws IOException {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
knnQueryBuilder.ignoreUnmapped(true);
assertTrue(knnQueryBuilder.getIgnoreUnmapped());
Query query = knnQueryBuilder.doToQuery(mock(QueryShardContext.class));
assertNotNull(query);
assertThat(query, instanceOf(MatchNoDocsQuery.class));
knnQueryBuilder.ignoreUnmapped(false);
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mock(QueryShardContext.class)));
}

private void assertSerialization(final Version version, final Optional<QueryBuilder> queryBuilderOptional) throws Exception {
final KNNQueryBuilder knnQueryBuilder = queryBuilderOptional.isPresent()
? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, queryBuilderOptional.get())
Expand Down

0 comments on commit ea41fce

Please sign in to comment.