Skip to content

Commit

Permalink
Integration with REPL Spark job (#2327)
Browse files Browse the repository at this point in the history
* add InteractiveSession and SessionManager

Signed-off-by: Peng Huo <penghuo@gmail.com>

* add statement

Signed-off-by: Peng Huo <penghuo@gmail.com>

* add statement

Signed-off-by: Peng Huo <penghuo@gmail.com>

* fix format

Signed-off-by: Peng Huo <penghuo@gmail.com>

* snapshot

Signed-off-by: Peng Huo <penghuo@gmail.com>

* address comments

Signed-off-by: Peng Huo <penghuo@gmail.com>

* update

Signed-off-by: Peng Huo <penghuo@gmail.com>

* Update REST and Transport interface

Signed-off-by: Peng Huo <penghuo@gmail.com>

* Revert on transport layer

Signed-off-by: Peng Huo <penghuo@gmail.com>

* format code

Signed-off-by: Peng Huo <penghuo@gmail.com>

* add API doc

Signed-off-by: Peng Huo <penghuo@gmail.com>

* modify api

Signed-off-by: Peng Huo <penghuo@gmail.com>

* create query_execution_request index on demand

Signed-off-by: Peng Huo <penghuo@gmail.com>

* add REPL spark parameters

Signed-off-by: Peng Huo <penghuo@gmail.com>

* Add IT

Signed-off-by: Peng Huo <penghuo@gmail.com>

* format code

Signed-off-by: Peng Huo <penghuo@gmail.com>

* bind request index to datasource

Signed-off-by: Peng Huo <penghuo@gmail.com>

* fix bug when fetch query result

Signed-off-by: Peng Huo <penghuo@gmail.com>

* revert entrypoint class

Signed-off-by: Peng Huo <penghuo@gmail.com>

* update mapping

Signed-off-by: Peng Huo <penghuo@gmail.com>

---------

Signed-off-by: Peng Huo <penghuo@gmail.com>
  • Loading branch information
penghuo authored Oct 20, 2023
1 parent f835112 commit 7b4156e
Show file tree
Hide file tree
Showing 22 changed files with 810 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata;
import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
Expand Down Expand Up @@ -321,9 +320,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService(
new FlintIndexMetadataReaderImpl(client),
client,
new SessionManager(
new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client),
emrServerlessClient,
pluginSettings));
new StateStore(client, clusterService), emrServerlessClient, pluginSettings));
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
Expand Down
1 change: 1 addition & 0 deletions spark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ dependencies {
because 'allows tests to run from IDEs that bundle older version of launcher'
}
testImplementation("org.opensearch.test:framework:${opensearch_version}")
testImplementation project(':opensearch')
}

test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI;
import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN;
import static org.opensearch.sql.spark.data.constants.SparkConstants.*;
import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX;

import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -39,7 +40,7 @@ public class SparkSubmitParameters {

public static class Builder {

private final String className;
private String className;
private final Map<String, String> config;
private String extraParameters;

Expand Down Expand Up @@ -70,6 +71,11 @@ public static Builder builder() {
return new Builder();
}

public Builder className(String className) {
this.className = className;
return this;
}

public Builder dataSource(DataSourceMetadata metadata) {
if (DataSourceType.S3GLUE.equals(metadata.getConnector())) {
String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN);
Expand Down Expand Up @@ -141,6 +147,12 @@ public Builder extraParameters(String params) {
return this;
}

public Builder sessionExecution(String sessionId, String datasourceName) {
config.put(FLINT_JOB_REQUEST_INDEX, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName));
config.put(FLINT_JOB_SESSION_ID, sessionId);
return this;
}

public SparkSubmitParameters build() {
return new SparkSubmitParameters(className, config, extraParameters);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,8 @@ public class SparkConstants {
public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER =
"com.amazonaws.emr.AssumeRoleAWSCredentialsProvider";
public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/";

public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex";
public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId";
public static final String FLINT_SESSION_CLASS_NAME = "org.apache.spark.sql.FlintREPL";
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME;
import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD;

import com.amazonaws.services.emrserverless.model.CancelJobRunResult;
Expand Down Expand Up @@ -96,12 +97,19 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata)
return DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()).result();
}

// either empty json when the result is not available or data with status
// Fetch from Result Index
JSONObject result =
jobExecutionResponseReader.getResultFromOpensearchIndex(
asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex());

JSONObject result;
if (asyncQueryJobMetadata.getSessionId() == null) {
// either empty json when the result is not available or data with status
// Fetch from Result Index
result =
jobExecutionResponseReader.getResultFromOpensearchIndex(
asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex());
} else {
// when session enabled, jobId in asyncQueryJobMetadata is actually queryId.
result =
jobExecutionResponseReader.getResultWithQueryId(
asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex());
}
// if result index document has a status, we are gonna use the status directly; otherwise, we
// will use emr-s job status.
// That a job is successful does not mean there is no error in execution. For example, even if
Expand Down Expand Up @@ -230,22 +238,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata);
String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query";
Map<String, String> tags = getDefaultTagsForJobSubmission(dispatchQueryRequest);
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
.dataSource(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
tags,
false,
dataSourceMetadata.getResultIndex());

