Skip to content

Commit

Permalink
Add batch_size param for text_embedding processor (#1298)
Browse files Browse the repository at this point in the history
* Add batchSize parameter for text_embedding processor

Signed-off-by: YeonghyeonKO <dk02315@gmail.com>

* throw IllegalArgumentException when batchSize is not a positive integer

Signed-off-by: YeonghyeonKO <dk02315@gmail.com>

* test: add test cases for BatchSize param

Signed-off-by: YeonghyeonKO <dk02315@gmail.com>

* test: exception when batchSize is zero or negative integer

Signed-off-by: YeonghyeonKO <dk02315@gmail.com>

* refactor: use assertNotNull for readability & convention

Signed-off-by: YeonghyeonKO <dk02315@gmail.com>

* update CHANGELOG about #1298 PR

Signed-off-by: YeonghyeonKO <dk02315@gmail.com>

* apply code convention to keep the codes spotless

Signed-off-by: YeonghyeonKO <dk02315@gmail.com>

---------

Signed-off-by: YeonghyeonKO <dk02315@gmail.com>
Signed-off-by: Thomas Farr <tsfarr@amazon.com>
Co-authored-by: Thomas Farr <tsfarr@amazon.com>
  • Loading branch information
YeonghyeonKO and Xtansia authored Nov 18, 2024
1 parent baf919d commit 6c3e68f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ This section is for maintaining a changelog for all breaking changes for the cli

### Added
- Added support for disabling typed keys serialization ([#1296](https://github.com/opensearch-project/opensearch-java/pull/1296))
- Added support for the `batch_size` param on the `text_embedding` processor ([#1298](https://github.com/opensearch-project/opensearch-java/pull/1298))

### Dependencies

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ public class TextEmbeddingProcessor extends ProcessorBase implements ProcessorVa
@Nullable
private final String description;

@Nullable
private final Integer batchSize;

// ---------------------------------------------------------------------------------------------

private TextEmbeddingProcessor(Builder builder) {
Expand All @@ -39,7 +42,7 @@ private TextEmbeddingProcessor(Builder builder) {
this.modelId = ApiTypeHelper.requireNonNull(builder.modelId, this, "modelId");
this.fieldMap = ApiTypeHelper.unmodifiableRequired(builder.fieldMap, this, "fieldMap");
this.description = builder.description;

this.batchSize = builder.batchSize;
}

public static TextEmbeddingProcessor of(Function<Builder, ObjectBuilder<TextEmbeddingProcessor>> fn) {
Expand Down Expand Up @@ -76,6 +79,14 @@ public final String description() {
return this.description;
}

/**
* API name: {@code batch_size}
*/
@Nullable
public final Integer batchSize() {
return this.batchSize;
}

protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

super.serializeInternal(generator, mapper);
Expand All @@ -96,7 +107,10 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeKey("description");
generator.write(this.description);
}

if (this.batchSize != null) {
generator.writeKey("batch_size");
generator.write(this.batchSize);
}
}

// ---------------------------------------------------------------------------------------------
Expand All @@ -114,6 +128,9 @@ public static class Builder extends ProcessorBase.AbstractBuilder<Builder> imple
@Nullable
private String description;

@Nullable
private Integer batchSize;

/**
* Required - API name: {@code model_id}
*/
Expand Down Expand Up @@ -150,6 +167,17 @@ public final Builder description(@Nullable String value) {
return this;
}

/**
* API name: {@code batch_size}
*/
public final Builder batchSize(@Nullable Integer value) {
if (value != null && value <= 0) {
throw new IllegalArgumentException("batchSize must be a positive integer");
}
this.batchSize = value;
return this;
}

@Override
protected Builder self() {
return this;
Expand Down Expand Up @@ -183,6 +211,7 @@ protected static void setupTextEmbeddingProcessorDeserializer(ObjectDeserializer
op.add(Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(Builder::fieldMap, JsonpDeserializer.stringMapDeserializer(JsonpDeserializer.stringDeserializer()), "field_map");
op.add(Builder::description, JsonpDeserializer.stringDeserializer(), "description");
op.add(Builder::batchSize, JsonpDeserializer.integerDeserializer(), "batch_size");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,63 @@ private static TextEmbeddingProcessor.Builder baseTextEmbeddingProcessor() {
}

@Test
public void testJsonRoundtripWithDescription() {
public void testJsonRoundtripWithDescriptionAndBatchSize() {
Processor processor = new Processor.Builder().textEmbedding(
baseTextEmbeddingProcessor().description("processor-description").build()
baseTextEmbeddingProcessor().description("processor-description").batchSize(1).build()
).build();
String json =
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\"}}";
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\",\"batch_size\":1}}";
TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding();

assertEquals("modelId", deserialized.modelId());
assertEquals(baseFieldMap, deserialized.fieldMap());
assertEquals("processor-description", deserialized.description());
assertEquals("some-tag", deserialized.tag());
assertNotNull(deserialized.batchSize());
assertEquals(1, deserialized.batchSize().intValue());
}

@Test
public void testJsonRoundtripWithoutDescription() {
Processor processor = new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().build()).build();
Processor processor = new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().batchSize(1).build()).build();
String json =
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"}}}";
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"batch_size\":1}}";
TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding();

assertEquals("modelId", deserialized.modelId());
assertEquals(baseFieldMap, deserialized.fieldMap());
assertNull(deserialized.description());
assertEquals("some-tag", deserialized.tag());
assertNotNull(deserialized.batchSize());
assertEquals(1, deserialized.batchSize().intValue());
}

@Test
public void testJsonRoundtripWithoutBatchSize() {
Processor processor = new Processor.Builder().textEmbedding(
baseTextEmbeddingProcessor().description("processor-description").build()
).build();
String json =
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\"}}";
TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding();

assertEquals("modelId", deserialized.modelId());
assertEquals(baseFieldMap, deserialized.fieldMap());
assertEquals("processor-description", deserialized.description());
assertEquals("some-tag", deserialized.tag());
assertNull(deserialized.batchSize());
}

@Test
public void testInvalidBatchSizeThrowsException() {
IllegalArgumentException exceptionWhenBatchSizeIsZero = assertThrows(IllegalArgumentException.class, () -> {
new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().batchSize(0).build()).build();
});
assertEquals("batchSize must be a positive integer", exceptionWhenBatchSizeIsZero.getMessage());

IllegalArgumentException exceptionWhenBatchSizeIsNegative = assertThrows(IllegalArgumentException.class, () -> {
new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().batchSize(-1).build()).build();
});
assertEquals("batchSize must be a positive integer", exceptionWhenBatchSizeIsNegative.getMessage());
}
}

0 comments on commit 6c3e68f

Please sign in to comment.