Skip to content

Commit

Permalink
Enable PPL lang and add datsource to async query API
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <reddyvam@amazon.com>
  • Loading branch information
vmmusings committed Oct 2, 2023
1 parent ae10857 commit 26dab6c
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 151 deletions.
1 change: 1 addition & 0 deletions docs/user/interfaces/asyncqueryinterface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Sample Request::
curl --location 'http://localhost:9200/_plugins/_async_query' \
--header 'Content-Type: application/json' \
--data '{
"datasource" : "my_glue",
"lang" : "sql",
"query" : "select * from my_glue.default.http_logs limit 10"
}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public CreateAsyncQueryResponse createAsyncQuery(
new DispatchQueryRequest(
sparkExecutionEngineConfig.getApplicationId(),
createAsyncQueryRequest.getQuery(),
createAsyncQueryRequest.getDatasource(),
createAsyncQueryRequest.getLang(),
sparkExecutionEngineConfig.getExecutionRoleARN(),
clusterName.value()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class SparkQueryDispatcher {
public static final String TABLE_TAG_KEY = "table";
public static final String CLUSTER_NAME_TAG_KEY = "cluster";

private EMRServerlessClient EMRServerlessClient;
private EMRServerlessClient emrServerlessClient;

private DataSourceService dataSourceService;

Expand All @@ -57,12 +57,12 @@ public class SparkQueryDispatcher {
private JobExecutionResponseReader jobExecutionResponseReader;

public String dispatch(DispatchQueryRequest dispatchQueryRequest) {
return EMRServerlessClient.startJobRun(getStartJobRequest(dispatchQueryRequest));
return emrServerlessClient.startJobRun(getStartJobRequest(dispatchQueryRequest));
}

// TODO : Fetch from Result Index and then make call to EMR Serverless.
public JSONObject getQueryResponse(String applicationId, String queryId) {
GetJobRunResult getJobRunResult = EMRServerlessClient.getJobRunResult(applicationId, queryId);
GetJobRunResult getJobRunResult = emrServerlessClient.getJobRunResult(applicationId, queryId);
JSONObject result = new JSONObject();
if (getJobRunResult.getJobRun().getState().equals(JobRunState.SUCCESS.toString())) {
result = jobExecutionResponseReader.getResultFromOpensearchIndex(queryId);
Expand All @@ -72,20 +72,23 @@ public JSONObject getQueryResponse(String applicationId, String queryId) {
}

public String cancelJob(String applicationId, String jobId) {
CancelJobRunResult cancelJobRunResult = EMRServerlessClient.cancelJobRun(applicationId, jobId);
CancelJobRunResult cancelJobRunResult = emrServerlessClient.cancelJobRun(applicationId, jobId);
return cancelJobRunResult.getJobRunId();
}

// we currently don't support index queries in PPL language.
// so we are treating all of them as non-index queries which don't require any kind of query
// parsing.
private StartJobRequest getStartJobRequest(DispatchQueryRequest dispatchQueryRequest) {
if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) {
if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery()))
return getStartJobRequestForIndexRequest(dispatchQueryRequest);
else {
return getStartJobRequestForNonIndexQueries(dispatchQueryRequest);
}
} else {
return getStartJobRequestForNonIndexQueries(dispatchQueryRequest);
}
throw new UnsupportedOperationException(
String.format("UnSupported Lang type:: %s", dispatchQueryRequest.getLangType()));
}

private String getDataSourceRoleARN(DataSourceMetadata dataSourceMetadata) {
Expand Down Expand Up @@ -133,27 +136,17 @@ private String constructSparkParameters(String datasourceName) {
private StartJobRequest getStartJobRequestForNonIndexQueries(
DispatchQueryRequest dispatchQueryRequest) {
StartJobRequest startJobRequest;
FullyQualifiedTableName fullyQualifiedTableName =
SQLQueryUtils.extractFullyQualifiedTableName(dispatchQueryRequest.getQuery());
if (fullyQualifiedTableName.getDatasourceName() == null) {
throw new UnsupportedOperationException("Missing datasource in the query syntax.");
}
dataSourceUserAuthorizationHelper.authorizeDataSource(
this.dataSourceService.getRawDataSourceMetadata(
fullyQualifiedTableName.getDatasourceName()));
String jobName =
dispatchQueryRequest.getClusterName()
+ ":"
+ fullyQualifiedTableName.getFullyQualifiedName();
Map<String, String> tags =
getDefaultTagsForJobSubmission(dispatchQueryRequest, fullyQualifiedTableName);
this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()));
String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query";
Map<String, String> tags = getDefaultTagsForJobSubmission(dispatchQueryRequest);
startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
constructSparkParameters(fullyQualifiedTableName.getDatasourceName()),
constructSparkParameters(dispatchQueryRequest.getDatasource()),
tags);
return startJobRequest;
}
Expand All @@ -163,46 +156,54 @@ private StartJobRequest getStartJobRequestForIndexRequest(
StartJobRequest startJobRequest;
IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName();
if (fullyQualifiedTableName.getDatasourceName() == null) {
throw new UnsupportedOperationException("Queries without a datasource are not supported");
}
dataSourceUserAuthorizationHelper.authorizeDataSource(
this.dataSourceService.getRawDataSourceMetadata(
fullyQualifiedTableName.getDatasourceName()));
this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()));
String jobName =
getJobNameForIndexQuery(dispatchQueryRequest, indexDetails, fullyQualifiedTableName);
Map<String, String> tags =
getDefaultTagsForJobSubmission(dispatchQueryRequest, fullyQualifiedTableName);
Map<String, String> tags = getDefaultTagsForJobSubmission(dispatchQueryRequest);
tags.put(INDEX_TAG_KEY, indexDetails.getIndexName());
tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName());
tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName());
startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
constructSparkParameters(fullyQualifiedTableName.getDatasourceName()),
constructSparkParameters(dispatchQueryRequest.getDatasource()),
tags);
return startJobRequest;
}