if (sessionManager.isEnabled()) {
Session session;
if (dispatchQueryRequest.getSessionId() != null) {
Expand All @@ -260,7 +253,19 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
// create session if not exist
session =
sessionManager.createSession(
new CreateSessionRequest(startJobRequest, dataSourceMetadata.getName()));
new CreateSessionRequest(
jobName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
.className(FLINT_SESSION_CLASS_NAME)
.dataSource(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()),
tags,
dataSourceMetadata.getResultIndex(),
dataSourceMetadata.getName()));
}
StatementId statementId =
session.submit(
Expand All @@ -272,6 +277,22 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
dataSourceMetadata.getResultIndex(),
session.getSessionId().getSessionId());
} else {
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
.dataSource(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
tags,
false,
dataSourceMetadata.getResultIndex());
String jobId = emrServerlessClient.startJobRun(startJobRequest);
return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,30 @@

package org.opensearch.sql.spark.execution.session;

import java.util.Map;
import lombok.Data;
import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters;
import org.opensearch.sql.spark.client.StartJobRequest;

@Data
public class CreateSessionRequest {
private final StartJobRequest startJobRequest;
private final String jobName;
private final String applicationId;
private final String executionRoleArn;
private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder;
private final Map<String, String> tags;
private final String resultIndex;
private final String datasourceName;

public StartJobRequest getStartJobRequest() {
return new StartJobRequest(
"select 1",
jobName,
applicationId,
executionRoleArn,
sparkSubmitParametersBuilder.build().toString(),
tags,
false,
resultIndex);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@ public class InteractiveSession implements Session {
@Override
public void open(CreateSessionRequest createSessionRequest) {
try {
// append session id;
createSessionRequest
.getSparkSubmitParametersBuilder()
.sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName());
String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest());
String applicationId = createSessionRequest.getStartJobRequest().getApplicationId();

sessionModel =
initInteractiveSession(
applicationId, jobID, sessionId, createSessionRequest.getDatasourceName());
createSession(stateStore).apply(sessionModel);
createSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel);
} catch (VersionConflictEngineException e) {
String errorMsg = "session already exist. " + sessionId;
LOG.error(errorMsg);
Expand All @@ -59,7 +63,8 @@ public void open(CreateSessionRequest createSessionRequest) {
/** todo. StatementSweeper will delete doc. */
@Override
public void close() {
Optional<SessionModel> model = getSession(stateStore).apply(sessionModel.getId());
Optional<SessionModel> model =
getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId());
if (model.isEmpty()) {
throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId());
} else {
Expand All @@ -69,7 +74,8 @@ public void close() {

/** Submit statement. If submit successfully, Statement in waiting state. */
public StatementId submit(QueryRequest request) {
Optional<SessionModel> model = getSession(stateStore).apply(sessionModel.getId());
Optional<SessionModel> model =
getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId());
if (model.isEmpty()) {
throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId());
} else {
Expand All @@ -84,6 +90,7 @@ public StatementId submit(QueryRequest request) {
.stateStore(stateStore)
.statementId(statementId)
.langType(LangType.SQL)
.datasourceName(sessionModel.getDatasourceName())
.query(request.getQuery())
.queryId(statementId.getId())
.build();
Expand All @@ -103,7 +110,7 @@ public StatementId submit(QueryRequest request) {

@Override
public Optional<Statement> get(StatementId stID) {
return StateStore.getStatement(stateStore)
return StateStore.getStatement(stateStore, sessionModel.getDatasourceName())
.apply(stID.getId())
.map(
model ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,32 @@

package org.opensearch.sql.spark.execution.session;

import java.nio.charset.StandardCharsets;
import java.util.Base64;
import lombok.Data;
import org.apache.commons.lang3.RandomStringUtils;

@Data
public class SessionId {
public static final int PREFIX_LEN = 10;

private final String sessionId;

public static SessionId newSessionId() {
return new SessionId(RandomStringUtils.randomAlphanumeric(16));
public static SessionId newSessionId(String datasourceName) {
return new SessionId(encode(datasourceName));
}

public String getDataSourceName() {
return decode(sessionId);
}

private static String decode(String sessionId) {
return new String(Base64.getDecoder().decode(sessionId)).substring(PREFIX_LEN);
}

private static String encode(String datasourceName) {
String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName;
return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class SessionManager {
public Session createSession(CreateSessionRequest request) {
InteractiveSession session =
InteractiveSession.builder()
.sessionId(newSessionId())
.sessionId(newSessionId(request.getDatasourceName()))
.stateStore(stateStore)
.serverlessClient(emrServerlessClient)
.build();
Expand All @@ -37,7 +37,8 @@ public Session createSession(CreateSessionRequest request) {
}

public Optional<Session> getSession(SessionId sid) {
Optional<SessionModel> model = StateStore.getSession(stateStore).apply(sid.getSessionId());
Optional<SessionModel> model =
StateStore.getSession(stateStore, sid.getDataSourceName()).apply(sid.getSessionId());
if (model.isPresent()) {
InteractiveSession session =
InteractiveSession.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Getter;
Expand All @@ -32,8 +33,10 @@ public enum SessionState {
.collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t));

public static SessionState fromString(String key) {
if (STATES.containsKey(key)) {
return STATES.get(key);
for (SessionState ss : SessionState.values()) {
if (ss.getSessionState().toLowerCase(Locale.ROOT).equals(key)) {
return ss;
}
}
throw new IllegalArgumentException("Invalid session state: " + key);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

package org.opensearch.sql.spark.execution.session;

import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.Locale;
import lombok.Getter;

@Getter
Expand All @@ -20,13 +18,11 @@ public enum SessionType {
this.sessionType = sessionType;
}

private static Map<String, SessionType> TYPES =
Arrays.stream(SessionType.values())
.collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t));

public static SessionType fromString(String key) {
if (TYPES.containsKey(key)) {
return TYPES.get(key);
for (SessionType sType : SessionType.values()) {
if (sType.getSessionType().toLowerCase(Locale.ROOT).equals(key)) {
return sType;
}
}
throw new IllegalArgumentException("Invalid session type: " + key);
}
Expand Down
Loading

0 comments on commit 7b4156e

Please sign in to comment.