Skip to content

Commit

Permalink
allow input null for text docs input
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn committed Sep 28, 2023
1 parent 28c6c41 commit aa91556
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws
case TEXT_DOCS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
docs.add(parser.text());
if (parser.currentToken() == null || parser.currentToken() == XContentParser.Token.VALUE_NULL) {
docs.add(null);
} else {
docs.add(parser.text());
}
}
break;
case RESULT_FILTER_FIELD:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

public class TextDocsMLInputTest {
Expand All @@ -47,22 +48,22 @@ public void parseTextDocsMLInput() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder();
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
String jsonStr = builder.toString();
parseMLInput(jsonStr);
parseMLInput(jsonStr, 2);
}

@Test
public void parseTextDocsMLInput_OldWay() throws IOException {
String jsonStr = "{\"text_docs\": [ \"doc1\", \"doc2\" ],\"return_number\": true, \"return_bytes\": true,\"target_response\": [ \"field1\" ], \"target_response_positions\": [2]}";
parseMLInput(jsonStr);
String jsonStr = "{\"text_docs\": [ \"doc1\", \"doc2\", null ],\"return_number\": true, \"return_bytes\": true,\"target_response\": [ \"field1\" ], \"target_response_positions\": [2]}";
parseMLInput(jsonStr, 3);
}

@Test
public void parseTextDocsMLInput_NewWay() throws IOException {
String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}";
parseMLInput(jsonStr);
parseMLInput(jsonStr, 2);
}

private void parseMLInput(String jsonStr) throws IOException {
private void parseMLInput(String jsonStr, int docSize) throws IOException {
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
Expand All @@ -72,9 +73,12 @@ private void parseMLInput(String jsonStr) throws IOException {
assertEquals(input.getFunctionName(), parsedInput.getFunctionName());
assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType());
TextDocsInputDataSet inputDataset = (TextDocsInputDataSet) parsedInput.getInputDataset();
assertEquals(2, inputDataset.getDocs().size());
assertEquals(docSize, inputDataset.getDocs().size());
assertEquals("doc1", inputDataset.getDocs().get(0));
assertEquals("doc2", inputDataset.getDocs().get(1));
if (inputDataset.getDocs().size() > 2) {
assertNull(inputDataset.getDocs().get(2));
}
assertNotNull(inputDataset.getResultFilter());
assertTrue(inputDataset.getResultFilter().isReturnBytes());
assertTrue(inputDataset.getResultFilter().isReturnNumber());
Expand Down

0 comments on commit aa91556

Please sign in to comment.