From 4372e150e72f8bbd6f069097d010825724bfe5ec Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 23 Apr 2024 14:02:23 -0700 Subject: [PATCH 1/7] Enable script score to work with model based indices Signed-off-by: Ryan Bogan --- .../org/opensearch/knn/plugin/KNNPlugin.java | 2 + .../knn/plugin/script/KNNScoringSpace.java | 11 +++-- .../plugin/script/KNNScoringSpaceUtil.java | 48 +++++++++++++++++++ .../script/KNNScoringSpaceUtilTests.java | 45 +++++++++++++++++ 4 files changed, 101 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 2e5a55092..bc17e80e7 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -33,6 +33,7 @@ import org.opensearch.knn.plugin.rest.RestTrainModelHandler; import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; +import org.opensearch.knn.plugin.script.KNNScoringSpaceUtil; import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.knn.plugin.transport.DeleteModelAction; import org.opensearch.knn.plugin.transport.DeleteModelTransportAction; @@ -204,6 +205,7 @@ public Collection createComponents( TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); + KNNScoringSpaceUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 5a8cdb036..3ba8bce63 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -28,6 +28,7 @@ import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong; public interface KNNScoringSpace { + /** * Return the correct scoring script for a given query. The scoring script * @@ -60,7 +61,7 @@ public L2(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); @@ -96,7 +97,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); SpaceType.COSINESIMIL.validateVector(processedQuery); @@ -191,7 +192,7 @@ public L1(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); @@ -226,7 +227,7 @@ public LInf(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); @@ -263,7 +264,7 @@ public InnerProd(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index c482413fb..888184e54 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -8,6 +8,9 @@ import java.util.List; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.MappedFieldType; @@ -21,6 +24,12 @@ public class KNNScoringSpaceUtil { + private static ModelDao modelDao; + + public static void initialize(ModelDao modelDao) { + KNNScoringSpaceUtil.modelDao = modelDao; + } + /** * Check if the passed in fieldType is of type NumberFieldType with numericType being Long * @@ -137,4 +146,43 @@ public static float getVectorMagnitudeSquared(float[] inputVector) { } return normInputVector; } + + /** + * Get the expected dimensions from a specified knn vector field type. + * + * If the field is model-based, get dimensions from model metadata. + * @param knnVectorFieldType knn vector field type + * @return expected dimensions + */ + public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) { + int expectedDimensions = knnVectorFieldType.getDimension(); + // Value will be -1 when a model-based index is used. In this case, retrieve expected dimensions from model metadata. + if (expectedDimensions == -1) { + ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); + expectedDimensions = modelMetadata.getDimension(); + } + return expectedDimensions; + } + + /** + * Returns the model metadata for a specified knn vector field + * + * @param knnVectorField knn vector field + * @return the model metadata from knnVectorField + */ + private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { + String modelId = knnVectorField.getModelId(); + + if (modelId == null) { + throw new IllegalArgumentException( + String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName()) + ); + } + + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (!ModelUtil.isModelCreated(modelMetadata)) { + throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); + } + return modelMetadata; + } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index b5bc4b95f..1497e3e17 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -6,10 +6,15 @@ package org.opensearch.knn.plugin.script; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; import java.math.BigInteger; import java.util.ArrayList; @@ -75,4 +80,44 @@ public void testParseKNNVectorQuery() { String invalidObject = "invalidObject"; expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT)); } + + public void testGetExpectedDimensions() { + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldType.getDimension()).thenReturn(3); + + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); + String modelId = "test-model"; + when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); + + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getDimension()).thenReturn(4); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + + KNNScoringSpaceUtil.initialize(modelDao); + + assertEquals(3, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldType)); + assertEquals(4, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)); + + when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); + + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased) + ); + assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage()); + + when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null); + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + MethodComponentContext methodComponentContext = mock(MethodComponentContext.class); + String fieldName = "test-field"; + when(methodComponentContext.getName()).thenReturn(fieldName); + when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); + when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext); + + e = expectThrows(IllegalArgumentException.class, () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)); + assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage()); + } } From 6806555db0a7f4b34a3283409fb6779a3feb223c Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 23 Apr 2024 14:08:16 -0700 Subject: [PATCH 2/7] Add changelog entry Signed-off-by: Ryan Bogan --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 782ee9c12..5b3ac9680 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499) ### Bug Fixes * Add stored fields for knn_vector type [#1630](https://github.com/opensearch-project/k-NN/pull/1630) +* Enable script score to work with model based indices [#1649](https://github.com/opensearch-project/k-NN/pull/1649) ### Infrastructure * Add micro-benchmark module in k-NN plugin for benchmark streaming vectors to JNI layer functionality. [#1583](https://github.com/opensearch-project/k-NN/pull/1583) * Add arm64 check when SIMD is disabled [#1618](https://github.com/opensearch-project/k-NN/pull/1618) From b4971f1fc35be34a4c1e11cdb22937698b264614 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 23 Apr 2024 15:29:52 -0700 Subject: [PATCH 3/7] Refactor into KNNVectorFieldMapperUtil and split test into two tests Signed-off-by: Ryan Bogan --- .../mapper/KNNVectorFieldMapperUtil.java | 58 +++++++++++++++++ .../org/opensearch/knn/plugin/KNNPlugin.java | 4 +- .../knn/plugin/script/KNNScoringSpace.java | 11 ++-- .../plugin/script/KNNScoringSpaceUtil.java | 51 +-------------- .../mapper/KNNVectorFieldMapperUtilTests.java | 63 +++++++++++++++++++ .../script/KNNScoringSpaceUtilTests.java | 45 ------------- 6 files changed, 132 insertions(+), 100 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 9b1578a45..c988f9baa 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -21,6 +21,9 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; import java.util.Arrays; import java.util.Locale; @@ -34,9 +37,22 @@ import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; +/** + * Utility class for KNNVectorFieldMapper + */ @NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNVectorFieldMapperUtil { + private static ModelDao modelDao; + + /** + * Initializes static instance variables + * @param modelDao ModelDao object + */ + public static void initialize(final ModelDao modelDao) { + KNNVectorFieldMapperUtil.modelDao = modelDao; + } + /** * Validate the float vector value and throw exception if it is not a number or not in the finite range * or is not within the FP16 range of [-65504 to 65504]. @@ -171,4 +187,46 @@ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataTy return vectorDataType.getVectorFromBytesRef(storedVector); } + + /** + * Get the expected dimensions from a specified knn vector field type. + * + * If the field is model-based, get dimensions from model metadata. + * @param knnVectorFieldType knn vector field type + * @return expected dimensions + */ + public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) { + int expectedDimensions = knnVectorFieldType.getDimension(); + if (isModelBasedIndex(expectedDimensions)) { + ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); + expectedDimensions = modelMetadata.getDimension(); + } + return expectedDimensions; + } + + private static boolean isModelBasedIndex(int expectedDimensions) { + return expectedDimensions == -1; + } + + /** + * Returns the model metadata for a specified knn vector field + * + * @param knnVectorField knn vector field + * @return the model metadata from knnVectorField + */ + private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { + String modelId = knnVectorField.getModelId(); + + if (modelId == null) { + throw new IllegalArgumentException( + String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName()) + ); + } + + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (!ModelUtil.isModelCreated(modelMetadata)) { + throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); + } + return modelMetadata; + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index bc17e80e7..f898b622e 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -14,6 +14,7 @@ import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.KNNClusterUtil; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -33,7 +34,6 @@ import org.opensearch.knn.plugin.rest.RestTrainModelHandler; import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; -import org.opensearch.knn.plugin.script.KNNScoringSpaceUtil; import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.knn.plugin.transport.DeleteModelAction; import org.opensearch.knn.plugin.transport.DeleteModelTransportAction; @@ -205,7 +205,7 @@ public Collection createComponents( TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); - KNNScoringSpaceUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); + KNNVectorFieldMapperUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 3ba8bce63..8105539ba 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -8,6 +8,7 @@ import org.apache.lucene.search.IndexSearcher; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil; import org.opensearch.knn.index.query.KNNWeight; import org.apache.lucene.index.LeafReaderContext; import org.opensearch.index.mapper.MappedFieldType; @@ -61,7 +62,7 @@ public L2(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), + KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); @@ -97,7 +98,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), + KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); SpaceType.COSINESIMIL.validateVector(processedQuery); @@ -192,7 +193,7 @@ public L1(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), + KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); @@ -227,7 +228,7 @@ public LInf(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), + KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); @@ -264,7 +265,7 @@ public InnerProd(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), + KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 888184e54..889780d7a 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -8,9 +8,6 @@ import java.util.List; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; -import org.opensearch.knn.indices.ModelDao; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.MappedFieldType; @@ -22,14 +19,11 @@ import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; +/** + * Utility class for KNNScoringSpace + */ public class KNNScoringSpaceUtil { - private static ModelDao modelDao; - - public static void initialize(ModelDao modelDao) { - KNNScoringSpaceUtil.modelDao = modelDao; - } - /** * Check if the passed in fieldType is of type NumberFieldType with numericType being Long * @@ -146,43 +140,4 @@ public static float getVectorMagnitudeSquared(float[] inputVector) { } return normInputVector; } - - /** - * Get the expected dimensions from a specified knn vector field type. - * - * If the field is model-based, get dimensions from model metadata. - * @param knnVectorFieldType knn vector field type - * @return expected dimensions - */ - public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) { - int expectedDimensions = knnVectorFieldType.getDimension(); - // Value will be -1 when a model-based index is used. In this case, retrieve expected dimensions from model metadata. - if (expectedDimensions == -1) { - ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); - expectedDimensions = modelMetadata.getDimension(); - } - return expectedDimensions; - } - - /** - * Returns the model metadata for a specified knn vector field - * - * @param knnVectorField knn vector field - * @return the model metadata from knnVectorField - */ - private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { - String modelId = knnVectorField.getModelId(); - - if (modelId == null) { - throw new IllegalArgumentException( - String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName()) - ); - } - - ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (!ModelUtil.isModelCreated(modelMetadata)) { - throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); - } - return modelMetadata; - } } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index 3fa9f2363..ff47dcd69 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -13,12 +13,20 @@ import org.apache.lucene.document.StoredField; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; import java.io.ByteArrayInputStream; import java.util.Arrays; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class KNNVectorFieldMapperUtilTests extends KNNTestCase { private static final String TEST_FIELD_NAME = "test_field_name"; @@ -51,4 +59,59 @@ public void testStoredFields_whenVectorIsFloatType_thenSucceed() { assertTrue(vector instanceof float[]); assertArrayEquals(TEST_FLOAT_VECTOR, (float[]) vector, 0.001f); } + + public void testGetExpectedDimensionsSuccess() { + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldType.getDimension()).thenReturn(3); + + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); + String modelId = "test-model"; + when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); + + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getDimension()).thenReturn(4); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + + KNNVectorFieldMapperUtil.initialize(modelDao); + + assertEquals(3, KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldType)); + assertEquals(4, KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)); + } + + public void testGetExpectedDimensionsFailure() { + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); + String modelId = "test-model"; + when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); + + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + + KNNVectorFieldMapperUtil.initialize(modelDao); + + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased) + ); + assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage()); + + when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null); + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + MethodComponentContext methodComponentContext = mock(MethodComponentContext.class); + String fieldName = "test-field"; + when(methodComponentContext.getName()).thenReturn(fieldName); + when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); + when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext); + + e = expectThrows( + IllegalArgumentException.class, + () -> KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased) + ); + assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage()); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index 1497e3e17..b5bc4b95f 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -6,15 +6,10 @@ package org.opensearch.knn.plugin.script; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.KNNMethodContext; -import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; -import org.opensearch.knn.indices.ModelDao; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelState; import java.math.BigInteger; import java.util.ArrayList; @@ -80,44 +75,4 @@ public void testParseKNNVectorQuery() { String invalidObject = "invalidObject"; expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT)); } - - public void testGetExpectedDimensions() { - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); - - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); - when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); - String modelId = "test-model"; - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); - - ModelDao modelDao = mock(ModelDao.class); - ModelMetadata modelMetadata = mock(ModelMetadata.class); - when(modelMetadata.getState()).thenReturn(ModelState.CREATED); - when(modelMetadata.getDimension()).thenReturn(4); - when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); - - KNNScoringSpaceUtil.initialize(modelDao); - - assertEquals(3, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldType)); - assertEquals(4, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)); - - when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); - - IllegalArgumentException e = expectThrows( - IllegalArgumentException.class, - () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased) - ); - assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage()); - - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null); - KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - MethodComponentContext methodComponentContext = mock(MethodComponentContext.class); - String fieldName = "test-field"; - when(methodComponentContext.getName()).thenReturn(fieldName); - when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); - when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext); - - e = expectThrows(IllegalArgumentException.class, () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)); - assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage()); - } } From 613743a8737aaa1ec3053bfdabb10013c77c7970 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 23 Apr 2024 15:44:38 -0700 Subject: [PATCH 4/7] Make parameters final for public methods Signed-off-by: Ryan Bogan --- .../opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index c988f9baa..8bd7eb6f2 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -195,7 +195,7 @@ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataTy * @param knnVectorFieldType knn vector field type * @return expected dimensions */ - public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) { + public static int getExpectedDimensions(final KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) { int expectedDimensions = knnVectorFieldType.getDimension(); if (isModelBasedIndex(expectedDimensions)) { ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); @@ -214,7 +214,7 @@ private static boolean isModelBasedIndex(int expectedDimensions) { * @param knnVectorField knn vector field * @return the model metadata from knnVectorField */ - private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { + private static ModelMetadata getModelMetadataForField(final KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { String modelId = knnVectorField.getModelId(); if (modelId == null) { From 708ef61632cffd1605805600bf2b2408508729ef Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 25 Apr 2024 11:26:19 -0700 Subject: [PATCH 5/7] Add integration test Signed-off-by: Ryan Bogan --- .../knn/plugin/script/KNNScriptScoringIT.java | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 8d014afec..abc3579c5 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -39,9 +39,12 @@ import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.knn.common.KNNConstants.*; public class KNNScriptScoringIT extends KNNRestTestCase { + private static final String TEST_MODEL = "test-model"; + public void testKNNL2ScriptScore() throws Exception { testKNNScriptScore(SpaceType.L2); } @@ -550,6 +553,47 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { assertEquals(1, secondQueryCacheMap.get("hit_count")); } + public void testKNNScriptScoreOnModelBasedIndex() throws Exception { + int dimensions = randomIntBetween(2, 10); + String modelName = TEST_MODEL; + String trainMapping = createKnnIndexMapping(TRAIN_FIELD_PARAMETER, dimensions); + createKnnIndex(TRAIN_INDEX_PARAMETER, trainMapping); + bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions * 3, dimensions); + + XContentBuilder methodBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 4) + .field(METHOD_PARAMETER_NPROBES, 2) + .endObject() + .endObject(); + Map method = xContentBuilderToMap(methodBuilder); + + trainModel(modelName, TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions, method, "test model for script score"); + assertTrainingSucceeds(modelName, 30, 1000); + + String testMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(MODEL_ID, modelName) + .endObject() + .endObject() + .endObject() + .toString(); + + for (SpaceType spaceType : SpaceType.values()) { + if (spaceType != SpaceType.HAMMING_BIT) { + final float[] queryVector = randomVector(dimensions); + final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); + createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector); + } + } + } + private List createMappers(int dimensions) throws Exception { return List.of( createKnnIndexMapping(FIELD_NAME, dimensions), From 1f2a14d0f5bd53e01f49b3f88e05b5b62dd53a54 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 25 Apr 2024 11:30:44 -0700 Subject: [PATCH 6/7] Transfer variable to constant Signed-off-by: Ryan Bogan --- .../opensearch/knn/plugin/script/KNNScriptScoringIT.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index abc3579c5..0e9f7857b 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -555,7 +555,6 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { public void testKNNScriptScoreOnModelBasedIndex() throws Exception { int dimensions = randomIntBetween(2, 10); - String modelName = TEST_MODEL; String trainMapping = createKnnIndexMapping(TRAIN_FIELD_PARAMETER, dimensions); createKnnIndex(TRAIN_INDEX_PARAMETER, trainMapping); bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions * 3, dimensions); @@ -571,15 +570,15 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception { .endObject(); Map method = xContentBuilderToMap(methodBuilder); - trainModel(modelName, TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions, method, "test model for script score"); - assertTrainingSucceeds(modelName, 30, 1000); + trainModel(TEST_MODEL, TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions, method, "test model for script score"); + assertTrainingSucceeds(TEST_MODEL, 30, 1000); String testMapping = XContentFactory.jsonBuilder() .startObject() .startObject(PROPERTIES_FIELD) .startObject(FIELD_NAME) .field(TYPE, TYPE_KNN_VECTOR) - .field(MODEL_ID, modelName) + .field(MODEL_ID, TEST_MODEL) .endObject() .endObject() .endObject() From a1375db4dc1a44dc45b9e9234b784c609d30175b Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 25 Apr 2024 12:28:30 -0700 Subject: [PATCH 7/7] Remove star import Signed-off-by: Ryan Bogan --- .../knn/plugin/script/KNNScriptScoringIT.java | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 0e9f7857b..11c626ff7 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -39,7 +39,18 @@ import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; -import static org.opensearch.knn.common.KNNConstants.*; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; +import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; public class KNNScriptScoringIT extends KNNRestTestCase {