private static Map<String, String> getDefaultTagsForJobSubmission(
DispatchQueryRequest dispatchQueryRequest, FullyQualifiedTableName fullyQualifiedTableName) {
DispatchQueryRequest dispatchQueryRequest) {
Map<String, String> tags = new HashMap<>();
tags.put(CLUSTER_NAME_TAG_KEY, dispatchQueryRequest.getClusterName());
tags.put(DATASOURCE_TAG_KEY, fullyQualifiedTableName.getDatasourceName());
tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName());
tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName());
tags.put(DATASOURCE_TAG_KEY, dispatchQueryRequest.getDatasource());
return tags;
}

// Our queries work with datasource name and without datasource name.
// Inorder to have a constant jobName in both the scenarios,
// we are adding data source name from dispatcher request to the jobName.
private static String getJobNameForIndexQuery(
DispatchQueryRequest dispatchQueryRequest,
IndexDetails indexDetails,
FullyQualifiedTableName fullyQualifiedTableName) {
return dispatchQueryRequest.getClusterName()
+ ":"
+ fullyQualifiedTableName.getFullyQualifiedName()
+ "."
+ indexDetails.getIndexName();
if (fullyQualifiedTableName.getDatasourceName() == null) {
return dispatchQueryRequest.getClusterName()
+ ":"
+ dispatchQueryRequest.getDatasource()
+ "."
+ fullyQualifiedTableName.getFullyQualifiedName()
+ "."
+ indexDetails.getIndexName();
} else {
return dispatchQueryRequest.getClusterName()
+ ":"
+ fullyQualifiedTableName.getFullyQualifiedName()
+ "."
+ indexDetails.getIndexName();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
public class DispatchQueryRequest {
private final String applicationId;
private final String query;
private final String datasource;
private final LangType langType;
private final String executionRoleARN;
private final String clusterName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.HashSet;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.opensearch.core.xcontent.XContentParser;
Expand All @@ -17,27 +18,44 @@
public class CreateAsyncQueryRequest {

private String query;
private String datasource;
private LangType lang;

public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser)
throws IOException {
String query = null;
LangType lang = null;
String datasource = null;

Check warning on line 28 in spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java

View check run for this annotation

Codecov / codecov/patch

spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java#L28

Added line #L28 was not covered by tests
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
HashSet<String> missingFields = new HashSet<>();

Check warning on line 30 in spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java

View check run for this annotation

Codecov / codecov/patch

spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java#L30

Added line #L30 was not covered by tests
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
if (fieldName.equals("query")) {
query = parser.textOrNull();
if (query == null) {
missingFields.add("query");

Check warning on line 37 in spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java

View check run for this annotation

Codecov / codecov/patch

spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java#L37

Added line #L37 was not covered by tests
}
} else if (fieldName.equals("lang")) {
lang = LangType.fromString(parser.textOrNull());
String langString = parser.textOrNull();

Check warning on line 40 in spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java

View check run for this annotation

Codecov / codecov/patch

spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java#L40

Added line #L40 was not covered by tests
if (langString == null) {
missingFields.add("lang");

Check warning on line 42 in spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java

View check run for this annotation

Codecov / codecov/patch

spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java#L42

Added line #L42 was not covered by tests
}
lang = LangType.fromString(langString);

Check warning on line 44 in spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java

View check run for this annotation

Codecov / codecov/patch

spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java#L44

Added line #L44 was not covered by tests
} else if (fieldName.equals("datasource")) {
datasource = parser.textOrNull();

Check warning on line 46 in spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java

View check run for this annotation

Codecov / codecov/patch

spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java#L46

Added line #L46 was not covered by tests
if (datasource == null) {
missingFields.add("datasource");

Check warning on line 48 in spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java

View check run for this annotation

Codecov / codecov/patch

spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java#L48

Added line #L48 was not covered by tests
}
} else {
throw new IllegalArgumentException("Unknown field: " + fieldName);
}
}
if (lang == null || query == null) {
throw new IllegalArgumentException("lang and query are required fields.");

if (missingFields.size() > 0) {
throw new IllegalArgumentException(
String.format("Missing %s fields in the query request", missingFields));

Check warning on line 57 in spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java

View check run for this annotation

Codecov / codecov/patch

spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java#L56-L57

Added lines #L56 - L57 were not covered by tests
}
return new CreateAsyncQueryRequest(query, lang);
return new CreateAsyncQueryRequest(query, datasource, lang);

Check warning on line 59 in spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java

View check run for this annotation

Codecov / codecov/patch

spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java#L59

Added line #L59 was not covered by tests
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ void testCreateAsyncQuery() {
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
CreateAsyncQueryRequest createAsyncQueryRequest =
new CreateAsyncQueryRequest("select * from my_glue.default.http_logs", LangType.SQL);
new CreateAsyncQueryRequest(
"select * from my_glue.default.http_logs", "my_glue", LangType.SQL);
when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG))
.thenReturn(
"{\"applicationId\":\"00fd775baqpu4g0p\",\"executionRoleARN\":\"arn:aws:iam::270824043731:role/emr-job-execution-role\",\"region\":\"eu-west-1\"}");
Expand All @@ -58,6 +59,7 @@ void testCreateAsyncQuery() {
new DispatchQueryRequest(
"00fd775baqpu4g0p",
"select * from my_glue.default.http_logs",
"my_glue",
LangType.SQL,
"arn:aws:iam::270824043731:role/emr-job-execution-role",
TEST_CLUSTER_NAME)))
Expand All @@ -73,6 +75,7 @@ void testCreateAsyncQuery() {
new DispatchQueryRequest(
"00fd775baqpu4g0p",
"select * from my_glue.default.http_logs",
"my_glue",
LangType.SQL,
"arn:aws:iam::270824043731:role/emr-job-execution-role",
TEST_CLUSTER_NAME));
Expand Down
Loading

0 comments on commit 26dab6c

Please sign in to comment.