Skip to content

Commit

Permalink
Refactor knn type and codecs (#439)
Browse files Browse the repository at this point in the history
* Refactor knn type and codecs

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski authored Jul 12, 2022
1 parent 7af37c8 commit 2fc09ba
Show file tree
Hide file tree
Showing 15 changed files with 137 additions and 152 deletions.
33 changes: 4 additions & 29 deletions src/main/java/org/opensearch/knn/index/KNNMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

package org.opensearch.knn.index;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;

Expand All @@ -26,31 +28,13 @@
* KNNMethod is used to define the structure of a method supported by a particular k-NN library. It is used to validate
* the KNNMethodContext passed in by the user. It is also used to provide superficial string translations.
*/
@AllArgsConstructor
@Getter
public class KNNMethod {

private final MethodComponent methodComponent;
private final Set<SpaceType> spaces;

/**
* KNNMethod Constructor
*
* @param methodComponent top level method component that is compatible with the underlying library
* @param spaces set of valid space types that the method supports
*/
public KNNMethod(MethodComponent methodComponent, Set<SpaceType> spaces) {
this.methodComponent = methodComponent;
this.spaces = spaces;
}

/**
* getMainMethodComponent
*
* @return mainMethodComponent
*/
public MethodComponent getMethodComponent() {
return methodComponent;
}

/**
* Determines whether the provided space is supported for this method
*
Expand All @@ -61,15 +45,6 @@ public boolean containsSpace(SpaceType space) {
return spaces.contains(space);
}

/**
* Get all valid spaces for this method
*
* @return spaces that can be used with this method
*/
public Set<SpaceType> getSpaces() {
return spaces;
}

/**
* Validate that the configured KNNMethodContext is valid for this method
*
Expand Down
48 changes: 4 additions & 44 deletions src/main/java/org/opensearch/knn/index/KNNMethodContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

package org.opensearch.knn.index;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.common.ValidationException;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -40,10 +40,10 @@
* KNNMethodContext will contain the information necessary to produce a library index from an Opensearch mapping.
* It will encompass all parameters necessary to build the index.
*/
@AllArgsConstructor
@Getter
public class KNNMethodContext implements ToXContentFragment, Writeable {

private static final Logger logger = LogManager.getLogger(KNNMethodContext.class);

private static KNNMethodContext defaultInstance = null;

public static synchronized KNNMethodContext getDefault() {
Expand All @@ -61,19 +61,6 @@ public static synchronized KNNMethodContext getDefault() {
private final SpaceType spaceType;
private final MethodComponentContext methodComponent;

/**
* Constructor
*
* @param knnEngine engine that this method uses
* @param spaceType space type that this method uses
* @param methodComponent MethodComponent describing the main index
*/
public KNNMethodContext(KNNEngine knnEngine, SpaceType spaceType, MethodComponentContext methodComponent) {
this.knnEngine = knnEngine;
this.spaceType = spaceType;
this.methodComponent = methodComponent;
}

/**
* Constructor from stream.
*
Expand All @@ -86,33 +73,6 @@ public KNNMethodContext(StreamInput in) throws IOException {
this.methodComponent = new MethodComponentContext(in);
}

/**
* Gets the main method component
*
* @return methodComponent
*/
public MethodComponentContext getMethodComponent() {
return methodComponent;
}

/**
* Gets the engine to be used for this context
*
* @return knnEngine
*/
public KNNEngine getEngine() {
return knnEngine;
}

/**
* Gets the space type for this context
*
* @return spaceType
*/
public SpaceType getSpaceType() {
return spaceType;
}

/**
* This method uses the knnEngine to validate that the method is compatible with the engine
*
Expand Down
27 changes: 13 additions & 14 deletions src/main/java/org/opensearch/knn/index/KNNVectorFieldMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index;

import lombok.Getter;
import org.opensearch.common.Strings;
import org.opensearch.common.ValidationException;
import org.opensearch.common.xcontent.XContentFactory;
Expand Down Expand Up @@ -206,7 +207,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {
if (knnMethodContext != null) {
return new MethodFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), meta.getValue(), dimension.getValue()),
new KNNVectorFieldType(buildFullName(context), meta.getValue(), dimension.getValue(), knnMethodContext),
multiFieldsBuilder.build(this, context),
copyTo.build(),
ignoreMalformed(context),
Expand All @@ -225,7 +226,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {

return new ModelFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), meta.getValue(), -1, modelIdAsString),
new KNNVectorFieldType(buildFullName(context), meta.getValue(), -1, knnMethodContext, modelIdAsString),
multiFieldsBuilder.build(this, context),
copyTo.build(),
ignoreMalformed(context),
Expand Down Expand Up @@ -296,19 +297,25 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
}
}

@Getter
public static class KNNVectorFieldType extends MappedFieldType {

int dimension;
String modelId;
KNNMethodContext knnMethodContext;

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension) {
this(name, meta, dimension, null);
this(name, meta, dimension, null, null);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext) {
this(name, meta, dimension, knnMethodContext, null);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, String modelId) {
public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext, String modelId) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.dimension = dimension;
this.modelId = modelId;
this.knnMethodContext = knnMethodContext;
}

@Override
Expand All @@ -334,14 +341,6 @@ public Query termQuery(Object value, QueryShardContext context) {
);
}

public int getDimension() {
return dimension;
}

public String getModelId() {
return modelId;
}

@Override
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
failIfNoDocValues();
Expand Down Expand Up @@ -623,7 +622,7 @@ private MethodFieldMapper(
this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension));
this.fieldType.putAttribute(SPACE_TYPE, knnMethodContext.getSpaceType().getValue());

KNNEngine knnEngine = knnMethodContext.getEngine();
KNNEngine knnEngine = knnMethodContext.getKnnEngine();
this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName());

try {
Expand Down
21 changes: 3 additions & 18 deletions src/main/java/org/opensearch/knn/index/MethodComponent.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

package org.opensearch.knn.index;

import lombok.Getter;
import org.opensearch.common.TriFunction;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
Expand All @@ -26,7 +27,9 @@
*/
public class MethodComponent {

@Getter
private String name;
@Getter
private Map<String, Parameter<?>> parameters;
private BiFunction<MethodComponent, MethodComponentContext, Map<String, Object>> mapGenerator;
private TriFunction<MethodComponent, MethodComponentContext, Integer, Long> overheadInKBEstimator;
Expand All @@ -45,24 +48,6 @@ private MethodComponent(Builder builder) {
this.requiresTraining = builder.requiresTraining;
}

/**
* Get the name of the component
*
* @return name
*/
public String getName() {
return name;
}

/**
* Get the parameters for the component
*
* @return parameters
*/
public Map<String, Parameter<?>> getParameters() {
return parameters;
}

/**
* Parse methodComponentContext into a map that the library can use to configure the method
*
Expand Down
28 changes: 4 additions & 24 deletions src/main/java/org/opensearch/knn/index/MethodComponentContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

package org.opensearch.knn.index;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
Expand All @@ -36,24 +36,13 @@
*
* Each component is composed of a name and a map of parameters.
*/
@AllArgsConstructor
public class MethodComponentContext implements ToXContentFragment, Writeable {

private static final Logger logger = LogManager.getLogger(MethodComponentContext.class);

@Getter
private final String name;
private final Map<String, Object> parameters;

/**
* Constructor
*
* @param name component name
* @param parameters component parameters
*/
public MethodComponentContext(String name, Map<String, Object> parameters) {
this.name = name;
this.parameters = parameters;
}

/**
* Constructor from stream.
*
Expand Down Expand Up @@ -183,15 +172,6 @@ public int hashCode() {
return new HashCodeBuilder().append(name).append(parameters).toHashCode();
}

/**
* Gets the name of the component
*
* @return name
*/
public String getName() {
return name;
}

/**
* Gets the parameters of the component
*
Expand Down
23 changes: 13 additions & 10 deletions src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,36 @@
import com.google.common.collect.ImmutableMap;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec;
import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.codec.util.CodecBuilder;

import java.lang.reflect.Constructor;
import java.util.Map;

/**
* Factory abstraction for KNN codec
*/
public class KNNCodecFactory {

private static Map<KNNCodecVersion, Class> CODEC_BY_VERSION = ImmutableMap.of(KNNCodecVersion.KNN910, KNN910Codec.class);
private final Map<KNNCodecVersion, CodecBuilder> codecByVersion;

private static KNNCodecVersion LATEST_KNN_CODEC_VERSION = KNNCodecVersion.KNN910;
private static final KNNCodecVersion LATEST_KNN_CODEC_VERSION = KNNCodecVersion.KNN910;

public static Codec createKNNCodec(final Codec userCodec) {
public KNNCodecFactory(MapperService mapperService) {
codecByVersion = ImmutableMap.of(KNNCodecVersion.KNN910, new CodecBuilder.KNN91CodecBuilder(mapperService));
}

public Codec createKNNCodec(final Codec userCodec) {
return getCodec(LATEST_KNN_CODEC_VERSION, userCodec);
}

public static Codec createKNNCodec(final KNNCodecVersion knnCodecVersion, final Codec userCodec) {
public Codec createKNNCodec(final KNNCodecVersion knnCodecVersion, final Codec userCodec) {
return getCodec(knnCodecVersion, userCodec);
}

private static Codec getCodec(final KNNCodecVersion knnCodecVersion, final Codec userCodec) {
private Codec getCodec(final KNNCodecVersion knnCodecVersion, final Codec userCodec) {
try {
Constructor<?> constructor = CODEC_BY_VERSION.getOrDefault(knnCodecVersion, CODEC_BY_VERSION.get(LATEST_KNN_CODEC_VERSION))
.getConstructor(Codec.class);
return (Codec) constructor.newInstance(userCodec);
final CodecBuilder codecBuilder = codecByVersion.getOrDefault(knnCodecVersion, codecByVersion.get(LATEST_KNN_CODEC_VERSION));
return codecBuilder.userCodec(userCodec).build();
} catch (Exception ex) {
throw new RuntimeException("Cannot create instance of KNN codec", ex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
*/
public class KNNCodecService extends CodecService {

private final KNNCodecFactory knnCodecFactory;

public KNNCodecService(CodecServiceConfig codecServiceConfig) {
super(codecServiceConfig.getMapperService(), codecServiceConfig.getLogger());
knnCodecFactory = new KNNCodecFactory(codecServiceConfig.getMapperService());
}

/**
Expand All @@ -26,6 +29,6 @@ public KNNCodecService(CodecServiceConfig codecServiceConfig) {
*/
@Override
public Codec codec(String name) {
return KNNCodecFactory.createKNNCodec(super.codec(name));
return knnCodecFactory.createKNNCodec(super.codec(name));
}
}
Loading

0 comments on commit 2fc09ba

Please sign in to comment.