diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java index a61f5a7a20..866e9cadef 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java +++ b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java @@ -42,16 +42,20 @@ public class DataSourceMetadata { @JsonProperty private Map properties; + @JsonProperty private String resultIndex; + public DataSourceMetadata( String name, DataSourceType connector, List allowedRoles, - Map properties) { + Map properties, + String resultIndex) { this.name = name; this.connector = connector; this.description = StringUtils.EMPTY; this.properties = properties; this.allowedRoles = allowedRoles; + this.resultIndex = resultIndex; } public DataSourceMetadata() { @@ -69,6 +73,7 @@ public static DataSourceMetadata defaultOpenSearchDataSourceMetadata() { DEFAULT_DATASOURCE_NAME, DataSourceType.OPENSEARCH, Collections.emptyList(), - ImmutableMap.of()); + ImmutableMap.of(), + null); } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index a16d57673e..508567582b 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -199,7 +199,8 @@ public Set getDataSourceMetadata(boolean isDefaultDataSource ds.getName(), ds.getConnectorType(), Collections.emptyList(), - ImmutableMap.of())) + ImmutableMap.of(), + null)) .collect(Collectors.toSet()); } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScanTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScanTest.java index 358a7b43b5..4aefc5521d 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScanTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScanTest.java @@ -64,7 +64,8 @@ void testIterator() { dataSource.getName(), dataSource.getConnectorType(), Collections.emptyList(), - ImmutableMap.of())) + ImmutableMap.of(), + null)) .collect(Collectors.toSet()); when(dataSourceService.getDataSourceMetadata(false)).thenReturn(dataSourceMetadata); diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java b/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java index 53f0054dc2..261f13870a 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java @@ -32,6 +32,8 @@ public class XContentParserUtils { public static final String PROPERTIES_FIELD = "properties"; public static final String ALLOWED_ROLES_FIELD = "allowedRoles"; + public static final String RESULT_INDEX_FIELD = "resultIndex"; + /** * Convert xcontent parser to DataSourceMetadata. * @@ -45,6 +47,7 @@ public static DataSourceMetadata toDataSourceMetadata(XContentParser parser) thr DataSourceType connector = null; List allowedRoles = new ArrayList<>(); Map properties = new HashMap<>(); + String resultIndex = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -73,6 +76,9 @@ public static DataSourceMetadata toDataSourceMetadata(XContentParser parser) thr properties.put(key, value); } break; + case RESULT_INDEX_FIELD: + resultIndex = parser.textOrNull(); + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -80,7 +86,8 @@ public static DataSourceMetadata toDataSourceMetadata(XContentParser parser) thr if (name == null || connector == null) { throw new IllegalArgumentException("name and connector are required fields."); } - return new DataSourceMetadata(name, description, connector, allowedRoles, properties); + return new DataSourceMetadata( + name, description, connector, allowedRoles, properties, resultIndex); } /** @@ -122,6 +129,7 @@ public static XContentBuilder convertToXContent(DataSourceMetadata metadata) thr builder.field(entry.getKey(), entry.getValue()); } builder.endObject(); + builder.field(RESULT_INDEX_FIELD, metadata.getResultIndex()); builder.endObject(); return builder; } diff --git a/datasources/src/main/resources/datasources-index-mapping.yml b/datasources/src/main/resources/datasources-index-mapping.yml index cb600ae825..0206a97886 100644 --- a/datasources/src/main/resources/datasources-index-mapping.yml +++ b/datasources/src/main/resources/datasources-index-mapping.yml @@ -14,4 +14,6 @@ properties: keyword: type: keyword connector: + type: keyword + resultIndex: type: keyword \ No newline at end of file diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index eb28495541..c8312e6013 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -344,7 +344,8 @@ void testRemovalOfAuthorizationInfo() { "testDS", DataSourceType.PROMETHEUS, Collections.singletonList("prometheus_access"), - properties); + properties, + null); when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) .thenReturn(Optional.of(dataSourceMetadata)); @@ -398,7 +399,8 @@ void testGetRawDataSourceMetadata() { "testDS", DataSourceType.PROMETHEUS, Collections.singletonList("prometheus_access"), - properties); + properties, + null); when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) .thenReturn(Optional.of(dataSourceMetadata)); diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java index eab0b8e168..d134293456 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java @@ -31,7 +31,7 @@ public void testConvertToXContent() { XContentBuilder contentBuilder = XContentParserUtils.convertToXContent(dataSourceMetadata); String contentString = BytesReference.bytes(contentBuilder).utf8ToString(); Assertions.assertEquals( - "{\"name\":\"testDS\",\"description\":\"\",\"connector\":\"PROMETHEUS\",\"allowedRoles\":[\"prometheus_access\"],\"properties\":{\"prometheus.uri\":\"https://localhost:9090\"}}", + "{\"name\":\"testDS\",\"description\":\"\",\"connector\":\"PROMETHEUS\",\"allowedRoles\":[\"prometheus_access\"],\"properties\":{\"prometheus.uri\":\"https://localhost:9090\"},\"resultIndex\":null}", contentString); } diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java index 0635c6581b..087629a1f1 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java @@ -66,7 +66,8 @@ public void createDataSourceAPITest() { "prometheus.auth.username", "username", "prometheus.auth.password", - "password")); + "password"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); Response response = client().performRequest(createRequest); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); @@ -96,7 +97,8 @@ public void updateDataSourceAPITest() { "update_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "https://localhost:9090")); + ImmutableMap.of("prometheus.uri", "https://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); client().performRequest(createRequest); // Datasource is not immediately created. so introducing a sleep of 2s. @@ -108,7 +110,8 @@ public void updateDataSourceAPITest() { "update_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "https://randomtest.com:9090")); + ImmutableMap.of("prometheus.uri", "https://randomtest.com:9090"), + null); Request updateRequest = getUpdateDataSourceRequest(updateDSM); Response updateResponse = client().performRequest(updateRequest); Assert.assertEquals(200, updateResponse.getStatusLine().getStatusCode()); @@ -141,7 +144,8 @@ public void deleteDataSourceTest() { "delete_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "https://localhost:9090")); + ImmutableMap.of("prometheus.uri", "https://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); client().performRequest(createRequest); // Datasource is not immediately created. so introducing a sleep of 2s. @@ -179,7 +183,8 @@ public void getAllDataSourceTest() { "get_all_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "https://localhost:9090")); + ImmutableMap.of("prometheus.uri", "https://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); client().performRequest(createRequest); // Datasource is not immediately created. so introducing a sleep of 2s. @@ -215,7 +220,8 @@ public void issue2196() { "prometheus.auth.username", "username", "prometheus.auth.password", - "password")); + "password"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); Response response = client().performRequest(createRequest); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/InformationSchemaCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/InformationSchemaCommandIT.java index cf7cfcdb39..7b694ce222 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/InformationSchemaCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/InformationSchemaCommandIT.java @@ -46,7 +46,8 @@ protected void init() throws InterruptedException, IOException { "my_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "http://localhost:9090")); + ImmutableMap.of("prometheus.uri", "http://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); Response response = client().performRequest(createRequest); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusDataSourceCommandsIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusDataSourceCommandsIT.java index 8d72f02e29..b81b7f9517 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusDataSourceCommandsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusDataSourceCommandsIT.java @@ -58,7 +58,8 @@ protected void init() throws InterruptedException, IOException { "my_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "http://localhost:9090")); + ImmutableMap.of("prometheus.uri", "http://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); Response response = client().performRequest(createRequest); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ShowDataSourcesCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ShowDataSourcesCommandIT.java index 4f31c5c130..c3d2bf5912 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ShowDataSourcesCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ShowDataSourcesCommandIT.java @@ -46,7 +46,8 @@ protected void init() throws InterruptedException, IOException { "my_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "http://localhost:9090")); + ImmutableMap.of("prometheus.uri", "http://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); Response response = client().performRequest(createRequest); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 486a31bf73..bbb5abdb28 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -7,6 +7,8 @@ import static org.opensearch.sql.common.setting.Settings.Key.CLUSTER_NAME; import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.JobRunState; import java.security.AccessController; @@ -74,11 +76,13 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getLang(), sparkExecutionEngineConfig.getExecutionRoleARN(), clusterName.value())); + asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata( sparkExecutionEngineConfig.getApplicationId(), dispatchQueryResponse.getJobId(), - dispatchQueryResponse.isDropIndexQuery())); + dispatchQueryResponse.isDropIndexQuery(), + dispatchQueryResponse.getResultIndex())); return new CreateAsyncQueryResponse(dispatchQueryResponse.getJobId()); } @@ -89,7 +93,7 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (jobMetadata.isPresent()) { JSONObject jsonObject = sparkQueryDispatcher.getQueryResponse(jobMetadata.get()); - if (JobRunState.SUCCESS.toString().equals(jsonObject.getString("status"))) { + if (JobRunState.SUCCESS.toString().equals(jsonObject.getString(STATUS_FIELD))) { DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = new DefaultSparkSqlFunctionResponseHandle(jsonObject); List result = new ArrayList<>(); @@ -97,9 +101,13 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { result.add(sparkSqlFunctionResponseHandle.next()); } return new AsyncQueryExecutionResponse( - JobRunState.SUCCESS.toString(), sparkSqlFunctionResponseHandle.schema(), result); + JobRunState.SUCCESS.toString(), sparkSqlFunctionResponseHandle.schema(), result, null); } else { - return new AsyncQueryExecutionResponse(jsonObject.getString("status"), null, null); + return new AsyncQueryExecutionResponse( + jsonObject.optString(STATUS_FIELD, JobRunState.FAILED.toString()), + null, + null, + jsonObject.optString(ERROR_FIELD, "")); } } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java index 84dcc490ba..d2e54af004 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java @@ -18,4 +18,5 @@ public class AsyncQueryExecutionResponse { private final String status; private final ExecutionEngine.Schema schema; private final List results; + private final String error; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index 64a2078066..b470ef989f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -29,11 +29,13 @@ public class AsyncQueryJobMetadata { private String applicationId; private String jobId; private boolean isDropIndexQuery; + private String resultIndex; - public AsyncQueryJobMetadata(String applicationId, String jobId) { + public AsyncQueryJobMetadata(String applicationId, String jobId, String resultIndex) { this.applicationId = applicationId; this.jobId = jobId; this.isDropIndexQuery = false; + this.resultIndex = resultIndex; } @Override @@ -54,6 +56,7 @@ public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) builder.field("jobId", metadata.getJobId()); builder.field("applicationId", metadata.getApplicationId()); builder.field("isDropIndexQuery", metadata.isDropIndexQuery()); + builder.field("resultIndex", metadata.getResultIndex()); builder.endObject(); return builder; } @@ -88,6 +91,7 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws String jobId = null; String applicationId = null; boolean isDropIndexQuery = false; + String resultIndex = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -102,6 +106,9 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws case "isDropIndexQuery": isDropIndexQuery = parser.booleanValue(); break; + case "resultIndex": + resultIndex = parser.textOrNull(); + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -109,6 +116,6 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws if (jobId == null || applicationId == null) { throw new IllegalArgumentException("jobId and applicationId are required fields."); } - return new AsyncQueryJobMetadata(applicationId, jobId, isDropIndexQuery); + return new AsyncQueryJobMetadata(applicationId, jobId, isDropIndexQuery, resultIndex); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java index 6d6bce8fbc..c229aa3920 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java @@ -11,19 +11,26 @@ public class AsyncQueryResult extends QueryResult { @Getter private final String status; + @Getter private final String error; public AsyncQueryResult( String status, ExecutionEngine.Schema schema, Collection exprValues, - Cursor cursor) { + Cursor cursor, + String error) { super(schema, exprValues, cursor); this.status = status; + this.error = error; } public AsyncQueryResult( - String status, ExecutionEngine.Schema schema, Collection exprValues) { + String status, + ExecutionEngine.Schema schema, + Collection exprValues, + String error) { super(schema, exprValues); this.status = status; + this.error = error; } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index f5ef155c9c..f9f0b8ed8d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -11,42 +11,7 @@ import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_REGION; import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI; import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN; -import static org.opensearch.sql.spark.data.constants.SparkConstants.AWS_SNAPSHOT_REPOSITORY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE; -import static org.opensearch.sql.spark.data.constants.SparkConstants.DRIVER_ENV_ASSUME_ROLE_ARN_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER; -import static org.opensearch.sql.spark.data.constants.SparkConstants.EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_CATALOG_JAR; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_CREDENTIALS_PROVIDER_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_AUTH; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_HOST; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_PORT; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_SCHEME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DELEGATE_CATALOG; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_PASSWORD; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_USERNAME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AWSREGION_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_HOST_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SQL_EXTENSION; -import static org.opensearch.sql.spark.data.constants.SparkConstants.GLUE_CATALOG_HIVE_JAR; -import static org.opensearch.sql.spark.data.constants.SparkConstants.GLUE_HIVE_CATALOG_FACTORY_CLASS; -import static org.opensearch.sql.spark.data.constants.SparkConstants.HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_CLASS_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_GLUE_ARN_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.JAVA_HOME_LOCATION; -import static org.opensearch.sql.spark.data.constants.SparkConstants.S3_AWS_CREDENTIALS_PROVIDER_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_DRIVER_ENV_JAVA_HOME_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_EXECUTOR_ENV_JAVA_HOME_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JARS_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JAR_PACKAGES_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JAR_REPOSITORIES_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_EXTENSIONS_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_STANDALONE_PACKAGE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.*; import java.net.URI; import java.net.URISyntaxException; @@ -80,8 +45,7 @@ private Builder() { config.put( HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY, DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); - config.put(SPARK_JARS_KEY, GLUE_CATALOG_HIVE_JAR + "," + FLINT_CATALOG_JAR); - config.put(SPARK_JAR_PACKAGES_KEY, SPARK_STANDALONE_PACKAGE); + config.put(SPARK_JAR_PACKAGES_KEY, SPARK_STANDALONE_PACKAGE + "," + SPARK_LAUNCHER_PACKAGE); config.put(SPARK_JAR_REPOSITORIES_KEY, AWS_SNAPSHOT_REPOSITORY); config.put(SPARK_DRIVER_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); config.put(SPARK_EXECUTOR_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); @@ -106,6 +70,7 @@ public Builder dataSource(DataSourceMetadata metadata) { config.put(EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY, roleArn); config.put(HIVE_METASTORE_GLUE_ARN_KEY, roleArn); config.put("spark.sql.catalog." + metadata.getName(), FLINT_DELEGATE_CATALOG); + config.put(FLINT_DATA_SOURCE_KEY, metadata.getName()); setFlintIndexStoreHost( parseUri( @@ -115,6 +80,7 @@ public Builder dataSource(DataSourceMetadata metadata) { () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME), () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD), () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_REGION)); + config.put("spark.flint.datasource.name", metadata.getName()); return this; } throw new UnsupportedOperationException( diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java index 1a8e3203b8..f0a7e76c87 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java @@ -34,6 +34,10 @@ public EmrServerlessClientImplEMR(AWSEMRServerless emrServerless) { @Override public String startJobRun(StartJobRequest startJobRequest) { + String resultIndex = + startJobRequest.getResultIndex() == null + ? SPARK_RESPONSE_BUFFER_INDEX_NAME + : startJobRequest.getResultIndex(); StartJobRunRequest request = new StartJobRunRequest() .withName(startJobRequest.getJobName()) @@ -46,8 +50,7 @@ public String startJobRun(StartJobRequest startJobRequest) { .withSparkSubmit( new SparkSubmit() .withEntryPoint(SPARK_SQL_APPLICATION_JAR) - .withEntryPointArguments( - startJobRequest.getQuery(), SPARK_RESPONSE_BUFFER_INDEX_NAME) + .withEntryPointArguments(startJobRequest.getQuery(), resultIndex) .withSparkSubmitParameters(startJobRequest.getSparkSubmitParams()))); StartJobRunResult startJobRunResult = AccessController.doPrivileged( diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java index df8f9f61b1..c4382239a1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -27,6 +27,8 @@ public class StartJobRequest { /** true if it is Spark Structured Streaming job. */ private final boolean isStructuredStreaming; + private final String resultIndex; + public Long executionTimeout() { return isStructuredStreaming ? 0L : DEFAULT_JOB_TIMEOUT; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index a0318a9478..7fc71458d0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -8,9 +8,18 @@ public class SparkConstants { public static final String EMR = "emr"; public static final String STEP_ID_FIELD = "stepId.keyword"; - // TODO should be replaced with mvn jar. + + public static final String JOB_ID_FIELD = "jobRunId"; + + public static final String STATUS_FIELD = "status"; + + public static final String DATA_FIELD = "data"; + + public static final String ERROR_FIELD = "error"; + + // EMR-S will download JAR to local maven public static final String SPARK_SQL_APPLICATION_JAR = - "s3://flint-data-dp-eu-west-1-beta/code/flint/sql-job.jar"; + "file:///home/hadoop/.ivy2/jars/org.opensearch_opensearch-spark-sql-application_2.12-0.1.0-SNAPSHOT.jar"; public static final String SPARK_RESPONSE_BUFFER_INDEX_NAME = ".query_execution_result"; // TODO should be replaced with mvn jar. public static final String FLINT_INTEGRATION_JAR = @@ -26,7 +35,7 @@ public class SparkConstants { public static final String FLINT_DEFAULT_SCHEME = "http"; public static final String FLINT_DEFAULT_AUTH = "noauth"; public static final String FLINT_DEFAULT_REGION = "us-west-2"; - public static final String DEFAULT_CLASS_NAME = "org.opensearch.sql.FlintJob"; + public static final String DEFAULT_CLASS_NAME = "org.apache.spark.sql.FlintJob"; public static final String S3_AWS_CREDENTIALS_PROVIDER_KEY = "spark.hadoop.fs.s3.customAWSCredentialsProvider"; public static final String DRIVER_ENV_ASSUME_ROLE_ARN_KEY = @@ -53,6 +62,7 @@ public class SparkConstants { public static final String FLINT_INDEX_STORE_AWSREGION_KEY = "spark.datasource.flint.region"; public static final String FLINT_CREDENTIALS_PROVIDER_KEY = "spark.datasource.flint.customAWSCredentialsProvider"; + public static final String FLINT_DATA_SOURCE_KEY = "spark.flint.datasource.name"; public static final String SPARK_SQL_EXTENSIONS_KEY = "spark.sql.extensions"; public static final String HIVE_METASTORE_CLASS_KEY = "spark.hadoop.hive.metastore.client.factory.class"; @@ -62,11 +72,14 @@ public class SparkConstants { "com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory"; public static final String SPARK_STANDALONE_PACKAGE = "org.opensearch:opensearch-spark-standalone_2.12:0.1.0-SNAPSHOT"; + public static final String SPARK_LAUNCHER_PACKAGE = + "org.opensearch:opensearch-spark-sql-application_2.12:0.1.0-SNAPSHOT"; public static final String AWS_SNAPSHOT_REPOSITORY = "https://aws.oss.sonatype.org/content/repositories/snapshots"; public static final String GLUE_HIVE_CATALOG_FACTORY_CLASS = "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory"; - public static final String FLINT_DELEGATE_CATALOG = "org.opensearch.sql.FlintDelegateCatalog"; + public static final String FLINT_DELEGATE_CATALOG = + "org.opensearch.sql.FlintDelegatingSessionCatalog"; public static final String FLINT_SQL_EXTENSION = "org.opensearch.flint.spark.FlintSparkExtensions"; public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER = diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 9c5d4df667..f5ef419294 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -5,6 +5,10 @@ package org.opensearch.sql.spark.dispatcher; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; + import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRunState; @@ -14,6 +18,7 @@ import org.apache.commons.lang3.RandomStringUtils; import org.json.JSONObject; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; @@ -63,12 +68,33 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) GetJobRunResult getJobRunResult = emrServerlessClient.getJobRunResult( asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - JSONObject result = new JSONObject(); - if (getJobRunResult.getJobRun().getState().equals(JobRunState.SUCCESS.toString())) { - result = - jobExecutionResponseReader.getResultFromOpensearchIndex(asyncQueryJobMetadata.getJobId()); + String jobState = getJobRunResult.getJobRun().getState(); + JSONObject result = + (jobState.equals(JobRunState.SUCCESS.toString())) + ? jobExecutionResponseReader.getResultFromOpensearchIndex( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()) + : new JSONObject(); + + // if result index document has a status, we are gonna use the status directly; otherwise, we + // will use emr-s job status + // a job is successful does not mean there is no error in execution. For example, even if result + // index mapping + // is incorrect, we still write query result and let the job finish. + if (result.has(DATA_FIELD)) { + JSONObject items = result.getJSONObject(DATA_FIELD); + + // If items have STATUS_FIELD, use it; otherwise, use jobState + String status = items.optString(STATUS_FIELD, jobState); + result.put(STATUS_FIELD, status); + + // If items have ERROR_FIELD, use it; otherwise, set empty string + String error = items.optString(ERROR_FIELD, ""); + result.put(ERROR_FIELD, error); + } else { + result.put(STATUS_FIELD, jobState); + result.put(ERROR_FIELD, ""); } - result.put("status", getJobRunResult.getJobRun().getState()); + return result; } @@ -96,8 +122,9 @@ private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryR private DispatchQueryResponse handleIndexQuery( DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - dataSourceUserAuthorizationHelper.authorizeDataSource( - this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource())); + DataSourceMetadata dataSourceMetadata = + this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); + dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); tags.put(INDEX_TAG_KEY, indexDetails.getIndexName()); @@ -117,14 +144,16 @@ private DispatchQueryResponse handleIndexQuery( .build() .toString(), tags, - indexDetails.getAutoRefresh()); + indexDetails.getAutoRefresh(), + dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false); + return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex()); } private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { - dataSourceUserAuthorizationHelper.authorizeDataSource( - this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource())); + DataSourceMetadata dataSourceMetadata = + this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); + dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); StartJobRequest startJobRequest = @@ -140,19 +169,22 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ .build() .toString(), tags, - false); + false, + dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false); + return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex()); } private DispatchQueryResponse handleDropIndexQuery( DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { - dataSourceUserAuthorizationHelper.authorizeDataSource( - this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource())); + DataSourceMetadata dataSourceMetadata = + this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); + dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobId = flintIndexMetadataReader.getJobIdFromFlintIndexMetadata(indexDetails); emrServerlessClient.cancelJobRun(dispatchQueryRequest.getApplicationId(), jobId); String dropIndexDummyJobId = RandomStringUtils.randomAlphanumeric(16); - return new DispatchQueryResponse(dropIndexDummyJobId, true); + return new DispatchQueryResponse( + dropIndexDummyJobId, true, dataSourceMetadata.getResultIndex()); } private static Map getDefaultTagsForJobSubmission( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java index 592f3db4fe..9ee5f156f2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -8,4 +8,5 @@ public class DispatchQueryResponse { private String jobId; private boolean isDropIndexQuery; + private String resultIndex; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java index 8abb7cd11f..5da0ef44fe 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java @@ -5,8 +5,9 @@ package org.opensearch.sql.spark.response; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.JOB_ID_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.STEP_ID_FIELD; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -33,13 +34,14 @@ public JobExecutionResponseReader(Client client) { this.client = client; } - public JSONObject getResultFromOpensearchIndex(String jobId) { - return searchInSparkIndex(QueryBuilders.termQuery(STEP_ID_FIELD, jobId)); + public JSONObject getResultFromOpensearchIndex(String jobId, String resultIndex) { + return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultIndex); } - private JSONObject searchInSparkIndex(QueryBuilder query) { + private JSONObject searchInSparkIndex(QueryBuilder query, String resultIndex) { SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(SPARK_RESPONSE_BUFFER_INDEX_NAME); + String searchResultIndex = resultIndex == null ? SPARK_RESPONSE_BUFFER_INDEX_NAME : resultIndex; + searchRequest.indices(searchResultIndex); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(query); searchRequest.source(searchSourceBuilder); @@ -53,13 +55,13 @@ private JSONObject searchInSparkIndex(QueryBuilder query) { if (searchResponse.status().getStatus() != 200) { throw new RuntimeException( "Fetching result from " - + SPARK_RESPONSE_BUFFER_INDEX_NAME + + searchResultIndex + " index failed with status : " + searchResponse.status()); } else { JSONObject data = new JSONObject(); for (SearchHit searchHit : searchResponse.getHits().getHits()) { - data.put("data", searchHit.getSourceAsMap()); + data.put(DATA_FIELD, searchHit.getSourceAsMap()); } return data; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java index c23706b184..5c784cf04c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java @@ -61,7 +61,8 @@ protected void doExecute( asyncQueryExecutionResponse.getStatus(), asyncQueryExecutionResponse.getSchema(), asyncQueryExecutionResponse.getResults(), - Cursor.None)); + Cursor.None, + asyncQueryExecutionResponse.getError())); listener.onResponse(new GetAsyncQueryResultActionResponse(responseContent)); } catch (Exception e) { listener.onFailure(e); diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java b/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java index c9eb5bbf59..3a2a5b110f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java @@ -11,6 +11,7 @@ import lombok.Builder; import lombok.Getter; import lombok.RequiredArgsConstructor; +import org.opensearch.core.common.Strings; import org.opensearch.sql.protocol.response.QueryResult; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; @@ -53,6 +54,10 @@ public Object buildJsonObject(AsyncQueryResult response) { json.datarows(fetchDataRows(response)); } json.status(response.getStatus()); + if (!Strings.isEmpty(response.getError())) { + json.error(response.getError()); + } + return json.build(); } @@ -79,6 +84,7 @@ public static class JsonResponse { private Integer total; private Integer size; + private final String error; } @RequiredArgsConstructor diff --git a/spark/src/main/resources/job-metadata-index-mapping.yml b/spark/src/main/resources/job-metadata-index-mapping.yml index ec2c83a4df..3a39b989a2 100644 --- a/spark/src/main/resources/job-metadata-index-mapping.yml +++ b/spark/src/main/resources/job-metadata-index-mapping.yml @@ -14,6 +14,11 @@ properties: keyword: type: keyword applicationId: + type: text + fields: + keyword: + type: keyword + resultIndex: type: text fields: keyword: diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index df897ec7dc..a053f30f3b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -64,11 +64,11 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false)); + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) - .storeJobMetadata(new AsyncQueryJobMetadata("00fd775baqpu4g0p", EMR_JOB_ID)); + .storeJobMetadata(new AsyncQueryJobMetadata("00fd775baqpu4g0p", EMR_JOB_ID, null)); verify(settings, times(1)).getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); verify(settings, times(1)).getSettingValue(Settings.Key.CLUSTER_NAME); verify(sparkQueryDispatcher, times(1)) @@ -106,11 +106,11 @@ void testGetAsyncQueryResultsWithInProgressJob() { new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(); jobResult.put("status", JobRunState.PENDING.toString()); when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))) + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); @@ -124,11 +124,11 @@ void testGetAsyncQueryResultsWithInProgressJob() { @Test void testGetAsyncQueryResultsWithSuccessJob() throws IOException { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); jobResult.put("status", JobRunState.SUCCESS.toString()); when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))) + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); AsyncQueryExecutorServiceImpl jobExecutorService = @@ -185,8 +185,9 @@ void testCancelJob() { new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))); - when(sparkQueryDispatcher.cancelJob(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))) + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + when(sparkQueryDispatcher.cancelJob( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(EMR_JOB_ID); String jobId = asyncQueryExecutorService.cancelQuery(EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, jobId); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index 7097daf13e..7288fd3fc2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -66,7 +66,7 @@ public void testStoreJobMetadata() { Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); @@ -83,7 +83,7 @@ public void testStoreJobMetadataWithOutCreatingIndex() { Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); @@ -105,7 +105,7 @@ public void testStoreJobMetadataWithException() { .thenThrow(new RuntimeException("error while indexing")); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, @@ -129,7 +129,7 @@ public void testStoreJobMetadataWithIndexCreationFailed() { .thenReturn(new CreateIndexResponse(false, false, JOB_METADATA_INDEX)); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, @@ -157,7 +157,7 @@ public void testStoreJobMetadataFailedWithNotFoundResponse() { Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, @@ -183,7 +183,7 @@ public void testGetJobMetadata() { new SearchHits( new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID); + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null); Mockito.when(searchHit.getSourceAsString()).thenReturn(asyncQueryJobMetadata.toString()); Optional jobMetadataOptional = diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 17d4fe55c0..4655584855 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -45,7 +45,26 @@ void testStartJobRun() { EMRS_EXECUTION_ROLE, SPARK_SUBMIT_PARAMETERS, new HashMap<>(), - false)); + false, + null)); + } + + @Test + void testStartJobRunResultIndex() { + StartJobRunResult response = new StartJobRunResult(); + when(emrServerless.startJobRun(any())).thenReturn(response); + + EmrServerlessClientImplEMR emrServerlessClient = new EmrServerlessClientImplEMR(emrServerless); + emrServerlessClient.startJobRun( + new StartJobRequest( + QUERY, + EMRS_JOB_NAME, + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + SPARK_SUBMIT_PARAMETERS, + new HashMap<>(), + false, + "foo")); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java index 783ce8466e..eb7d9634ec 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java @@ -20,10 +20,10 @@ void executionTimeout() { } private StartJobRequest onDemandJob() { - return new StartJobRequest("", "", "", "", "", Map.of(), false); + return new StartJobRequest("", "", "", "", "", Map.of(), false, null); } private StartJobRequest streamingJob() { - return new StartJobRequest("", "", "", "", "", Map.of(), true); + return new StartJobRequest("", "", "", "", "", Map.of(), true, null); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 8dbf60e170..4c04381f36 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.dispatcher; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -14,9 +15,12 @@ import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_PASSWORD; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_USERNAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AWSREGION_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -85,7 +89,8 @@ void testDispatchSelectQuery() { } }), tags, - false))) + false, + any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); @@ -114,7 +119,8 @@ void testDispatchSelectQuery() { } }), tags, - false)); + false, + any())); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -148,7 +154,8 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { } }), tags, - false))) + false, + any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); @@ -178,7 +185,8 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { } }), tags, - false)); + false, + any())); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -210,7 +218,8 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { } }), tags, - false))) + false, + any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithNoAuth(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); @@ -238,7 +247,8 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { } }), tags, - false)); + false, + any())); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -277,7 +287,8 @@ void testDispatchIndexQuery() { } })), tags, - true))) + true, + any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); @@ -307,7 +318,8 @@ void testDispatchIndexQuery() { } })), tags, - true)); + true, + any())); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -340,7 +352,8 @@ void testDispatchWithPPLQuery() { } }), tags, - false))) + false, + any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); @@ -369,7 +382,8 @@ void testDispatchWithPPLQuery() { } }), tags, - false)); + false, + any())); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -402,7 +416,8 @@ void testDispatchQueryWithoutATableAndDataSourceName() { } }), tags, - false))) + false, + any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); @@ -431,7 +446,8 @@ void testDispatchQueryWithoutATableAndDataSourceName() { } }), tags, - false)); + false, + any())); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -470,7 +486,8 @@ void testDispatchIndexQueryWithoutADatasourceName() { } })), tags, - true))) + true, + any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); @@ -500,7 +517,8 @@ void testDispatchIndexQueryWithoutADatasourceName() { } })), tags, - true)); + true, + any())); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -579,7 +597,8 @@ void testCancelJob() { .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); String jobId = - sparkQueryDispatcher.cancelJob(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID)); + sparkQueryDispatcher.cancelJob( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); Assertions.assertEquals(EMR_JOB_ID, jobId); } @@ -596,7 +615,7 @@ void testGetQueryResponse() { .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); JSONObject result = sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID)); + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); Assertions.assertEquals("PENDING", result.get("status")); verifyNoInteractions(jobExecutionResponseReader); } @@ -613,17 +632,29 @@ void testGetQueryResponseWithSuccess() { when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.SUCCESS))); JSONObject queryResult = new JSONObject(); - queryResult.put("data", "result"); - when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)) + Map resultMap = new HashMap<>(); + resultMap.put(STATUS_FIELD, "SUCCESS"); + resultMap.put(ERROR_FIELD, ""); + queryResult.put(DATA_FIELD, resultMap); + when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(queryResult); JSONObject result = sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID)); + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); verify(emrServerlessClient, times(1)).getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID); - verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID); - Assertions.assertEquals(new HashSet<>(Arrays.asList("data", "status")), result.keySet()); - Assertions.assertEquals("result", result.get("data")); - Assertions.assertEquals("SUCCESS", result.get("status")); + verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID, null); + Assertions.assertEquals( + new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet()); + JSONObject dataJson = new JSONObject(); + dataJson.put(ERROR_FIELD, ""); + dataJson.put(STATUS_FIELD, "SUCCESS"); + // JSONObject.similar() compares if two JSON objects are the same, but having perhaps a + // different order of its attributes. + // The equals() will compare each string caracter, one-by-one checking if it is the same, having + // the same order. + // We need similar. + Assertions.assertTrue(dataJson.similar(result.get(DATA_FIELD))); + Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD)); } @Test @@ -738,14 +769,12 @@ private String constructExpectedSparkSubmitParameterString( authParamConfigBuilder.append(authParams.get(key)); authParamConfigBuilder.append(" "); } - return " --class org.opensearch.sql.FlintJob --conf" + return " --class org.apache.spark.sql.FlintJob --conf" + " spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + " --conf" + " spark.hadoop.aws.catalog.credentials.provider.factory.class=com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory" + " --conf" - + " spark.jars=s3://flint-data-dp-eu-west-1-beta/code/flint/AWSGlueDataCatalogHiveMetaStoreAuth-1.0.jar,s3://flint-data-dp-eu-west-1-beta/code/flint/flint-catalog.jar" - + " --conf" - + " spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.1.0-SNAPSHOT" + + " spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.1.0-SNAPSHOT,org.opensearch:opensearch-spark-sql-application_2.12:0.1.0-SNAPSHOT" + " --conf" + " spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots" + " --conf" @@ -766,8 +795,9 @@ private String constructExpectedSparkSubmitParameterString( + " spark.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf" + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" - + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegateCatalog " - + authParamConfigBuilder; + + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegatingSessionCatalog " + + authParamConfigBuilder + + " --conf spark.flint.datasource.name=my_glue "; } private String withStructuredStreaming(String parameters) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java b/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java index 17305fb905..7d7ebd42b3 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java @@ -45,7 +45,23 @@ public void testGetResultFromOpensearchIndex() { new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F)); Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID)); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); - assertFalse(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID).isEmpty()); + assertFalse( + jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null).isEmpty()); + } + + @Test + public void testGetResultFromCustomIndex() { + when(client.search(any())).thenReturn(searchResponseActionFuture); + when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + when(searchResponse.status()).thenReturn(RestStatus.OK); + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F)); + Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID)); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + assertFalse( + jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, "foo").isEmpty()); } @Test @@ -58,7 +74,7 @@ public void testInvalidSearchResponse() { RuntimeException exception = assertThrows( RuntimeException.class, - () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)); + () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)); Assertions.assertEquals( "Fetching result from " + SPARK_RESPONSE_BUFFER_INDEX_NAME @@ -73,6 +89,6 @@ public void testSearchFailure() { JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); assertThrows( RuntimeException.class, - () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)); + () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 9e4cd75165..21a213c7c2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -63,7 +63,7 @@ public void setUp() { public void testDoExecute() { GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); AsyncQueryExecutionResponse asyncQueryExecutionResponse = - new AsyncQueryExecutionResponse("IN_PROGRESS", null, null); + new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); @@ -88,7 +88,8 @@ public void testDoExecuteWithSuccessResponse() { schema, Arrays.asList( tupleValue(ImmutableMap.of("name", "John", "age", 20)), - tupleValue(ImmutableMap.of("name", "Smith", "age", 30)))); + tupleValue(ImmutableMap.of("name", "Smith", "age", 30))), + null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java index 5ba5627665..711db75efb 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java @@ -29,7 +29,8 @@ void formatAsyncQueryResponse() { schema, Arrays.asList( tupleValue(ImmutableMap.of("firstname", "John", "age", 20)), - tupleValue(ImmutableMap.of("firstname", "Smith", "age", 30)))); + tupleValue(ImmutableMap.of("firstname", "Smith", "age", 30))), + null); AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); assertEquals( "{\"status\":\"success\",\"schema\":[{\"name\":\"firstname\",\"type\":\"string\"}," @@ -37,4 +38,11 @@ void formatAsyncQueryResponse() { + "[[\"John\",20],[\"Smith\",30]],\"total\":2,\"size\":2}", formatter.format(response)); } + + @Test + void formatAsyncQueryError() { + AsyncQueryResult response = new AsyncQueryResult("FAILED", null, null, "foo"); + AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); + assertEquals("{\"status\":\"FAILED\",\"error\":\"foo\"}", formatter.format(response)); + } }