From cc6b5a67a127f2cd7909edd7d8a016525f137de6 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Mon, 2 Oct 2023 15:13:35 -0700 Subject: [PATCH] Add customized result index in data source etc This PR - Introduce `spark.flint.datasource.name` parameter for data source specification. - Enhance data source creation to allow custom result indices; fallback to default if unavailable. - Include error details in the async result response, sourced from the result index. - Migrate to `org.apache.spark.sql.FlintJob` following updates in OpenSearch-Spark. - Populate query status from result index over EMR-S job status to handle edge cases where jobs may succeed, but queries or mappings fail. Testing done: 1. manual testing including if with or without custom result index async query still works 2. added new unit tests Signed-off-by: Kaituo Li --- .../datasource/model/DataSourceMetadata.java | 9 ++- .../sql/analysis/AnalyzerTestBase.java | 3 +- .../datasource/DataSourceTableScanTest.java | 3 +- .../utils/XContentParserUtils.java | 10 ++- .../resources/datasources-index-mapping.yml | 2 + .../service/DataSourceServiceImplTest.java | 6 +- .../utils/XContentParserUtilsTest.java | 2 +- .../sql/datasource/DataSourceAPIsIT.java | 15 ++-- .../sql/ppl/InformationSchemaCommandIT.java | 3 +- .../ppl/PrometheusDataSourceCommandsIT.java | 3 +- .../sql/ppl/ShowDataSourcesCommandIT.java | 3 +- .../AsyncQueryExecutorServiceImpl.java | 30 ++++++-- .../model/AsyncQueryExecutionResponse.java | 1 + .../model/AsyncQueryJobMetadata.java | 8 +- .../asyncquery/model/AsyncQueryResult.java | 11 ++- .../client/EmrServerlessClientImplEMR.java | 7 +- .../sql/spark/client/StartJobRequest.java | 1 + .../spark/data/constants/SparkConstants.java | 16 +++- .../dispatcher/SparkQueryDispatcher.java | 59 +++++++++++---- .../response/JobExecutionResponseReader.java | 16 ++-- .../TransportGetAsyncQueryResultAction.java | 3 +- .../AsyncQueryResultResponseFormatter.java | 6 ++ .../resources/job-metadata-index-mapping.yml | 5 ++ .../AsyncQueryExecutorServiceImplTest.java | 15 ++-- ...yncQueryJobMetadataStorageServiceTest.java | 12 +-- .../client/EmrServerlessClientImplTest.java | 20 ++++- .../dispatcher/SparkQueryDispatcherTest.java | 75 ++++++++++++++----- ...AsyncQueryExecutionResponseReaderTest.java | 22 +++++- ...ransportGetAsyncQueryResultActionTest.java | 5 +- ...AsyncQueryResultResponseFormatterTest.java | 10 ++- 30 files changed, 287 insertions(+), 94 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java index a61f5a7a20..866e9cadef 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java +++ b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java @@ -42,16 +42,20 @@ public class DataSourceMetadata { @JsonProperty private Map properties; + @JsonProperty private String resultIndex; + public DataSourceMetadata( String name, DataSourceType connector, List allowedRoles, - Map properties) { + Map properties, + String resultIndex) { this.name = name; this.connector = connector; this.description = StringUtils.EMPTY; this.properties = properties; this.allowedRoles = allowedRoles; + this.resultIndex = resultIndex; } public DataSourceMetadata() { @@ -69,6 +73,7 @@ public static DataSourceMetadata defaultOpenSearchDataSourceMetadata() { DEFAULT_DATASOURCE_NAME, DataSourceType.OPENSEARCH, Collections.emptyList(), - ImmutableMap.of()); + ImmutableMap.of(), + null); } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index a16d57673e..508567582b 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -199,7 +199,8 @@ public Set getDataSourceMetadata(boolean isDefaultDataSource ds.getName(), ds.getConnectorType(), Collections.emptyList(), - ImmutableMap.of())) + ImmutableMap.of(), + null)) .collect(Collectors.toSet()); } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScanTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScanTest.java index 28851f2454..069cb0eada 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScanTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScanTest.java @@ -64,7 +64,8 @@ void testIterator() { dataSource.getName(), dataSource.getConnectorType(), Collections.emptyList(), - ImmutableMap.of())) + ImmutableMap.of(), + null)) .collect(Collectors.toSet()); when(dataSourceService.getDataSourceMetadata(true)).thenReturn(dataSourceMetadata); diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java b/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java index 53f0054dc2..261f13870a 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java @@ -32,6 +32,8 @@ public class XContentParserUtils { public static final String PROPERTIES_FIELD = "properties"; public static final String ALLOWED_ROLES_FIELD = "allowedRoles"; + public static final String RESULT_INDEX_FIELD = "resultIndex"; + /** * Convert xcontent parser to DataSourceMetadata. * @@ -45,6 +47,7 @@ public static DataSourceMetadata toDataSourceMetadata(XContentParser parser) thr DataSourceType connector = null; List allowedRoles = new ArrayList<>(); Map properties = new HashMap<>(); + String resultIndex = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -73,6 +76,9 @@ public static DataSourceMetadata toDataSourceMetadata(XContentParser parser) thr properties.put(key, value); } break; + case RESULT_INDEX_FIELD: + resultIndex = parser.textOrNull(); + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -80,7 +86,8 @@ public static DataSourceMetadata toDataSourceMetadata(XContentParser parser) thr if (name == null || connector == null) { throw new IllegalArgumentException("name and connector are required fields."); } - return new DataSourceMetadata(name, description, connector, allowedRoles, properties); + return new DataSourceMetadata( + name, description, connector, allowedRoles, properties, resultIndex); } /** @@ -122,6 +129,7 @@ public static XContentBuilder convertToXContent(DataSourceMetadata metadata) thr builder.field(entry.getKey(), entry.getValue()); } builder.endObject(); + builder.field(RESULT_INDEX_FIELD, metadata.getResultIndex()); builder.endObject(); return builder; } diff --git a/datasources/src/main/resources/datasources-index-mapping.yml b/datasources/src/main/resources/datasources-index-mapping.yml index cb600ae825..0206a97886 100644 --- a/datasources/src/main/resources/datasources-index-mapping.yml +++ b/datasources/src/main/resources/datasources-index-mapping.yml @@ -14,4 +14,6 @@ properties: keyword: type: keyword connector: + type: keyword + resultIndex: type: keyword \ No newline at end of file diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index eb28495541..c8312e6013 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -344,7 +344,8 @@ void testRemovalOfAuthorizationInfo() { "testDS", DataSourceType.PROMETHEUS, Collections.singletonList("prometheus_access"), - properties); + properties, + null); when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) .thenReturn(Optional.of(dataSourceMetadata)); @@ -398,7 +399,8 @@ void testGetRawDataSourceMetadata() { "testDS", DataSourceType.PROMETHEUS, Collections.singletonList("prometheus_access"), - properties); + properties, + null); when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) .thenReturn(Optional.of(dataSourceMetadata)); diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java index eab0b8e168..d134293456 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java @@ -31,7 +31,7 @@ public void testConvertToXContent() { XContentBuilder contentBuilder = XContentParserUtils.convertToXContent(dataSourceMetadata); String contentString = BytesReference.bytes(contentBuilder).utf8ToString(); Assertions.assertEquals( - "{\"name\":\"testDS\",\"description\":\"\",\"connector\":\"PROMETHEUS\",\"allowedRoles\":[\"prometheus_access\"],\"properties\":{\"prometheus.uri\":\"https://localhost:9090\"}}", + "{\"name\":\"testDS\",\"description\":\"\",\"connector\":\"PROMETHEUS\",\"allowedRoles\":[\"prometheus_access\"],\"properties\":{\"prometheus.uri\":\"https://localhost:9090\"},\"resultIndex\":null}", contentString); } diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java index 0b69a459a1..275f3ade48 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java @@ -62,7 +62,8 @@ public void createDataSourceAPITest() { "prometheus.auth.username", "username", "prometheus.auth.password", - "password")); + "password"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); Response response = client().performRequest(createRequest); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); @@ -92,7 +93,8 @@ public void updateDataSourceAPITest() { "update_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "https://localhost:9090")); + ImmutableMap.of("prometheus.uri", "https://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); client().performRequest(createRequest); // Datasource is not immediately created. so introducing a sleep of 2s. @@ -104,7 +106,8 @@ public void updateDataSourceAPITest() { "update_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "https://randomtest.com:9090")); + ImmutableMap.of("prometheus.uri", "https://randomtest.com:9090"), + null); Request updateRequest = getUpdateDataSourceRequest(updateDSM); Response updateResponse = client().performRequest(updateRequest); Assert.assertEquals(200, updateResponse.getStatusLine().getStatusCode()); @@ -137,7 +140,8 @@ public void deleteDataSourceTest() { "delete_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "https://localhost:9090")); + ImmutableMap.of("prometheus.uri", "https://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); client().performRequest(createRequest); // Datasource is not immediately created. so introducing a sleep of 2s. @@ -175,7 +179,8 @@ public void getAllDataSourceTest() { "get_all_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "https://localhost:9090")); + ImmutableMap.of("prometheus.uri", "https://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); client().performRequest(createRequest); // Datasource is not immediately created. so introducing a sleep of 2s. diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/InformationSchemaCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/InformationSchemaCommandIT.java index cf7cfcdb39..7b694ce222 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/InformationSchemaCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/InformationSchemaCommandIT.java @@ -46,7 +46,8 @@ protected void init() throws InterruptedException, IOException { "my_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "http://localhost:9090")); + ImmutableMap.of("prometheus.uri", "http://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); Response response = client().performRequest(createRequest); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusDataSourceCommandsIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusDataSourceCommandsIT.java index 8d72f02e29..b81b7f9517 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusDataSourceCommandsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusDataSourceCommandsIT.java @@ -58,7 +58,8 @@ protected void init() throws InterruptedException, IOException { "my_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "http://localhost:9090")); + ImmutableMap.of("prometheus.uri", "http://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); Response response = client().performRequest(createRequest); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ShowDataSourcesCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ShowDataSourcesCommandIT.java index c9c4854212..2180048563 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ShowDataSourcesCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ShowDataSourcesCommandIT.java @@ -46,7 +46,8 @@ protected void init() throws InterruptedException, IOException { "my_prometheus", DataSourceType.PROMETHEUS, ImmutableList.of(), - ImmutableMap.of("prometheus.uri", "http://localhost:9090")); + ImmutableMap.of("prometheus.uri", "http://localhost:9090"), + null); Request createRequest = getCreateDataSourceRequest(createDSM); Response response = client().performRequest(createRequest); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index a86aa82695..0e4ce72d37 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -7,6 +7,8 @@ import static org.opensearch.sql.common.setting.Settings.Key.CLUSTER_NAME; import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.JobRunState; import java.security.AccessController; @@ -15,6 +17,7 @@ import java.util.List; import java.util.Optional; import lombok.AllArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; import org.json.JSONObject; import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; @@ -64,17 +67,22 @@ public CreateAsyncQueryResponse createAsyncQuery( SparkExecutionEngineConfig.toSparkExecutionEngineConfig( sparkExecutionEngineConfigString)); ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME); - String jobId = + String query = createAsyncQueryRequest.getQuery(); + Pair jobIdResultIndexPair = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( sparkExecutionEngineConfig.getApplicationId(), - createAsyncQueryRequest.getQuery(), + query, createAsyncQueryRequest.getLang(), sparkExecutionEngineConfig.getExecutionRoleARN(), clusterName.value())); + asyncQueryJobMetadataStorageService.storeJobMetadata( - new AsyncQueryJobMetadata(jobId, sparkExecutionEngineConfig.getApplicationId())); - return new CreateAsyncQueryResponse(jobId); + new AsyncQueryJobMetadata( + jobIdResultIndexPair.getLeft(), + sparkExecutionEngineConfig.getApplicationId(), + jobIdResultIndexPair.getRight())); + return new CreateAsyncQueryResponse(jobIdResultIndexPair.getLeft()); } @Override @@ -85,8 +93,10 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { if (jobMetadata.isPresent()) { JSONObject jsonObject = sparkQueryDispatcher.getQueryResponse( - jobMetadata.get().getApplicationId(), jobMetadata.get().getJobId()); - if (JobRunState.SUCCESS.toString().equals(jsonObject.getString("status"))) { + jobMetadata.get().getApplicationId(), + jobMetadata.get().getJobId(), + jobMetadata.get().getResultIndex()); + if (JobRunState.SUCCESS.toString().equals(jsonObject.getString(STATUS_FIELD))) { DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = new DefaultSparkSqlFunctionResponseHandle(jsonObject); List result = new ArrayList<>(); @@ -94,9 +104,13 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { result.add(sparkSqlFunctionResponseHandle.next()); } return new AsyncQueryExecutionResponse( - JobRunState.SUCCESS.toString(), sparkSqlFunctionResponseHandle.schema(), result); + JobRunState.SUCCESS.toString(), sparkSqlFunctionResponseHandle.schema(), result, null); } else { - return new AsyncQueryExecutionResponse(jsonObject.getString("status"), null, null); + return new AsyncQueryExecutionResponse( + jsonObject.optString(STATUS_FIELD, JobRunState.FAILED.toString()), + null, + null, + jsonObject.optString(ERROR_FIELD, "")); } } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java index 84dcc490ba..d2e54af004 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java @@ -18,4 +18,5 @@ public class AsyncQueryExecutionResponse { private final String status; private final ExecutionEngine.Schema schema; private final List results; + private final String error; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index 60ec53987e..3e4e801105 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -26,6 +26,7 @@ public class AsyncQueryJobMetadata { private String jobId; private String applicationId; + private String resultIndex; @Override public String toString() { @@ -44,6 +45,7 @@ public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) builder.startObject(); builder.field("jobId", metadata.getJobId()); builder.field("applicationId", metadata.getApplicationId()); + builder.field("resultIndex", metadata.getResultIndex()); builder.endObject(); return builder; } @@ -77,6 +79,7 @@ public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOExceptio public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws IOException { String jobId = null; String applicationId = null; + String resultIndex = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -88,6 +91,9 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws case "applicationId": applicationId = parser.textOrNull(); break; + case "resultIndex": + resultIndex = parser.textOrNull(); + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -95,6 +101,6 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws if (jobId == null || applicationId == null) { throw new IllegalArgumentException("jobId and applicationId are required fields."); } - return new AsyncQueryJobMetadata(jobId, applicationId); + return new AsyncQueryJobMetadata(jobId, applicationId, resultIndex); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java index 6d6bce8fbc..c229aa3920 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java @@ -11,19 +11,26 @@ public class AsyncQueryResult extends QueryResult { @Getter private final String status; + @Getter private final String error; public AsyncQueryResult( String status, ExecutionEngine.Schema schema, Collection exprValues, - Cursor cursor) { + Cursor cursor, + String error) { super(schema, exprValues, cursor); this.status = status; + this.error = error; } public AsyncQueryResult( - String status, ExecutionEngine.Schema schema, Collection exprValues) { + String status, + ExecutionEngine.Schema schema, + Collection exprValues, + String error) { super(schema, exprValues); this.status = status; + this.error = error; } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java index 83e570ece2..e7f349c896 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java @@ -34,6 +34,10 @@ public EmrServerlessClientImplEMR(AWSEMRServerless emrServerless) { @Override public String startJobRun(StartJobRequest startJobRequest) { + String resultIndex = + startJobRequest.getResultIndex() == null + ? SPARK_RESPONSE_BUFFER_INDEX_NAME + : startJobRequest.getResultIndex(); StartJobRunRequest request = new StartJobRunRequest() .withName(startJobRequest.getJobName()) @@ -45,8 +49,7 @@ public String startJobRun(StartJobRequest startJobRequest) { .withSparkSubmit( new SparkSubmit() .withEntryPoint(SPARK_SQL_APPLICATION_JAR) - .withEntryPointArguments( - startJobRequest.getQuery(), SPARK_RESPONSE_BUFFER_INDEX_NAME) + .withEntryPointArguments(startJobRequest.getQuery(), resultIndex) .withSparkSubmitParameters(startJobRequest.getSparkSubmitParams()))); StartJobRunResult startJobRunResult = AccessController.doPrivileged( diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java index 94689c7030..26f0437acf 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -20,4 +20,5 @@ public class StartJobRequest { private final String executionRoleArn; private final String sparkSubmitParams; private final Map tags; + private final String resultIndex; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 21db8b9478..8618511b86 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -8,13 +8,22 @@ public class SparkConstants { public static final String EMR = "emr"; public static final String STEP_ID_FIELD = "stepId.keyword"; + + public static final String JOB_ID_FIELD = "jobRunId"; + + public static final String STATUS_FIELD = "status"; + + public static final String DATA_FIELD = "data"; + + public static final String ERROR_FIELD = "error"; + // TODO should be replaced with mvn jar. public static final String SPARK_SQL_APPLICATION_JAR = - "s3://flint-data-dp-eu-west-1-beta/code/flint/sql-job.jar"; + "s3://flint-data-dp-eu-west-1-beta/code/flint/sql-job-assembly-0.1.0-SNAPSHOT.jar"; public static final String SPARK_RESPONSE_BUFFER_INDEX_NAME = ".query_execution_result"; // TODO should be replaced with mvn jar. public static final String FLINT_INTEGRATION_JAR = - "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; + "s3://flint-data-dp-eu-west-1-beta/code/flint/sql-job.jar"; // TODO should be replaced with mvn jar. public static final String GLUE_CATALOG_HIVE_JAR = "s3://flint-data-dp-eu-west-1-beta/code/flint/AWSGlueDataCatalogHiveMetaStoreAuth-1.0.jar"; @@ -26,7 +35,7 @@ public class SparkConstants { public static final String FLINT_DEFAULT_SCHEME = "http"; public static final String FLINT_DEFAULT_AUTH = "-1"; public static final String FLINT_DEFAULT_REGION = "us-west-2"; - public static final String DEFAULT_CLASS_NAME = "org.opensearch.sql.FlintJob"; + public static final String DEFAULT_CLASS_NAME = "org.apache.spark.sql.FlintJob"; public static final String S3_AWS_CREDENTIALS_PROVIDER_KEY = "spark.hadoop.fs.s3.customAWSCredentialsProvider"; public static final String DRIVER_ENV_ASSUME_ROLE_ARN_KEY = @@ -49,6 +58,7 @@ public class SparkConstants { public static final String FLINT_INDEX_STORE_AWSREGION_KEY = "spark.datasource.flint.region"; public static final String FLINT_CREDENTIALS_PROVIDER_KEY = "spark.datasource.flint.customAWSCredentialsProvider"; + public static final String FLINT_DATA_SOURCE_KEY = "spark.flint.datasource.name"; public static final String SPARK_SQL_EXTENSIONS_KEY = "spark.sql.extensions"; public static final String HIVE_METASTORE_CLASS_KEY = "spark.hadoop.hive.metastore.client.factory.class"; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 904d199663..4e3c332f45 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -5,8 +5,11 @@ package org.opensearch.sql.spark.dispatcher; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.DRIVER_ENV_ASSUME_ROLE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DATA_SOURCE_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DELEGATE_CATALOG; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AWSREGION_KEY; @@ -14,6 +17,7 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_GLUE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -23,6 +27,7 @@ import java.util.HashMap; import java.util.Map; import lombok.AllArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; import org.json.JSONObject; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; @@ -56,18 +61,39 @@ public class SparkQueryDispatcher { private JobExecutionResponseReader jobExecutionResponseReader; - public String dispatch(DispatchQueryRequest dispatchQueryRequest) { - return EMRServerlessClient.startJobRun(getStartJobRequest(dispatchQueryRequest)); + public Pair dispatch(DispatchQueryRequest dispatchQueryRequest) { + StartJobRequest startJobRequest = getStartJobRequest(dispatchQueryRequest); + return Pair.of( + EMRServerlessClient.startJobRun(startJobRequest), startJobRequest.getResultIndex()); } - // TODO : Fetch from Result Index and then make call to EMR Serverless. - public JSONObject getQueryResponse(String applicationId, String queryId) { + public JSONObject getQueryResponse(String applicationId, String queryId, String resultIndex) { GetJobRunResult getJobRunResult = EMRServerlessClient.getJobRunResult(applicationId, queryId); - JSONObject result = new JSONObject(); - if (getJobRunResult.getJobRun().getState().equals(JobRunState.SUCCESS.toString())) { - result = jobExecutionResponseReader.getResultFromOpensearchIndex(queryId); + String jobState = getJobRunResult.getJobRun().getState(); + JSONObject result = + (jobState.equals(JobRunState.SUCCESS.toString())) + ? jobExecutionResponseReader.getResultFromOpensearchIndex(queryId, resultIndex) + : new JSONObject(); + // if result index document has a status, we are gonna use the status directly; otherwise, we + // will use emr-s job status + // a job is successful does not mean there is no error in execution. For example, even if result + // index mapping + // is incorrect, we still write query result and let the job finish. + if (result.has(DATA_FIELD)) { + JSONObject items = result.getJSONObject(DATA_FIELD); + + // If items have STATUS_FIELD, use it; otherwise, use jobState + String status = items.optString(STATUS_FIELD, jobState); + result.put(STATUS_FIELD, status); + + // If items have ERROR_FIELD, use it; otherwise, set empty string + String error = items.optString(ERROR_FIELD, ""); + result.put(ERROR_FIELD, error); + } else { + result.put(STATUS_FIELD, jobState); + result.put(ERROR_FIELD, "Spark failed to write back results."); } - result.put("status", getJobRunResult.getJobRun().getState()); + return result; } @@ -127,6 +153,7 @@ private String constructSparkParameters(String datasourceName) { s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_AWSREGION_KEY, region); s3GlueSparkSubmitParameters.addParameter( "spark.sql.catalog." + datasourceName, FLINT_DELEGATE_CATALOG); + s3GlueSparkSubmitParameters.addParameter(FLINT_DATA_SOURCE_KEY, datasourceName); return s3GlueSparkSubmitParameters.toString(); } @@ -138,9 +165,10 @@ private StartJobRequest getStartJobRequestForNonIndexQueries( if (fullyQualifiedTableName.getDatasourceName() == null) { throw new UnsupportedOperationException("Missing datasource in the query syntax."); } - dataSourceUserAuthorizationHelper.authorizeDataSource( + DataSourceMetadata dataSource = this.dataSourceService.getRawDataSourceMetadata( - fullyQualifiedTableName.getDatasourceName())); + fullyQualifiedTableName.getDatasourceName()); + dataSourceUserAuthorizationHelper.authorizeDataSource(dataSource); String jobName = dispatchQueryRequest.getClusterName() + ":" @@ -154,7 +182,8 @@ private StartJobRequest getStartJobRequestForNonIndexQueries( dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), constructSparkParameters(fullyQualifiedTableName.getDatasourceName()), - tags); + tags, + dataSource.getResultIndex()); return startJobRequest; } @@ -166,9 +195,10 @@ private StartJobRequest getStartJobRequestForIndexRequest( if (fullyQualifiedTableName.getDatasourceName() == null) { throw new UnsupportedOperationException("Queries without a datasource are not supported"); } - dataSourceUserAuthorizationHelper.authorizeDataSource( + DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata( - fullyQualifiedTableName.getDatasourceName())); + fullyQualifiedTableName.getDatasourceName()); + dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = getJobNameForIndexQuery(dispatchQueryRequest, indexDetails, fullyQualifiedTableName); Map tags = @@ -181,7 +211,8 @@ private StartJobRequest getStartJobRequestForIndexRequest( dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), constructSparkParameters(fullyQualifiedTableName.getDatasourceName()), - tags); + tags, + dataSourceMetadata.getResultIndex()); return startJobRequest; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java index 8abb7cd11f..5da0ef44fe 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java @@ -5,8 +5,9 @@ package org.opensearch.sql.spark.response; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.JOB_ID_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.STEP_ID_FIELD; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -33,13 +34,14 @@ public JobExecutionResponseReader(Client client) { this.client = client; } - public JSONObject getResultFromOpensearchIndex(String jobId) { - return searchInSparkIndex(QueryBuilders.termQuery(STEP_ID_FIELD, jobId)); + public JSONObject getResultFromOpensearchIndex(String jobId, String resultIndex) { + return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultIndex); } - private JSONObject searchInSparkIndex(QueryBuilder query) { + private JSONObject searchInSparkIndex(QueryBuilder query, String resultIndex) { SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(SPARK_RESPONSE_BUFFER_INDEX_NAME); + String searchResultIndex = resultIndex == null ? SPARK_RESPONSE_BUFFER_INDEX_NAME : resultIndex; + searchRequest.indices(searchResultIndex); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(query); searchRequest.source(searchSourceBuilder); @@ -53,13 +55,13 @@ private JSONObject searchInSparkIndex(QueryBuilder query) { if (searchResponse.status().getStatus() != 200) { throw new RuntimeException( "Fetching result from " - + SPARK_RESPONSE_BUFFER_INDEX_NAME + + searchResultIndex + " index failed with status : " + searchResponse.status()); } else { JSONObject data = new JSONObject(); for (SearchHit searchHit : searchResponse.getHits().getHits()) { - data.put("data", searchHit.getSourceAsMap()); + data.put(DATA_FIELD, searchHit.getSourceAsMap()); } return data; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java index c23706b184..5c784cf04c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java @@ -61,7 +61,8 @@ protected void doExecute( asyncQueryExecutionResponse.getStatus(), asyncQueryExecutionResponse.getSchema(), asyncQueryExecutionResponse.getResults(), - Cursor.None)); + Cursor.None, + asyncQueryExecutionResponse.getError())); listener.onResponse(new GetAsyncQueryResultActionResponse(responseContent)); } catch (Exception e) { listener.onFailure(e); diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java b/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java index c9eb5bbf59..3a2a5b110f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java @@ -11,6 +11,7 @@ import lombok.Builder; import lombok.Getter; import lombok.RequiredArgsConstructor; +import org.opensearch.core.common.Strings; import org.opensearch.sql.protocol.response.QueryResult; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; @@ -53,6 +54,10 @@ public Object buildJsonObject(AsyncQueryResult response) { json.datarows(fetchDataRows(response)); } json.status(response.getStatus()); + if (!Strings.isEmpty(response.getError())) { + json.error(response.getError()); + } + return json.build(); } @@ -79,6 +84,7 @@ public static class JsonResponse { private Integer total; private Integer size; + private final String error; } @RequiredArgsConstructor diff --git a/spark/src/main/resources/job-metadata-index-mapping.yml b/spark/src/main/resources/job-metadata-index-mapping.yml index ec2c83a4df..3a39b989a2 100644 --- a/spark/src/main/resources/job-metadata-index-mapping.yml +++ b/spark/src/main/resources/job-metadata-index-mapping.yml @@ -14,6 +14,11 @@ properties: keyword: type: keyword applicationId: + type: text + fields: + keyword: + type: keyword + resultIndex: type: text fields: keyword: diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 1ff2493e6d..d2c7604a57 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.Optional; +import org.apache.commons.lang3.tuple.Pair; import org.json.JSONObject; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -61,11 +62,11 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) - .thenReturn(EMR_JOB_ID); + .thenReturn(Pair.of(EMR_JOB_ID, null)); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) - .storeJobMetadata(new AsyncQueryJobMetadata(EMR_JOB_ID, "00fd775baqpu4g0p")); + .storeJobMetadata(new AsyncQueryJobMetadata(EMR_JOB_ID, "00fd775baqpu4g0p", null)); verify(settings, times(1)).getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); verify(settings, times(1)).getSettingValue(Settings.Key.CLUSTER_NAME); verify(sparkQueryDispatcher, times(1)) @@ -102,10 +103,10 @@ void testGetAsyncQueryResultsWithInProgressJob() { new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null))); JSONObject jobResult = new JSONObject(); jobResult.put("status", JobRunState.PENDING.toString()); - when(sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID)) + when(sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID, null)) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); @@ -119,10 +120,10 @@ void testGetAsyncQueryResultsWithInProgressJob() { @Test void testGetAsyncQueryResultsWithSuccessJob() throws IOException { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null))); JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); jobResult.put("status", JobRunState.SUCCESS.toString()); - when(sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID)) + when(sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID, null)) .thenReturn(jobResult); AsyncQueryExecutorServiceImpl jobExecutorService = @@ -179,7 +180,7 @@ void testCancelJob() { new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null))); when(sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID)).thenReturn(EMR_JOB_ID); String jobId = asyncQueryExecutorService.cancelQuery(EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, jobId); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index fe9da12ef0..3a721b6825 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -66,7 +66,7 @@ public void testStoreJobMetadata() { Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); @@ -83,7 +83,7 @@ public void testStoreJobMetadataWithOutCreatingIndex() { Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); @@ -105,7 +105,7 @@ public void testStoreJobMetadataWithException() { .thenThrow(new RuntimeException("error while indexing")); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, @@ -129,7 +129,7 @@ public void testStoreJobMetadataWithIndexCreationFailed() { .thenReturn(new CreateIndexResponse(false, false, JOB_METADATA_INDEX)); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, @@ -157,7 +157,7 @@ public void testStoreJobMetadataFailedWithNotFoundResponse() { Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, @@ -183,7 +183,7 @@ public void testGetJobMetadata() { new SearchHits( new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); Mockito.when(searchHit.getSourceAsString()).thenReturn(asyncQueryJobMetadata.toString()); Optional jobMetadataOptional = diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 0765b90534..a1ded1683f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -44,7 +44,25 @@ void testStartJobRun() { EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, SPARK_SUBMIT_PARAMETERS, - new HashMap<>())); + new HashMap<>(), + null)); + } + + @Test + void testStartJobRunResultIndex() { + StartJobRunResult response = new StartJobRunResult(); + when(emrServerless.startJobRun(any())).thenReturn(response); + + EmrServerlessClientImplEMR emrServerlessClient = new EmrServerlessClientImplEMR(emrServerless); + emrServerlessClient.startJobRun( + new StartJobRequest( + QUERY, + EMRS_JOB_NAME, + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + SPARK_SUBMIT_PARAMETERS, + new HashMap<>(), + "foo")); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index d83505fde0..8708bf387a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.dispatcher; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -14,6 +15,9 @@ import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -68,15 +72,22 @@ void testDispatchSelectQuery() { EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, constructExpectedSparkSubmitParameterString(), - tags))) + tags, + any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); String jobId = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + sparkQueryDispatcher + .dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)) + .getLeft(); verify(EMRServerlessClient, times(1)) .startJobRun( new StartJobRequest( @@ -85,7 +96,8 @@ void testDispatchSelectQuery() { EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, constructExpectedSparkSubmitParameterString(), - tags)); + tags, + any())); Assertions.assertEquals(EMR_JOB_ID, jobId); } @@ -113,15 +125,22 @@ void testDispatchIndexQuery() { EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, constructExpectedSparkSubmitParameterString(), - tags))) + tags, + any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); String jobId = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + sparkQueryDispatcher + .dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)) + .getLeft(); verify(EMRServerlessClient, times(1)) .startJobRun( new StartJobRequest( @@ -130,7 +149,8 @@ void testDispatchIndexQuery() { EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, constructExpectedSparkSubmitParameterString(), - tags)); + tags, + any())); Assertions.assertEquals(EMR_JOB_ID, jobId); } @@ -330,7 +350,8 @@ void testGetQueryResponse() { jobExecutionResponseReader); when(EMRServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); - JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); + JSONObject result = + sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID, null); Assertions.assertEquals("PENDING", result.get("status")); verifyNoInteractions(jobExecutionResponseReader); } @@ -346,19 +367,32 @@ void testGetQueryResponseWithSuccess() { when(EMRServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.SUCCESS))); JSONObject queryResult = new JSONObject(); - queryResult.put("data", "result"); - when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)) + Map resultMap = new HashMap<>(); + resultMap.put(STATUS_FIELD, "SUCCESS"); + resultMap.put(ERROR_FIELD, ""); + queryResult.put(DATA_FIELD, resultMap); + when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(queryResult); - JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); + JSONObject result = + sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID, null); verify(EMRServerlessClient, times(1)).getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID); - verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID); - Assertions.assertEquals(new HashSet<>(Arrays.asList("data", "status")), result.keySet()); - Assertions.assertEquals("result", result.get("data")); - Assertions.assertEquals("SUCCESS", result.get("status")); + verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID, null); + Assertions.assertEquals( + new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet()); + JSONObject dataJson = new JSONObject(); + dataJson.put(ERROR_FIELD, ""); + dataJson.put(STATUS_FIELD, "SUCCESS"); + // JSONObject.similar() compares if two JSON objects are the same, but having perhaps a + // different order of its attributes. + // The equals() will compare each string caracter, one-by-one checking if it is the same, having + // the same order. + // We need similar. + Assertions.assertTrue(dataJson.similar(result.get(DATA_FIELD))); + Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD)); } private String constructExpectedSparkSubmitParameterString() { - return " --class org.opensearch.sql.FlintJob --conf" + return " --class org.apache.spark.sql.FlintJob --conf" + " spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + " --conf" + " spark.hadoop.aws.catalog.credentials.provider.factory.class=com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory" @@ -386,7 +420,8 @@ private String constructExpectedSparkSubmitParameterString() { + " spark.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf" + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" - + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegateCatalog "; + + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegateCatalog " + + " --conf spark.flint.datasource.name=my_glue "; } private DataSourceMetadata constructMyGlueDataSourceMetadata() { diff --git a/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java b/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java index 17305fb905..7d7ebd42b3 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java @@ -45,7 +45,23 @@ public void testGetResultFromOpensearchIndex() { new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F)); Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID)); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); - assertFalse(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID).isEmpty()); + assertFalse( + jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null).isEmpty()); + } + + @Test + public void testGetResultFromCustomIndex() { + when(client.search(any())).thenReturn(searchResponseActionFuture); + when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + when(searchResponse.status()).thenReturn(RestStatus.OK); + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F)); + Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID)); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + assertFalse( + jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, "foo").isEmpty()); } @Test @@ -58,7 +74,7 @@ public void testInvalidSearchResponse() { RuntimeException exception = assertThrows( RuntimeException.class, - () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)); + () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)); Assertions.assertEquals( "Fetching result from " + SPARK_RESPONSE_BUFFER_INDEX_NAME @@ -73,6 +89,6 @@ public void testSearchFailure() { JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); assertThrows( RuntimeException.class, - () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)); + () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 9e4cd75165..21a213c7c2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -63,7 +63,7 @@ public void setUp() { public void testDoExecute() { GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); AsyncQueryExecutionResponse asyncQueryExecutionResponse = - new AsyncQueryExecutionResponse("IN_PROGRESS", null, null); + new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); @@ -88,7 +88,8 @@ public void testDoExecuteWithSuccessResponse() { schema, Arrays.asList( tupleValue(ImmutableMap.of("name", "John", "age", 20)), - tupleValue(ImmutableMap.of("name", "Smith", "age", 30)))); + tupleValue(ImmutableMap.of("name", "Smith", "age", 30))), + null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java index 5ba5627665..711db75efb 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java @@ -29,7 +29,8 @@ void formatAsyncQueryResponse() { schema, Arrays.asList( tupleValue(ImmutableMap.of("firstname", "John", "age", 20)), - tupleValue(ImmutableMap.of("firstname", "Smith", "age", 30)))); + tupleValue(ImmutableMap.of("firstname", "Smith", "age", 30))), + null); AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); assertEquals( "{\"status\":\"success\",\"schema\":[{\"name\":\"firstname\",\"type\":\"string\"}," @@ -37,4 +38,11 @@ void formatAsyncQueryResponse() { + "[[\"John\",20],[\"Smith\",30]],\"total\":2,\"size\":2}", formatter.format(response)); } + + @Test + void formatAsyncQueryError() { + AsyncQueryResult response = new AsyncQueryResult("FAILED", null, null, "foo"); + AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); + assertEquals("{\"status\":\"FAILED\",\"error\":\"foo\"}", formatter.format(response)); + } }