Skip to content

Commit

Permalink
Add text embedding processor (opensearch-project#1007)
Browse files Browse the repository at this point in the history
* Add text embedding processor

Signed-off-by: miguel-vila <miguelvilag@gmail.com>

* add changelog entry

Signed-off-by: miguel-vila <miguelvilag@gmail.com>

* add (de)serialization test

Signed-off-by: miguel-vila <miguelvilag@gmail.com>

* fix files headers

Signed-off-by: miguel-vila <miguelvilag@gmail.com>

* fix for java 8

Signed-off-by: miguel-vila <miguelvilag@gmail.com>

---------

Signed-off-by: miguel-vila <miguelvilag@gmail.com>
  • Loading branch information
miguel-vila authored Jun 3, 2024
1 parent cee6818 commit 5b90848
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ This section is for maintaining a changelog for all breaking changes for the cli

### Added

- Added support for [text embedding processor](https://opensearch.org/docs/latest/ingest-pipelines/processors/text-embedding/) ([#1007](https://github.com/opensearch-project/opensearch-java/pull/1007))

### Dependencies

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ public enum Kind implements JsonEnum {

Inference("inference"),

TextEmbedding("text_embedding"),

;

private final String jsonValue;
Expand Down Expand Up @@ -735,6 +737,23 @@ public InferenceProcessor inference() {
return TaggedUnionUtils.get(this, Kind.Inference);
}

/**
* Is this variant instance of kind {@code text_embedding}?
*/
public boolean isTextEmbedding() {
return _kind == Kind.TextEmbedding;
}

/**
* Get the {@code text_embedding} variant value.
*
* @throws IllegalStateException
* if the current variant is not of the {@code text_embedding} kind.
*/
public TextEmbeddingProcessor textEmbedding() {
return TaggedUnionUtils.get(this, Kind.TextEmbedding);
}

@Override
@SuppressWarnings("unchecked")
public void serialize(JsonGenerator generator, JsonpMapper mapper) {
Expand Down Expand Up @@ -1086,6 +1105,16 @@ public ObjectBuilder<Processor> inference(Function<InferenceProcessor.Builder, O
return this.inference(fn.apply(new InferenceProcessor.Builder()).build());
}

public ObjectBuilder<Processor> textEmbedding(TextEmbeddingProcessor v) {
this._kind = Kind.TextEmbedding;
this._value = v;
return this;
}

public ObjectBuilder<Processor> textEmbedding(Function<TextEmbeddingProcessor.Builder, ObjectBuilder<TextEmbeddingProcessor>> fn) {
return this.textEmbedding(fn.apply(new TextEmbeddingProcessor.Builder()).build());
}

public Processor build() {
_checkSingleUse();
return new Processor(this);
Expand Down Expand Up @@ -1128,6 +1157,7 @@ protected static void setupProcessorDeserializer(ObjectDeserializer<Builder> op)
op.add(Builder::drop, DropProcessor._DESERIALIZER, "drop");
op.add(Builder::circle, CircleProcessor._DESERIALIZER, "circle");
op.add(Builder::inference, InferenceProcessor._DESERIALIZER, "inference");
op.add(Builder::textEmbedding, TextEmbeddingProcessor._DESERIALIZER, "text_embedding");

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,4 +301,12 @@ public static InferenceProcessor.Builder inference() {
return new InferenceProcessor.Builder();
}

/**
* Creates a builder for the {@link TextEmbeddingProcessor text_embedding}
* {@code Processor} variant.
*/
public static TextEmbeddingProcessor.Builder textEmbedding() {
return new TextEmbeddingProcessor.Builder();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.client.opensearch.ingest;

import jakarta.json.stream.JsonGenerator;
import java.util.Map;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.opensearch.client.json.JsonpDeserializable;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;

// typedef: ingest._types.TextEmbeddingProcessor

@JsonpDeserializable
public class TextEmbeddingProcessor extends ProcessorBase implements ProcessorVariant {
private final String modelId;

private final Map<String, String> fieldMap;

@Nullable
private final String description;

// ---------------------------------------------------------------------------------------------

private TextEmbeddingProcessor(Builder builder) {
super(builder);

this.modelId = ApiTypeHelper.requireNonNull(builder.modelId, this, "modelId");
this.fieldMap = ApiTypeHelper.unmodifiableRequired(builder.fieldMap, this, "fieldMap");
this.description = builder.description;

}

public static TextEmbeddingProcessor of(Function<Builder, ObjectBuilder<TextEmbeddingProcessor>> fn) {
return fn.apply(new Builder()).build();
}

/**
* Processor variant kind.
*/
@Override
public Processor.Kind _processorKind() {
return Processor.Kind.Inference;
}

/**
* Required - API name: {@code model_id}
*/
public final String modelId() {
return this.modelId;
}

/**
* API name: {@code field_map}
*/
public final Map<String, String> fieldMap() {
return this.fieldMap;
}

/**
* API name: {@code description}
*/
@Nullable
public final String description() {
return this.description;
}

protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

super.serializeInternal(generator, mapper);
generator.writeKey("model_id");
generator.write(this.modelId);

if (ApiTypeHelper.isDefined(this.fieldMap)) {
generator.writeKey("field_map");
generator.writeStartObject();
for (Map.Entry<String, String> item0 : this.fieldMap.entrySet()) {
generator.writeKey(item0.getKey());
generator.write(item0.getValue());
}
generator.writeEnd();

}
if (this.description != null) {
generator.writeKey("description");
generator.write(this.description);
}

}

// ---------------------------------------------------------------------------------------------

/**
* Builder for {@link TextEmbeddingProcessor}.
*/

public static class Builder extends ProcessorBase.AbstractBuilder<Builder> implements ObjectBuilder<TextEmbeddingProcessor> {
private String modelId;

@Nullable
private Map<String, String> fieldMap;

@Nullable
private String description;

/**
* Required - API name: {@code model_id}
*/
public final Builder modelId(String value) {
this.modelId = value;
return this;
}

/**
* API name: {@code field_map}
* <p>
* Adds all entries of <code>map</code> to <code>fieldMap</code>.
*/
public final Builder fieldMap(Map<String, String> map) {
this.fieldMap = _mapPutAll(this.fieldMap, map);
return this;
}

/**
* API name: {@code field_map}
* <p>
* Adds an entry to <code>fieldMap</code>.
*/
public final Builder fieldMap(String key, String value) {
this.fieldMap = _mapPut(this.fieldMap, key, value);
return this;
}

/**
* API name: {@code description}
*/
public final Builder description(@Nullable String value) {
this.description = value;
return this;
}

@Override
protected Builder self() {
return this;
}

/**
* Builds a {@link TextEmbeddingProcessor}.
*
* @throws NullPointerException
* if some of the required fields are null.
*/
public TextEmbeddingProcessor build() {
_checkSingleUse();

return new TextEmbeddingProcessor(this);
}
}

// ---------------------------------------------------------------------------------------------

/**
* Json deserializer for {@link TextEmbeddingProcessor}
*/
public static final JsonpDeserializer<TextEmbeddingProcessor> _DESERIALIZER = ObjectBuilderDeserializer.lazy(
Builder::new,
TextEmbeddingProcessor::setupTextEmbeddingProcessorDeserializer
);

protected static void setupTextEmbeddingProcessorDeserializer(ObjectDeserializer<TextEmbeddingProcessor.Builder> op) {
ProcessorBase.setupProcessorBaseDeserializer(op);
op.add(Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(Builder::fieldMap, JsonpDeserializer.stringMapDeserializer(JsonpDeserializer.stringDeserializer()), "field_map");
op.add(Builder::description, JsonpDeserializer.stringDeserializer(), "description");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.client.opensearch.ingest;

import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.opensearch.client.opensearch.model.ModelTestCase;

public class TextEmbeddingProcessorTest extends ModelTestCase {

private static final Map<String, String> baseFieldMap = new HashMap<>();
static {
baseFieldMap.put("input_field", "vector_field");
}

private static TextEmbeddingProcessor.Builder baseTextEmbeddingProcessor() {
return new TextEmbeddingProcessor.Builder().modelId("modelId").fieldMap(baseFieldMap).tag("some-tag");
}

@Test
public void testJsonRoundtripWithDescription() {
Processor processor = new Processor.Builder().textEmbedding(
baseTextEmbeddingProcessor().description("processor-description").build()
).build();
String json =
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\"}}";
TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding();

assertEquals("modelId", deserialized.modelId());
assertEquals(baseFieldMap, deserialized.fieldMap());
assertEquals("processor-description", deserialized.description());
assertEquals("some-tag", deserialized.tag());
}

@Test
public void testJsonRoundtripWithoutDescription() {
Processor processor = new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().build()).build();
String json =
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"}}}";
TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding();

assertEquals("modelId", deserialized.modelId());
assertEquals(baseFieldMap, deserialized.fieldMap());
assertNull(deserialized.description());
assertEquals("some-tag", deserialized.tag());
}
}

0 comments on commit 5b90848

Please sign in to comment.