From 78acb1a43bd82f1036d8bcf17bf653564bc0557d Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Thu, 12 Oct 2023 13:34:21 -0700 Subject: [PATCH 01/20] add InteractiveSession and SessionManager Signed-off-by: Peng Huo --- spark/build.gradle | 39 +++- .../session/CreateSessionRequest.java | 15 ++ .../execution/session/InteractiveSession.java | 62 +++++ .../sql/spark/execution/session/Session.java | 19 ++ .../spark/execution/session/SessionId.java | 23 ++ .../execution/session/SessionManager.java | 51 +++++ .../spark/execution/session/SessionModel.java | 143 ++++++++++++ .../spark/execution/session/SessionState.java | 36 +++ .../spark/execution/session/SessionType.java | 33 +++ .../statestore/SessionStateStore.java | 87 ++++++++ .../session/InteractiveSessionTest.java | 211 ++++++++++++++++++ .../execution/session/SessionManagerTest.java | 38 ++++ .../execution/session/SessionStateTest.java | 20 ++ .../execution/session/SessionTypeTest.java | 20 ++ .../statestore/SessionStateStoreTest.java | 42 ++++ 15 files changed, 834 insertions(+), 5 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java diff --git a/spark/build.gradle b/spark/build.gradle index c06b5b6ecf..c2c925ecaf 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -52,15 +52,38 @@ dependencies { api group: 'com.amazonaws', name: 'aws-java-sdk-emrserverless', version: '1.12.545' implementation group: 'commons-io', name: 'commons-io', version: '2.8.0' - testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') + testImplementation(platform("org.junit:junit-bom:5.6.2")) + + testImplementation('org.junit.jupiter:junit-jupiter') testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.2.0' - testImplementation 'junit:junit:4.13.1' - testImplementation "org.opensearch.test:framework:${opensearch_version}" + + testCompileOnly('junit:junit:4.13.1') { + exclude group: 'org.hamcrest', module: 'hamcrest-core' + } + testRuntimeOnly("org.junit.vintage:junit-vintage-engine") { + exclude group: 'org.hamcrest', module: 'hamcrest-core' + } + testRuntimeOnly("org.junit.platform:junit-platform-launcher") { + because 'allows tests to run from IDEs that bundle older version of launcher' + } + testImplementation("org.opensearch.test:framework:${opensearch_version}") } test { - useJUnitPlatform() + useJUnitPlatform { + includeEngines("junit-jupiter") + } + testLogging { + events "failed" + exceptionFormat "full" + } +} +task junit4(type: Test) { + useJUnitPlatform { + includeEngines("junit-vintage") + } + systemProperty 'tests.security.manager', 'false' testLogging { events "failed" exceptionFormat "full" @@ -68,6 +91,8 @@ test { } jacocoTestReport { + dependsOn test, junit4 + executionData test, junit4 reports { html.enabled true xml.enabled true @@ -78,9 +103,10 @@ jacocoTestReport { })) } } -test.finalizedBy(project.tasks.jacocoTestReport) jacocoTestCoverageVerification { + dependsOn test, junit4 + executionData test, junit4 violationRules { rule { element = 'CLASS' @@ -92,6 +118,9 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.asyncquery.exceptions.*', 'org.opensearch.sql.spark.dispatcher.model.*', 'org.opensearch.sql.spark.flint.FlintIndexType', + // ignore because XContext IOException + 'org.opensearch.sql.spark.execution.statestore.SessionStateStore', + 'org.opensearch.sql.spark.execution.session.SessionModel' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java new file mode 100644 index 0000000000..17e3346248 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import lombok.Data; +import org.opensearch.sql.spark.client.StartJobRequest; + +@Data +public class CreateSessionRequest { + private final StartJobRequest startJobRequest; + private final String datasourceName; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java new file mode 100644 index 0000000000..2898f4b87b --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; + +import java.util.Optional; +import lombok.Builder; +import lombok.Getter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; + +/** + * Interactive session. + * + *

ENTRY_STATE: not_started + */ +@Getter +@Builder +public class InteractiveSession implements Session { + private static final Logger LOG = LogManager.getLogger(); + + private final SessionId sessionId; + private final SessionStateStore sessionStateStore; + private final EMRServerlessClient serverlessClient; + private final CreateSessionRequest createSessionRequest; + + private SessionModel sessionModel; + + @Override + public void open() { + try { + String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); + String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); + + sessionModel = + initInteractiveSession( + applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); + sessionStateStore.create(sessionModel); + } catch (VersionConflictEngineException e) { + String errorMsg = "session already exist. " + sessionId; + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + @Override + public void close() { + Optional model = sessionStateStore.get(sessionModel.getSessionId()); + if (model.isEmpty()) { + throw new IllegalStateException("session not exist. " + sessionModel.getSessionId()); + } else { + serverlessClient.cancelJobRun(sessionModel.getApplicationId(), sessionModel.getJobId()); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java new file mode 100644 index 0000000000..449a9af538 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +/** Session define the statement execution context. Each session is binding to one Spark Job. */ +public interface Session { + /** open session. */ + void open(); + + /** close session. */ + void close(); + + SessionModel getSessionModel(); + + SessionId getSessionId(); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java new file mode 100644 index 0000000000..a2847cde18 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import lombok.Data; +import org.apache.commons.lang3.RandomStringUtils; + +@Data +public class SessionId { + private final String sessionId; + + public static SessionId newSessionId() { + return new SessionId(RandomStringUtils.random(10, true, true)); + } + + @Override + public String toString() { + return "sessionId=" + sessionId; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java new file mode 100644 index 0000000000..2166c91568 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; + +/** + * Singleton Class + * + *

todo. add Session cache and Session sweeper. + */ +@RequiredArgsConstructor +public class SessionManager { + private final SessionStateStore stateStore; + private final EMRServerlessClient emrServerlessClient; + + public Session createSession(CreateSessionRequest request) { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(newSessionId()) + .sessionStateStore(stateStore) + .serverlessClient(emrServerlessClient) + .createSessionRequest(request) + .build(); + session.open(); + return session; + } + + public Optional getSession(SessionId sid) { + Optional model = stateStore.get(sid); + if (model.isPresent()) { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sid) + .sessionStateStore(stateStore) + .serverlessClient(emrServerlessClient) + .sessionModel(model.get()) + .build(); + return Optional.ofNullable(session); + } + return Optional.empty(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java new file mode 100644 index 0000000000..656f0ec8ce --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.session.SessionType.INTERACTIVE; + +import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.index.seqno.SequenceNumbers; + +/** Session data in flint.ql.sessions index. */ +@Data +@Builder +public class SessionModel implements ToXContentObject { + public static final String VERSION = "version"; + public static final String TYPE = "type"; + public static final String SESSION_TYPE = "sessionType"; + public static final String SESSION_ID = "sessionId"; + public static final String SESSION_STATE = "state"; + public static final String DATASOURCE_NAME = "dataSourceName"; + public static final String LAST_UPDATE_TIME = "lastUpdateTime"; + public static final String APPLICATION_ID = "applicationId"; + public static final String JOB_ID = "jobId"; + public static final String ERROR = "error"; + public static final String UNKNOWN = "unknown"; + public static final String SESSION_DOC_TYPE = "session"; + + private final String version; + private final SessionType sessionType; + private final SessionId sessionId; + private final SessionState sessionState; + private final String applicationId; + private final String jobId; + private final String datasourceName; + private final String error; + private final long lastUpdateTime; + + private final long seqNo; + private final long primaryTerm; + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(VERSION, version) + .field(TYPE, SESSION_DOC_TYPE) + .field(SESSION_TYPE, sessionType.getSessionType()) + .field(SESSION_ID, sessionId.getSessionId()) + .field(SESSION_STATE, sessionState.getSessionState()) + .field(DATASOURCE_NAME, datasourceName) + .field(APPLICATION_ID, applicationId) + .field(JOB_ID, jobId) + .field(LAST_UPDATE_TIME, lastUpdateTime) + .field(ERROR, error) + .endObject(); + return builder; + } + + public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) { + return builder() + .version(copy.version) + .sessionType(copy.sessionType) + .sessionId(new SessionId(copy.sessionId.getSessionId())) + .sessionState(copy.sessionState) + .datasourceName(copy.datasourceName) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + @SneakyThrows + public static SessionModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { + SessionModelBuilder builder = new SessionModelBuilder(); + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case VERSION: + builder.version(parser.text()); + break; + case SESSION_TYPE: + builder.sessionType(SessionType.fromString(parser.text())); + break; + case SESSION_ID: + builder.sessionId(new SessionId(parser.text())); + break; + case SESSION_STATE: + builder.sessionState(SessionState.fromString(parser.text())); + break; + case DATASOURCE_NAME: + builder.datasourceName(parser.text()); + break; + case ERROR: + builder.error(parser.text()); + break; + case APPLICATION_ID: + builder.applicationId(parser.text()); + break; + case JOB_ID: + builder.jobId(parser.text()); + break; + case LAST_UPDATE_TIME: + builder.lastUpdateTime(parser.longValue()); + break; + case TYPE: + // do nothing. + break; + } + } + builder.seqNo(seqNo); + builder.primaryTerm(primaryTerm); + return builder.build(); + } + + public static SessionModel initInteractiveSession( + String applicationId, String jobId, SessionId sid, String datasourceName) { + return builder() + .version("1.0") + .sessionType(INTERACTIVE) + .sessionId(sid) + .sessionState(NOT_STARTED) + .datasourceName(datasourceName) + .applicationId(applicationId) + .jobId(jobId) + .error(UNKNOWN) + .lastUpdateTime(System.currentTimeMillis()) + .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) + .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) + .build(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java new file mode 100644 index 0000000000..509d5105e9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +@Getter +public enum SessionState { + NOT_STARTED("not_started"), + RUNNING("running"), + DEAD("dead"), + FAIL("fail"); + + private final String sessionState; + + SessionState(String sessionState) { + this.sessionState = sessionState; + } + + private static Map STATES = + Arrays.stream(SessionState.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static SessionState fromString(String key) { + if (STATES.containsKey(key)) { + return STATES.get(key); + } + throw new IllegalArgumentException("Invalid session state: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java new file mode 100644 index 0000000000..dd179a1dc5 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +@Getter +public enum SessionType { + INTERACTIVE("interactive"); + + private final String sessionType; + + SessionType(String sessionType) { + this.sessionType = sessionType; + } + + private static Map TYPES = + Arrays.stream(SessionType.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static SessionType fromString(String key) { + if (TYPES.containsKey(key)) { + return TYPES.get(key); + } + throw new IllegalArgumentException("Invalid session type: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java new file mode 100644 index 0000000000..6ddce55360 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionModel; + +@RequiredArgsConstructor +public class SessionStateStore { + private static final Logger LOG = LogManager.getLogger(); + + private final String indexName; + private final Client client; + + public SessionModel create(SessionModel session) { + try { + IndexRequest indexRequest = + new IndexRequest(indexName) + .id(session.getSessionId().getSessionId()) + .source(session.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .setIfSeqNo(session.getSeqNo()) + .setIfPrimaryTerm(session.getPrimaryTerm()) + .create(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", session.getSessionId()); + return SessionModel.of(session, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + session.getSessionId(), + indexResponse.getResult().getLowercase())); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public Optional get(SessionId sid) { + try { + GetRequest getRequest = new GetRequest().index(indexName).id(sid.getSessionId()); + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + SessionModel.fromXContent( + parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java new file mode 100644 index 0000000000..3ff547157c --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -0,0 +1,211 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; +import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; + +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import java.util.HashMap; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +/** mock-maker-inline does not work with OpenSearchTestCase. */ +public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { + + private static final String indexName = "mockindex"; + + private TestEMRServerlessClient emrsClient; + private StartJobRequest startJobRequest; + private SessionStateStore stateStore; + + @Before + public void setup() { + emrsClient = new TestEMRServerlessClient(); + startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); + stateStore = new SessionStateStore(indexName, client()); + createIndex(indexName); + } + + @After + public void clean() { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } + + @Test + public void openCloseSession() { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(SessionId.newSessionId()) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) + .build(); + + // open session + TestSession testSession = testSession(session, stateStore); + testSession.open().assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); + emrsClient.startJobRunCalled(1); + + // close session + testSession.close(); + emrsClient.cancelJobRunCalled(1); + } + + @Test + public void openSessionFailedConflict() { + SessionId sessionId = new SessionId("duplicate-session-id"); + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sessionId) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) + .build(); + session.open(); + + InteractiveSession duplicateSession = + InteractiveSession.builder() + .sessionId(sessionId) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) + .build(); + IllegalStateException exception = + assertThrows(IllegalStateException.class, duplicateSession::open); + assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage()); + } + + @Test + public void closeNotExistSession() { + SessionId sessionId = SessionId.newSessionId(); + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sessionId) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) + .build(); + session.open(); + + client().delete(new DeleteRequest(indexName, sessionId.getSessionId())); + + IllegalStateException exception = assertThrows(IllegalStateException.class, session::close); + assertEquals("session not exist. " + sessionId, exception.getMessage()); + emrsClient.cancelJobRunCalled(0); + } + + @Test + public void sessionManagerCreateSession() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + TestSession testSession = testSession(session, stateStore); + testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); + } + + @Test + public void sessionManagerGetSession() { + SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + Session session = + sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + Optional managerSession = sessionManager.getSession(session.getSessionId()); + assertTrue(managerSession.isPresent()); + assertEquals(session.getSessionId(), managerSession.get().getSessionId()); + } + + @Test + public void sessionManagerGetSessionNotExist() { + SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + + Optional managerSession = sessionManager.getSession(new SessionId("no-exist")); + assertTrue(managerSession.isEmpty()); + } + + @RequiredArgsConstructor + static class TestSession { + private final Session session; + private final SessionStateStore stateStore; + + public static TestSession testSession(Session session, SessionStateStore stateStore) { + return new TestSession(session, stateStore); + } + + public TestSession assertSessionState(SessionState expected) { + assertEquals(expected, session.getSessionModel().getSessionState()); + + Optional sessionStoreState = + stateStore.get(session.getSessionModel().getSessionId()); + assertTrue(sessionStoreState.isPresent()); + assertEquals(expected, sessionStoreState.get().getSessionState()); + + return this; + } + + public TestSession assertAppId(String expected) { + assertEquals(expected, session.getSessionModel().getApplicationId()); + return this; + } + + public TestSession assertJobId(String expected) { + assertEquals(expected, session.getSessionModel().getJobId()); + return this; + } + + public TestSession open() { + session.open(); + return this; + } + + public TestSession close() { + session.close(); + return this; + } + } + + static class TestEMRServerlessClient implements EMRServerlessClient { + + private int startJobRunCalled = 0; + private int cancelJobRunCalled = 0; + + @Override + public String startJobRun(StartJobRequest startJobRequest) { + startJobRunCalled++; + return "jobId"; + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + return null; + } + + @Override + public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { + cancelJobRunCalled++; + return null; + } + + public void startJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, startJobRunCalled); + } + + public void cancelJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, cancelJobRunCalled); + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java new file mode 100644 index 0000000000..d35105f787 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.After; +import org.junit.Before; +import org.mockito.MockMakers; +import org.mockito.MockSettings; +import org.mockito.Mockito; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +class SessionManagerTest extends OpenSearchSingleNodeTestCase { + private static final String indexName = "mockindex"; + + // mock-maker-inline does not work with OpenSearchTestCase. make sure use mockSettings when mock. + private static final MockSettings mockSettings = + Mockito.withSettings().mockMaker(MockMakers.SUBCLASS); + + private SessionStateStore stateStore; + + @Before + public void setup() { + stateStore = new SessionStateStore(indexName, client()); + createIndex(indexName); + } + + @After + public void clean() { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java new file mode 100644 index 0000000000..a987c80d59 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Test; + +class SessionStateTest { + @Test + public void invalidSessionType() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SessionState.fromString("invalid")); + assertEquals("Invalid session state: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java new file mode 100644 index 0000000000..a2ab43e709 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Test; + +class SessionTypeTest { + @Test + public void invalidSessionType() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SessionType.fromString("invalid")); + assertEquals("Invalid session type: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java new file mode 100644 index 0000000000..9c779555d7 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.client.Client; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionModel; + +@ExtendWith(MockitoExtension.class) +class SessionStateStoreTest { + @Mock(answer = RETURNS_DEEP_STUBS) + private Client client; + + @Mock private IndexResponse indexResponse; + + @Test + public void createWithException() { + when(client.index(any()).actionGet()).thenReturn(indexResponse); + doReturn(DocWriteResponse.Result.NOT_FOUND).when(indexResponse).getResult(); + SessionModel sessionModel = + SessionModel.initInteractiveSession( + "appId", "jobId", SessionId.newSessionId(), "datasource"); + SessionStateStore sessionStateStore = new SessionStateStore("indexName", client); + + assertThrows(RuntimeException.class, () -> sessionStateStore.create(sessionModel)); + } +} From b7b3d77153b903a09b5149f88a240fe73e166b76 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Fri, 13 Oct 2023 14:51:58 -0700 Subject: [PATCH 02/20] add statement Signed-off-by: Peng Huo --- spark/build.gradle | 5 +- .../execution/session/InteractiveSession.java | 67 +++- .../sql/spark/execution/session/Session.java | 21 ++ .../execution/session/SessionManager.java | 10 +- .../spark/execution/session/SessionModel.java | 30 +- .../execution/statement/QueryRequest.java | 15 + .../spark/execution/statement/Statement.java | 82 +++++ .../execution/statement/StatementId.java | 23 ++ .../execution/statement/StatementModel.java | 170 ++++++++++ .../execution/statement/StatementState.java | 37 +++ .../statestore/SessionStateStore.java | 87 ----- .../execution/statestore/StateModel.java | 30 ++ .../execution/statestore/StateStore.java | 149 +++++++++ .../session/InteractiveSessionTest.java | 27 +- .../execution/session/SessionManagerTest.java | 8 +- .../statement/StatementStateTest.java | 20 ++ .../execution/statement/StatementTest.java | 296 ++++++++++++++++++ .../statestore/SessionStateStoreTest.java | 42 --- 18 files changed, 958 insertions(+), 161 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java delete mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java diff --git a/spark/build.gradle b/spark/build.gradle index c2c925ecaf..d8bb08657d 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -119,8 +119,9 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.dispatcher.model.*', 'org.opensearch.sql.spark.flint.FlintIndexType', // ignore because XContext IOException - 'org.opensearch.sql.spark.execution.statestore.SessionStateStore', - 'org.opensearch.sql.spark.execution.session.SessionModel' + 'org.opensearch.sql.spark.execution.statestore.StateStore', + 'org.opensearch.sql.spark.execution.session.SessionModel', + 'org.opensearch.sql.spark.execution.statement.StatementModel' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 2898f4b87b..dd6cb1160d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -6,6 +6,9 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; +import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import java.util.Optional; import lombok.Builder; @@ -14,7 +17,11 @@ import org.apache.logging.log4j.Logger; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; /** * Interactive session. @@ -27,7 +34,7 @@ public class InteractiveSession implements Session { private static final Logger LOG = LogManager.getLogger(); private final SessionId sessionId; - private final SessionStateStore sessionStateStore; + private final StateStore stateStore; private final EMRServerlessClient serverlessClient; private final CreateSessionRequest createSessionRequest; @@ -42,7 +49,7 @@ public void open() { sessionModel = initInteractiveSession( applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - sessionStateStore.create(sessionModel); + createSession(stateStore).apply(sessionModel); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -50,13 +57,63 @@ public void open() { } } + /** todo. StatementSweeper will delete doc. */ @Override public void close() { - Optional model = sessionStateStore.get(sessionModel.getSessionId()); + Optional model = getSession(stateStore).apply(sessionModel.getId()); if (model.isEmpty()) { - throw new IllegalStateException("session not exist. " + sessionModel.getSessionId()); + throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { serverlessClient.cancelJobRun(sessionModel.getApplicationId(), sessionModel.getJobId()); } } + + /** Submit statement. If submit successfully, Statement in waiting state. */ + public StatementId submit(QueryRequest request) { + Optional model = getSession(stateStore).apply(sessionModel.getId()); + if (model.isEmpty()) { + throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); + } else { + sessionModel = model.get(); + if (sessionModel.getSessionState() == SessionState.RUNNING) { + StatementId statementId = newStatementId(); + Statement st = + Statement.builder() + .sessionId(sessionId) + .stateStore(stateStore) + .statementId(statementId) + .langType(LangType.SQL) + .query(request.getQuery()) + .queryId(statementId.getId()) + .build(); + st.open(); + return statementId; + } else { + String errMsg = + String.format( + "can't submit statement, session should in running state, " + + "current session state is: %s", + sessionModel.getSessionState().getSessionState()); + LOG.debug(errMsg); + throw new IllegalStateException(errMsg); + } + } + } + + @Override + public Optional get(StatementId stID) { + return StateStore.getStatement(stateStore) + .apply(stID.getId()) + .map( + model -> + Statement.builder() + .sessionId(sessionId) + .statementId(model.getStatementId()) + .langType(model.getLangType()) + .query(model.getQuery()) + .queryId(model.getQueryId()) + .stateStore(stateStore) + .statementModel(model) + .build()); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java index 449a9af538..752055e119 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -5,6 +5,11 @@ package org.opensearch.sql.spark.execution.session; +import java.util.Optional; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; + /** Session define the statement execution context. Each session is binding to one Spark Job. */ public interface Session { /** open session. */ @@ -13,6 +18,22 @@ public interface Session { /** close session. */ void close(); + /** + * submit {@link QueryRequest}. + * + * @param request {@link QueryRequest} + * @return {@link StatementId} + */ + StatementId submit(QueryRequest request); + + /** + * get {@link Statement}. + * + * @param stID {@link StatementId} + * @return {@link Statement} + */ + Optional get(StatementId stID); + SessionModel getSessionModel(); SessionId getSessionId(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 2166c91568..5ed510f367 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -10,7 +10,7 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statestore.StateStore; /** * Singleton Class @@ -19,14 +19,14 @@ */ @RequiredArgsConstructor public class SessionManager { - private final SessionStateStore stateStore; + private final StateStore stateStore; private final EMRServerlessClient emrServerlessClient; public Session createSession(CreateSessionRequest request) { InteractiveSession session = InteractiveSession.builder() .sessionId(newSessionId()) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrServerlessClient) .createSessionRequest(request) .build(); @@ -35,12 +35,12 @@ public Session createSession(CreateSessionRequest request) { } public Optional getSession(SessionId sid) { - Optional model = stateStore.get(sid); + Optional model = StateStore.getSession(stateStore).apply(sid.getSessionId()); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() .sessionId(sid) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrServerlessClient) .sessionModel(model.get()) .build(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java index 656f0ec8ce..806cdb083e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java @@ -12,16 +12,16 @@ import lombok.Builder; import lombok.Data; import lombok.SneakyThrows; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.statestore.StateModel; /** Session data in flint.ql.sessions index. */ @Data @Builder -public class SessionModel implements ToXContentObject { +public class SessionModel extends StateModel { public static final String VERSION = "version"; public static final String TYPE = "type"; public static final String SESSION_TYPE = "sessionType"; @@ -73,6 +73,27 @@ public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) { .sessionId(new SessionId(copy.sessionId.getSessionId())) .sessionState(copy.sessionState) .datasourceName(copy.datasourceName) + .applicationId(copy.getApplicationId()) + .jobId(copy.jobId) + .error(UNKNOWN) + .lastUpdateTime(copy.getLastUpdateTime()) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + public static SessionModel copyWithState( + SessionModel copy, SessionState state, long seqNo, long primaryTerm) { + return builder() + .version(copy.version) + .sessionType(copy.sessionType) + .sessionId(new SessionId(copy.sessionId.getSessionId())) + .sessionState(state) + .datasourceName(copy.datasourceName) + .applicationId(copy.getApplicationId()) + .jobId(copy.jobId) + .error(UNKNOWN) + .lastUpdateTime(copy.getLastUpdateTime()) .seqNo(seqNo) .primaryTerm(primaryTerm) .build(); @@ -140,4 +161,9 @@ public static SessionModel initInteractiveSession( .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) .build(); } + + @Override + public String getId() { + return sessionId.getSessionId(); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java new file mode 100644 index 0000000000..10061404ca --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import lombok.Data; +import org.opensearch.sql.spark.rest.model.LangType; + +@Data +public class QueryRequest { + private final LangType langType; + private final String query; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java new file mode 100644 index 0000000000..4c54393379 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.statement.StatementModel.submitStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.index.engine.DocumentMissingException; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; + +/** Statement represent query to execute in session. One statement map to one session. */ +@Getter +@Builder +public class Statement { + private static final Logger LOG = LogManager.getLogger(); + + private final SessionId sessionId; + private final StatementId statementId; + private final LangType langType; + private final String query; + private final String queryId; + private final StateStore stateStore; + + @Setter private StatementModel statementModel; + + /** Open a statement. */ + public void open() { + try { + statementModel = submitStatement(sessionId, statementId, langType, query, queryId); + statementModel = createStatement(stateStore).apply(statementModel); + } catch (VersionConflictEngineException e) { + String errorMsg = "statement already exist. " + statementId; + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + /** Cancel a statement. */ + public void cancel() { + if (statementModel.getStatementState().equals(StatementState.RUNNING)) { + String errorMsg = + String.format("can't cancel statement in waiting state. statement: %s.", statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + try { + this.statementModel = + updateStatementState(stateStore).apply(this.statementModel, StatementState.CANCELLED); + } catch (DocumentMissingException e) { + String errorMsg = + String.format("cancel statement failed. no statement found. statement: %s.", statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } catch (VersionConflictEngineException e) { + this.statementModel = + getStatement(stateStore).apply(statementModel.getId()).orElse(this.statementModel); + String errorMsg = + String.format( + "cancel statement failed. current statementState: %s " + "statement: %s.", + this.statementModel.getStatementState(), statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + public StatementState getStatementState() { + return statementModel.getStatementState(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java new file mode 100644 index 0000000000..4baff71493 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import lombok.Data; +import org.apache.commons.lang3.RandomStringUtils; + +@Data +public class StatementId { + private final String id; + + public static StatementId newStatementId() { + return new StatementId(RandomStringUtils.random(10, true, true)); + } + + @Override + public String toString() { + return "statementId=" + id; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java new file mode 100644 index 0000000000..b57868964e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -0,0 +1,170 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; + +import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.statestore.StateModel; +import org.opensearch.sql.spark.rest.model.LangType; + +/** Statement data in flint.ql.sessions index. */ +@Data +@Builder +public class StatementModel extends StateModel { + public static final String VERSION = "version"; + public static final String TYPE = "type"; + public static final String STATEMENT_STATE = "state"; + public static final String STATEMENT_ID = "statementId"; + public static final String SESSION_ID = "sessionId"; + public static final String LANG = "lang"; + public static final String QUERY = "query"; + public static final String QUERY_ID = "queryId"; + public static final String SUBMIT_TIME = "submitTime"; + public static final String ERROR = "error"; + public static final String UNKNOWN = "unknown"; + public static final String STATEMENT_DOC_TYPE = "statement"; + + private final String version; + private final StatementState statementState; + private final StatementId statementId; + private final SessionId sessionId; + private final LangType langType; + private final String query; + private final String queryId; + private final long submitTime; + private final String error; + + private final long seqNo; + private final long primaryTerm; + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(VERSION, version) + .field(TYPE, STATEMENT_DOC_TYPE) + .field(STATEMENT_STATE, statementState.getState()) + .field(STATEMENT_ID, statementId.getId()) + .field(SESSION_ID, sessionId.getSessionId()) + .field(LANG, langType.getText()) + .field(QUERY, query) + .field(QUERY_ID, queryId) + .field(SUBMIT_TIME, submitTime) + .field(ERROR, error) + .endObject(); + return builder; + } + + public static StatementModel copy(StatementModel copy, long seqNo, long primaryTerm) { + return builder() + .version("1.0") + .statementState(copy.statementState) + .statementId(copy.statementId) + .sessionId(copy.sessionId) + .langType(copy.langType) + .query(copy.query) + .queryId(copy.queryId) + .submitTime(copy.submitTime) + .error(copy.error) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + public static StatementModel copyWithState( + StatementModel copy, StatementState state, long seqNo, long primaryTerm) { + return builder() + .version("1.0") + .statementState(state) + .statementId(copy.statementId) + .sessionId(copy.sessionId) + .langType(copy.langType) + .query(copy.query) + .queryId(copy.queryId) + .submitTime(copy.submitTime) + .error(copy.error) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + @SneakyThrows + public static StatementModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { + StatementModel.StatementModelBuilder builder = StatementModel.builder(); + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case VERSION: + builder.version(parser.text()); + break; + case TYPE: + // do nothing + break; + case STATEMENT_STATE: + builder.statementState(StatementState.fromString(parser.text())); + break; + case STATEMENT_ID: + builder.statementId(new StatementId(parser.text())); + break; + case SESSION_ID: + builder.sessionId(new SessionId(parser.text())); + break; + case LANG: + builder.langType(LangType.fromString(parser.text())); + break; + case QUERY: + builder.query(parser.text()); + break; + case QUERY_ID: + builder.queryId(parser.text()); + break; + case SUBMIT_TIME: + builder.submitTime(parser.longValue()); + break; + case ERROR: + builder.error(parser.text()); + break; + } + } + builder.seqNo(seqNo); + builder.primaryTerm(primaryTerm); + return builder.build(); + } + + public static StatementModel submitStatement( + SessionId sid, StatementId statementId, LangType langType, String query, String queryId) { + return builder() + .version("1.0") + .statementState(WAITING) + .statementId(statementId) + .sessionId(sid) + .langType(langType) + .query(query) + .queryId(queryId) + .submitTime(System.currentTimeMillis()) + .error(UNKNOWN) + .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) + .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) + .build(); + } + + @Override + public String getId() { + return statementId.getId(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java new file mode 100644 index 0000000000..87ad6b11ae --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +@Getter +public enum StatementState { + WAITING("waiting"), + RUNNING("running"), + SUCCESS("success"), + FAILED("failed"), + CANCELLED("cancelled"); + + private final String state; + + StatementState(String state) { + this.state = state; + } + + private static Map STATES = + Arrays.stream(StatementState.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static StatementState fromString(String key) { + if (STATES.containsKey(key)) { + return STATES.get(key); + } + throw new IllegalArgumentException("Invalid statement state: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java deleted file mode 100644 index 6ddce55360..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statestore; - -import java.io.IOException; -import java.util.Locale; -import java.util.Optional; -import lombok.RequiredArgsConstructor; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.client.Client; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.session.SessionModel; - -@RequiredArgsConstructor -public class SessionStateStore { - private static final Logger LOG = LogManager.getLogger(); - - private final String indexName; - private final Client client; - - public SessionModel create(SessionModel session) { - try { - IndexRequest indexRequest = - new IndexRequest(indexName) - .id(session.getSessionId().getSessionId()) - .source(session.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) - .setIfSeqNo(session.getSeqNo()) - .setIfPrimaryTerm(session.getPrimaryTerm()) - .create(true) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client.index(indexRequest).actionGet(); - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("Successfully created doc. id: {}", session.getSessionId()); - return SessionModel.of(session, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed create doc. id: %s, error: %s", - session.getSessionId(), - indexResponse.getResult().getLowercase())); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public Optional get(SessionId sid) { - try { - GetRequest getRequest = new GetRequest().index(indexName).id(sid.getSessionId()); - GetResponse getResponse = client.get(getRequest).actionGet(); - if (getResponse.isExists()) { - XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - getResponse.getSourceAsString()); - parser.nextToken(); - return Optional.of( - SessionModel.fromXContent( - parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); - } else { - return Optional.empty(); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java new file mode 100644 index 0000000000..b5bf31a6ba --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentParser; + +public abstract class StateModel implements ToXContentObject { + + public abstract String getId(); + + public abstract long getSeqNo(); + + public abstract long getPrimaryTerm(); + + public interface CopyBuilder { + T of(T copy, long seqNo, long primaryTerm); + } + + public interface StateCopyBuilder { + T of(T copy, S state, long seqNo, long primaryTerm); + } + + public interface FromXContent { + T fromXContent(XContentParser parser, long seqNo, long primaryTerm); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java new file mode 100644 index 0000000000..bd72b17353 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -0,0 +1,149 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Function; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.execution.session.SessionModel; +import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statement.StatementModel; +import org.opensearch.sql.spark.execution.statement.StatementState; + +@RequiredArgsConstructor +public class StateStore { + private static final Logger LOG = LogManager.getLogger(); + + private final String indexName; + private final Client client; + + protected T create(T st, StateModel.CopyBuilder builder) { + try { + IndexRequest indexRequest = + new IndexRequest(indexName) + .id(st.getId()) + .source(st.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .setIfSeqNo(st.getSeqNo()) + .setIfPrimaryTerm(st.getPrimaryTerm()) + .create(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", st.getId()); + return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + st.getId(), + indexResponse.getResult().getLowercase())); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Optional get(String sid, StateModel.FromXContent builder) { + try { + GetRequest getRequest = new GetRequest().index(indexName).id(sid); + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected T updateState( + T st, S state, StateModel.StateCopyBuilder builder) { + try { + T model = builder.of(st, state, st.getSeqNo(), st.getPrimaryTerm()); + UpdateRequest updateRequest = + new UpdateRequest() + .index(indexName) + .id(model.getId()) + .setIfSeqNo(model.getSeqNo()) + .setIfPrimaryTerm(model.getPrimaryTerm()) + .doc(model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .fetchSource(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + UpdateResponse updateResponse = client.update(updateRequest).actionGet(); + if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { + LOG.debug("Successfully update doc. id: {}", st.getId()); + return builder.of(model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed update doc. id: %s, error: %s", + st.getId(), + updateResponse.getResult().getLowercase())); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** Helper Functions */ + public static Function createStatement(StateStore stateStore) { + return (st) -> stateStore.create(st, StatementModel::copy); + } + + public static Function> getStatement(StateStore stateStore) { + return (docId) -> stateStore.get(docId, StatementModel::fromXContent); + } + + public static BiFunction updateStatementState( + StateStore stateStore) { + return (old, state) -> stateStore.updateState(old, state, StatementModel::copyWithState); + } + + public static Function createSession(StateStore stateStore) { + return (session) -> stateStore.create(session, SessionModel::of); + } + + public static Function> getSession(StateStore stateStore) { + return (docId) -> stateStore.get(docId, SessionModel::fromXContent); + } + + public static BiFunction updateSessionState( + StateStore stateStore) { + return (old, state) -> stateStore.updateState(old, state, SessionModel::copyWithState); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 3ff547157c..0652a2d032 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -20,7 +21,7 @@ import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchSingleNodeTestCase; /** mock-maker-inline does not work with OpenSearchTestCase. */ @@ -30,13 +31,13 @@ public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; - private SessionStateStore stateStore; + private StateStore stateStore; @Before public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new SessionStateStore(indexName, client()); + stateStore = new StateStore(indexName, client()); createIndex(indexName); } @@ -50,7 +51,7 @@ public void openCloseSession() { InteractiveSession session = InteractiveSession.builder() .sessionId(SessionId.newSessionId()) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) .build(); @@ -71,7 +72,7 @@ public void openSessionFailedConflict() { InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) .build(); @@ -80,7 +81,7 @@ public void openSessionFailedConflict() { InteractiveSession duplicateSession = InteractiveSession.builder() .sessionId(sessionId) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) .build(); @@ -95,16 +96,16 @@ public void closeNotExistSession() { InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) .build(); session.open(); - client().delete(new DeleteRequest(indexName, sessionId.getSessionId())); + client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); IllegalStateException exception = assertThrows(IllegalStateException.class, session::close); - assertEquals("session not exist. " + sessionId, exception.getMessage()); + assertEquals("session does not exist. " + sessionId, exception.getMessage()); emrsClient.cancelJobRunCalled(0); } @@ -140,9 +141,9 @@ public void sessionManagerGetSessionNotExist() { @RequiredArgsConstructor static class TestSession { private final Session session; - private final SessionStateStore stateStore; + private final StateStore stateStore; - public static TestSession testSession(Session session, SessionStateStore stateStore) { + public static TestSession testSession(Session session, StateStore stateStore) { return new TestSession(session, stateStore); } @@ -150,7 +151,7 @@ public TestSession assertSessionState(SessionState expected) { assertEquals(expected, session.getSessionModel().getSessionState()); Optional sessionStoreState = - stateStore.get(session.getSessionModel().getSessionId()); + getSession(stateStore).apply(session.getSessionModel().getId()); assertTrue(sessionStoreState.isPresent()); assertEquals(expected, sessionStoreState.get().getSessionState()); @@ -178,7 +179,7 @@ public TestSession close() { } } - static class TestEMRServerlessClient implements EMRServerlessClient { + public static class TestEMRServerlessClient implements EMRServerlessClient { private int startJobRunCalled = 0; private int cancelJobRunCalled = 0; diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index d35105f787..6028de6cfb 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -5,15 +5,13 @@ package org.opensearch.sql.spark.execution.session; -import static org.junit.jupiter.api.Assertions.*; - import org.junit.After; import org.junit.Before; import org.mockito.MockMakers; import org.mockito.MockSettings; import org.mockito.Mockito; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchSingleNodeTestCase; class SessionManagerTest extends OpenSearchSingleNodeTestCase { @@ -23,11 +21,11 @@ class SessionManagerTest extends OpenSearchSingleNodeTestCase { private static final MockSettings mockSettings = Mockito.withSettings().mockMaker(MockMakers.SUBCLASS); - private SessionStateStore stateStore; + private StateStore stateStore; @Before public void setup() { - stateStore = new SessionStateStore(indexName, client()); + stateStore = new StateStore(indexName, client()); createIndex(indexName); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java new file mode 100644 index 0000000000..b7af1123ba --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class StatementStateTest { + @Test + public void invalidStatementState() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> StatementState.fromString("invalid")); + Assertions.assertEquals("Invalid statement state: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java new file mode 100644 index 0000000000..b0bc84219b --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -0,0 +1,296 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; +import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; +import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; + +import java.util.HashMap; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.execution.session.CreateSessionRequest; +import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +public class StatementTest extends OpenSearchSingleNodeTestCase { + + private static final String indexName = "mockindex"; + + private StartJobRequest startJobRequest; + private StateStore stateStore; + private InteractiveSessionTest.TestEMRServerlessClient emrsClient = + new InteractiveSessionTest.TestEMRServerlessClient(); + + @Before + public void setup() { + startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); + stateStore = new StateStore(indexName, client()); + createIndex(indexName); + } + + @After + public void clean() { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } + + @Test + public void openThenCancelStatement() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + + // submit statement + TestStatement testStatement = testStatement(st, stateStore); + testStatement + .open() + .assertSessionState(WAITING) + .assertStatementId(new StatementId("statementId")); + + // close statement + testStatement.cancel().assertSessionState(CANCELLED); + } + + @Test + public void openFailedBecauseConflict() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + // open statement with same statement id + Statement dupSt = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + IllegalStateException exception = assertThrows(IllegalStateException.class, dupSt::open); + assertEquals("statement already exist. statementId=statementId", exception.getMessage()); + } + + @Test + public void cancelNotExistStatement() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .statementId(stId) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + client().delete(new DeleteRequest(indexName, stId.getId())); + + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("cancel statement failed. no statement found. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelFailedBecauseOfConflict() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .statementId(stId) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + StatementModel running = + updateStatementState(stateStore).apply(st.getStatementModel(), CANCELLED); + + assertEquals(StatementState.CANCELLED, running.getStatementState()); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format( + "cancel statement failed. current statementState: CANCELLED " + "statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelRunningStatementFailed() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .statementId(stId) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + // update to running state + StatementModel model = st.getStatementModel(); + st.setStatementModel( + StatementModel.copyWithState( + st.getStatementModel(), + StatementState.RUNNING, + model.getSeqNo(), + model.getPrimaryTerm())); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("can't cancel statement in waiting state. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void submitStatementInRunningSession() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + // App change state to running + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + assertFalse(statementId.getId().isEmpty()); + } + + @Test + public void failToSubmitStatementInStartingSession() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertEquals( + "can't submit statement, session should in running state, current session state is:" + + " not_started", + exception.getMessage()); + } + + @Test + public void failToSubmitStatementInDeletedSession() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + // other's delete session + client() + .delete(new DeleteRequest(indexName, session.getSessionId().getSessionId())) + .actionGet(); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertEquals("session does not exist. " + session.getSessionId(), exception.getMessage()); + } + + @Test + public void getStatementSuccess() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + // App change state to running + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + + Optional statement = session.get(statementId); + assertTrue(statement.isPresent()); + assertEquals(WAITING, statement.get().getStatementState()); + assertEquals(statementId, statement.get().getStatementId()); + } + + @Test + public void getStatementNotExist() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + // App change state to running + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + + Optional statement = session.get(StatementId.newStatementId()); + assertFalse(statement.isPresent()); + } + + @RequiredArgsConstructor + static class TestStatement { + private final Statement st; + private final StateStore stateStore; + + public static TestStatement testStatement(Statement st, StateStore stateStore) { + return new TestStatement(st, stateStore); + } + + public TestStatement assertSessionState(StatementState expected) { + assertEquals(expected, st.getStatementModel().getStatementState()); + + Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + assertTrue(model.isPresent()); + assertEquals(expected, model.get().getStatementState()); + + return this; + } + + public TestStatement assertStatementId(StatementId expected) { + assertEquals(expected, st.getStatementModel().getStatementId()); + + Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + assertTrue(model.isPresent()); + assertEquals(expected, model.get().getStatementId()); + return this; + } + + public TestStatement open() { + st.open(); + return this; + } + + public TestStatement cancel() { + st.cancel(); + return this; + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java deleted file mode 100644 index 9c779555d7..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statestore; - -import static org.junit.Assert.assertThrows; -import static org.mockito.Answers.RETURNS_DEEP_STUBS; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.when; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.client.Client; -import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.session.SessionModel; - -@ExtendWith(MockitoExtension.class) -class SessionStateStoreTest { - @Mock(answer = RETURNS_DEEP_STUBS) - private Client client; - - @Mock private IndexResponse indexResponse; - - @Test - public void createWithException() { - when(client.index(any()).actionGet()).thenReturn(indexResponse); - doReturn(DocWriteResponse.Result.NOT_FOUND).when(indexResponse).getResult(); - SessionModel sessionModel = - SessionModel.initInteractiveSession( - "appId", "jobId", SessionId.newSessionId(), "datasource"); - SessionStateStore sessionStateStore = new SessionStateStore("indexName", client); - - assertThrows(RuntimeException.class, () -> sessionStateStore.create(sessionModel)); - } -} From aab6651f3cd425752480349c7941799c769de217 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Fri, 13 Oct 2023 15:46:30 -0700 Subject: [PATCH 03/20] add statement Signed-off-by: Peng Huo --- .../service/DataSourceServiceImpl.java | 9 ++- .../service/DataSourceServiceImplTest.java | 67 ++++++++++++++++++- .../sql/datasource/DataSourceAPIsIT.java | 8 +++ .../opensearch-sql.release-notes-2.11.0.0.md | 55 +++++++++++++++ .../execution/session/InteractiveSession.java | 4 +- .../sql/spark/execution/session/Session.java | 2 +- .../execution/session/SessionManager.java | 3 +- .../session/InteractiveSessionTest.java | 16 ++--- .../execution/session/SessionManagerTest.java | 7 -- 9 files changed, 144 insertions(+), 27 deletions(-) create mode 100644 release-notes/opensearch-sql.release-notes-2.11.0.0.md diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java index d6c1907f84..25e8006d66 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java @@ -34,6 +34,8 @@ public class DataSourceServiceImpl implements DataSourceService { private static String DATASOURCE_NAME_REGEX = "[@*A-Za-z]+?[*a-zA-Z_\\-0-9]*"; + public static final Set CONFIDENTIAL_AUTH_KEYS = + Set.of("auth.username", "auth.password", "auth.access_key", "auth.secret_key"); private final DataSourceLoaderCache dataSourceLoaderCache; @@ -159,7 +161,12 @@ private void removeAuthInfo(Set dataSourceMetadataSet) { private void removeAuthInfo(DataSourceMetadata dataSourceMetadata) { HashMap safeProperties = new HashMap<>(dataSourceMetadata.getProperties()); - safeProperties.entrySet().removeIf(entry -> entry.getKey().contains("auth")); + safeProperties + .entrySet() + .removeIf( + entry -> + CONFIDENTIAL_AUTH_KEYS.stream() + .anyMatch(confidentialKey -> entry.getKey().endsWith(confidentialKey))); dataSourceMetadata.setProperties(safeProperties); } } diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index c8312e6013..6164d8b73f 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -233,7 +233,7 @@ void testGetDataSourceMetadataSet() { assertEquals(1, dataSourceMetadataSet.size()); DataSourceMetadata dataSourceMetadata = dataSourceMetadataSet.iterator().next(); assertTrue(dataSourceMetadata.getProperties().containsKey("prometheus.uri")); - assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.type")); + assertTrue(dataSourceMetadata.getProperties().containsKey("prometheus.auth.type")); assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.username")); assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.password")); assertFalse( @@ -352,11 +352,72 @@ void testRemovalOfAuthorizationInfo() { DataSourceMetadata dataSourceMetadata1 = dataSourceService.getDataSourceMetadata("testDS"); assertEquals("testDS", dataSourceMetadata1.getName()); assertEquals(DataSourceType.PROMETHEUS, dataSourceMetadata1.getConnector()); - assertFalse(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); assertFalse(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.username")); assertFalse(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.password")); } + @Test + void testRemovalOfAuthorizationInfoForAccessKeyAndSecretKye() { + HashMap properties = new HashMap<>(); + properties.put("prometheus.uri", "https://localhost:9090"); + properties.put("prometheus.auth.type", "awssigv4"); + properties.put("prometheus.auth.access_key", "access_key"); + properties.put("prometheus.auth.secret_key", "secret_key"); + DataSourceMetadata dataSourceMetadata = + new DataSourceMetadata( + "testDS", + DataSourceType.PROMETHEUS, + Collections.singletonList("prometheus_access"), + properties, + null); + when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) + .thenReturn(Optional.of(dataSourceMetadata)); + + DataSourceMetadata dataSourceMetadata1 = dataSourceService.getDataSourceMetadata("testDS"); + assertEquals("testDS", dataSourceMetadata1.getName()); + assertEquals(DataSourceType.PROMETHEUS, dataSourceMetadata1.getConnector()); + assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); + assertFalse(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.access_key")); + assertFalse(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.secret_key")); + } + + @Test + void testRemovalOfAuthorizationInfoForGlueWithRoleARN() { + HashMap properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put("glue.auth.role_arn", "role_arn"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9200"); + properties.put("glue.indexstore.opensearch.auth", "basicauth"); + properties.put("glue.indexstore.opensearch.auth.username", "username"); + properties.put("glue.indexstore.opensearch.auth.password", "password"); + DataSourceMetadata dataSourceMetadata = + new DataSourceMetadata( + "testGlue", + DataSourceType.S3GLUE, + Collections.singletonList("glue_access"), + properties, + null); + when(dataSourceMetadataStorage.getDataSourceMetadata("testGlue")) + .thenReturn(Optional.of(dataSourceMetadata)); + + DataSourceMetadata dataSourceMetadata1 = dataSourceService.getDataSourceMetadata("testGlue"); + assertEquals("testGlue", dataSourceMetadata1.getName()); + assertEquals(DataSourceType.S3GLUE, dataSourceMetadata1.getConnector()); + assertTrue(dataSourceMetadata1.getProperties().containsKey("glue.auth.type")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("glue.auth.role_arn")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("glue.indexstore.opensearch.uri")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("glue.indexstore.opensearch.auth")); + assertFalse( + dataSourceMetadata1 + .getProperties() + .containsKey("glue.indexstore.opensearch.auth.username")); + assertFalse( + dataSourceMetadata1 + .getProperties() + .containsKey("glue.indexstore.opensearch.auth.password")); + } + @Test void testGetDataSourceMetadataForNonExistingDataSource() { when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")).thenReturn(Optional.empty()); @@ -381,7 +442,7 @@ void testGetDataSourceMetadataForSpecificDataSourceName() { "testDS", DataSourceType.PROMETHEUS, Collections.emptyList(), properties))); DataSourceMetadata dataSourceMetadata = this.dataSourceService.getDataSourceMetadata("testDS"); assertTrue(dataSourceMetadata.getProperties().containsKey("prometheus.uri")); - assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.type")); + assertTrue(dataSourceMetadata.getProperties().containsKey("prometheus.auth.type")); assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.username")); assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.password")); verify(dataSourceMetadataStorage, times(1)).getDataSourceMetadata("testDS"); diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java index 087629a1f1..8623b9fa6f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java @@ -85,6 +85,10 @@ public void createDataSourceAPITest() { new Gson().fromJson(getResponseString, DataSourceMetadata.class); Assert.assertEquals( "https://localhost:9090", dataSourceMetadata.getProperties().get("prometheus.uri")); + Assert.assertEquals( + "basicauth", dataSourceMetadata.getProperties().get("prometheus.auth.type")); + Assert.assertNull(dataSourceMetadata.getProperties().get("prometheus.auth.username")); + Assert.assertNull(dataSourceMetadata.getProperties().get("prometheus.auth.password")); Assert.assertEquals("Prometheus Creation for Integ test", dataSourceMetadata.getDescription()); } @@ -239,6 +243,10 @@ public void issue2196() { new Gson().fromJson(getResponseString, DataSourceMetadata.class); Assert.assertEquals( "https://localhost:9090", dataSourceMetadata.getProperties().get("prometheus.uri")); + Assert.assertEquals( + "basicauth", dataSourceMetadata.getProperties().get("prometheus.auth.type")); + Assert.assertNull(dataSourceMetadata.getProperties().get("prometheus.auth.username")); + Assert.assertNull(dataSourceMetadata.getProperties().get("prometheus.auth.password")); Assert.assertEquals("Prometheus Creation for Integ test", dataSourceMetadata.getDescription()); } } diff --git a/release-notes/opensearch-sql.release-notes-2.11.0.0.md b/release-notes/opensearch-sql.release-notes-2.11.0.0.md new file mode 100644 index 0000000000..a560d5c8dd --- /dev/null +++ b/release-notes/opensearch-sql.release-notes-2.11.0.0.md @@ -0,0 +1,55 @@ +Compatible with OpenSearch and OpenSearch Dashboards Version 2.11.0 + +### Features + +### Enhancements +* Enable PPL lang and add datasource to async query API in https://github.com/opensearch-project/sql/pull/2195 +* Refactor Flint Auth in https://github.com/opensearch-project/sql/pull/2201 +* Add conf for spark structured streaming job in https://github.com/opensearch-project/sql/pull/2203 +* Submit long running job only when auto_refresh = false in https://github.com/opensearch-project/sql/pull/2209 +* Bug Fix, handle DESC TABLE response in https://github.com/opensearch-project/sql/pull/2213 +* Drop Index Implementation in https://github.com/opensearch-project/sql/pull/2217 +* Enable PPL Queries in https://github.com/opensearch-project/sql/pull/2223 +* Read extra Spark submit parameters from cluster settings in https://github.com/opensearch-project/sql/pull/2236 +* Spark Execution Engine Config Refactor in https://github.com/opensearch-project/sql/pull/2266 +* Provide auth.type and auth.role_arn paramters in GET Datasource API response. in https://github.com/opensearch-project/sql/pull/2283 +* Add support for `date_nanos` and tests. (#337) in https://github.com/opensearch-project/sql/pull/2020 +* Applied formatting improvements to Antlr files based on spotless changes (#2017) by @MitchellGale in https://github.com/opensearch-project/sql/pull/2023 +* Revert "Guarantee datasource read api is strong consistent read (#1815)" in https://github.com/opensearch-project/sql/pull/2031 +* Add _primary preference only for segment replication enabled indices in https://github.com/opensearch-project/sql/pull/2045 +* Changed allowlist config to denylist ip config for datasource uri hosts in https://github.com/opensearch-project/sql/pull/2058 + +### Bug Fixes +* fix broken link for connectors doc in https://github.com/opensearch-project/sql/pull/2199 +* Fix response codes returned by JSON formatting them in https://github.com/opensearch-project/sql/pull/2200 +* Bug fix, datasource API should be case sensitive in https://github.com/opensearch-project/sql/pull/2202 +* Minor fix in dropping covering index in https://github.com/opensearch-project/sql/pull/2240 +* Fix Unit tests for FlintIndexReader in https://github.com/opensearch-project/sql/pull/2242 +* Bug Fix , delete OpenSearch index when DROP INDEX in https://github.com/opensearch-project/sql/pull/2252 +* Correctly Set query status in https://github.com/opensearch-project/sql/pull/2232 +* Exclude generated files from spotless in https://github.com/opensearch-project/sql/pull/2024 +* Fix mockito core conflict. in https://github.com/opensearch-project/sql/pull/2131 +* Fix `ASCII` function and groom UT for text functions. (#301) in https://github.com/opensearch-project/sql/pull/2029 +* Fixed response codes For Requests With security exception. in https://github.com/opensearch-project/sql/pull/2036 + +### Documentation +* Datasource description in https://github.com/opensearch-project/sql/pull/2138 +* Add documentation for S3GlueConnector. in https://github.com/opensearch-project/sql/pull/2234 + +### Infrastructure +* bump aws-encryption-sdk-java to 1.71 in https://github.com/opensearch-project/sql/pull/2057 +* Run IT tests with security plugin (#335) #1986 by @MitchellGale in https://github.com/opensearch-project/sql/pull/2022 + +### Refactoring +* Merging Async Query APIs feature branch into main. in https://github.com/opensearch-project/sql/pull/2163 +* Removed Domain Validation in https://github.com/opensearch-project/sql/pull/2136 +* Check for existence of security plugin in https://github.com/opensearch-project/sql/pull/2069 +* Always use snapshot version for security plugin download in https://github.com/opensearch-project/sql/pull/2061 +* Add customized result index in data source etc in https://github.com/opensearch-project/sql/pull/2220 + +### Security +* bump okhttp to 4.10.0 (#2043) by @joshuali925 in https://github.com/opensearch-project/sql/pull/2044 +* bump okio to 3.4.0 by @joshuali925 in https://github.com/opensearch-project/sql/pull/2047 + +--- +**Full Changelog**: https://github.com/opensearch-project/sql/compare/2.3.0.0...v.2.11.0.0 \ No newline at end of file diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index dd6cb1160d..101cc7f5f1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -36,12 +36,10 @@ public class InteractiveSession implements Session { private final SessionId sessionId; private final StateStore stateStore; private final EMRServerlessClient serverlessClient; - private final CreateSessionRequest createSessionRequest; - private SessionModel sessionModel; @Override - public void open() { + public void open(CreateSessionRequest createSessionRequest) { try { String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java index 752055e119..4d919d5e2e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -13,7 +13,7 @@ /** Session define the statement execution context. Each session is binding to one Spark Job. */ public interface Session { /** open session. */ - void open(); + void open(CreateSessionRequest createSessionRequest); /** close session. */ void close(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 5ed510f367..217af80caf 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -28,9 +28,8 @@ public Session createSession(CreateSessionRequest request) { .sessionId(newSessionId()) .stateStore(stateStore) .serverlessClient(emrServerlessClient) - .createSessionRequest(request) .build(); - session.open(); + session.open(request); return session; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 0652a2d032..1cb2ee08f1 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -53,12 +53,11 @@ public void openCloseSession() { .sessionId(SessionId.newSessionId()) .stateStore(stateStore) .serverlessClient(emrsClient) - .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) .build(); // open session TestSession testSession = testSession(session, stateStore); - testSession.open().assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); + testSession.open(new CreateSessionRequest(startJobRequest, "datasource")).assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); emrsClient.startJobRunCalled(1); // close session @@ -74,19 +73,17 @@ public void openSessionFailedConflict() { .sessionId(sessionId) .stateStore(stateStore) .serverlessClient(emrsClient) - .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) .build(); - session.open(); + session.open(new CreateSessionRequest(startJobRequest, "datasource")); InteractiveSession duplicateSession = InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) .serverlessClient(emrsClient) - .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) .build(); IllegalStateException exception = - assertThrows(IllegalStateException.class, duplicateSession::open); + assertThrows(IllegalStateException.class, () -> duplicateSession.open(new CreateSessionRequest(startJobRequest, "datasource"))); assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage()); } @@ -98,9 +95,8 @@ public void closeNotExistSession() { .sessionId(sessionId) .stateStore(stateStore) .serverlessClient(emrsClient) - .createSessionRequest(new CreateSessionRequest(startJobRequest, "datasource")) .build(); - session.open(); + session.open(new CreateSessionRequest(startJobRequest, "datasource")); client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); @@ -168,8 +164,8 @@ public TestSession assertJobId(String expected) { return this; } - public TestSession open() { - session.open(); + public TestSession open(CreateSessionRequest req) { + session.open(req); return this; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index 6028de6cfb..95b85613be 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -7,9 +7,6 @@ import org.junit.After; import org.junit.Before; -import org.mockito.MockMakers; -import org.mockito.MockSettings; -import org.mockito.Mockito; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchSingleNodeTestCase; @@ -17,10 +14,6 @@ class SessionManagerTest extends OpenSearchSingleNodeTestCase { private static final String indexName = "mockindex"; - // mock-maker-inline does not work with OpenSearchTestCase. make sure use mockSettings when mock. - private static final MockSettings mockSettings = - Mockito.withSettings().mockMaker(MockMakers.SUBCLASS); - private StateStore stateStore; @Before From ff02f28366b8ce8362bd7712a1e763e069e2eade Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Fri, 13 Oct 2023 16:17:53 -0700 Subject: [PATCH 04/20] fix format Signed-off-by: Peng Huo --- .../execution/session/InteractiveSessionTest.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 1cb2ee08f1..488252d05a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -57,7 +57,11 @@ public void openCloseSession() { // open session TestSession testSession = testSession(session, stateStore); - testSession.open(new CreateSessionRequest(startJobRequest, "datasource")).assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); + testSession + .open(new CreateSessionRequest(startJobRequest, "datasource")) + .assertSessionState(NOT_STARTED) + .assertAppId("appId") + .assertJobId("jobId"); emrsClient.startJobRunCalled(1); // close session @@ -83,7 +87,9 @@ public void openSessionFailedConflict() { .serverlessClient(emrsClient) .build(); IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> duplicateSession.open(new CreateSessionRequest(startJobRequest, "datasource"))); + assertThrows( + IllegalStateException.class, + () -> duplicateSession.open(new CreateSessionRequest(startJobRequest, "datasource"))); assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage()); } From ca59586dccf57d90a4d629eaebf63c6d2b36848c Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Mon, 16 Oct 2023 16:36:48 -0700 Subject: [PATCH 05/20] snapshot Signed-off-by: Peng Huo --- .../sql/common/setting/Settings.java | 10 +- .../setting/OpenSearchSettings.java | 14 ++ .../org/opensearch/sql/plugin/SQLPlugin.java | 9 +- .../AsyncQueryExecutorServiceImpl.java | 16 +- .../model/AsyncQueryExecutionResponse.java | 1 + .../model/AsyncQueryJobMetadata.java | 11 +- .../asyncquery/model/AsyncQueryResult.java | 9 +- .../spark/data/constants/SparkConstants.java | 1 + .../dispatcher/SparkQueryDispatcher.java | 71 ++++++- .../model/DispatchQueryRequest.java | 3 + .../model/DispatchQueryResponse.java | 1 + .../execution/session/SessionManager.java | 7 + .../rest/RestAsyncQueryManagementAction.java | 9 +- .../rest/model/CreateAsyncQueryRequest.java | 1 + .../TransportGetAsyncQueryResultAction.java | 3 +- .../AsyncQueryResultResponseFormatter.java | 4 + .../model/CancelAsyncQueryActionRequest.java | 2 + .../GetAsyncQueryResultActionRequest.java | 2 + .../AsyncQueryExecutorServiceImplTest.java | 4 +- .../sql/spark/constants/TestConstants.java | 2 + .../dispatcher/SparkQueryDispatcherTest.java | 194 +++++++++++++++++- .../session/InteractiveSessionTest.java | 9 +- .../execution/session/SessionManagerTest.java | 51 +++-- .../execution/statement/StatementTest.java | 11 +- ...portCancelAsyncQueryRequestActionTest.java | 7 +- ...ransportGetAsyncQueryResultActionTest.java | 29 ++- ...AsyncQueryResultResponseFormatterTest.java | 23 ++- 27 files changed, 448 insertions(+), 56 deletions(-) diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index 8daf0e9bf6..89d046b3d9 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -5,6 +5,8 @@ package org.opensearch.sql.common.setting; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_ENABLED; + import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -36,7 +38,8 @@ public enum Key { METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"), METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"), SPARK_EXECUTION_ENGINE_CONFIG("plugins.query.executionengine.spark.config"), - CLUSTER_NAME("cluster.name"); + CLUSTER_NAME("cluster.name"), + SPARK_EXECUTION_SESSION_ENABLED("plugins.query.executionengine.spark.session.enabled"); @Getter private final String keyValue; @@ -60,4 +63,9 @@ public static Optional of(String keyValue) { public abstract T getSettingValue(Key key); public abstract List getSettings(); + + /** Helper class */ + public static boolean isSparkExecutionSessionEnabled(Settings settings) { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index 76bda07607..ec5bc7dfc0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -135,6 +135,13 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting SPARK_EXECUTION_SESSION_ENABLED_SETTING = + Setting.boolSetting( + Key.SPARK_EXECUTION_SESSION_ENABLED.getKeyValue(), + true, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + /** Construct OpenSearchSetting. The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { @@ -205,6 +212,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.SPARK_EXECUTION_ENGINE_CONFIG, SPARK_EXECUTION_ENGINE_CONFIG, new Updater(Key.SPARK_EXECUTION_ENGINE_CONFIG)); + register( + settingBuilder, + clusterSettings, + Key.SPARK_EXECUTION_SESSION_ENABLED, + SPARK_EXECUTION_SESSION_ENABLED_SETTING, + new Updater(Key.SPARK_EXECUTION_SESSION_ENABLED)); registerNonDynamicSettings( settingBuilder, clusterSettings, Key.CLUSTER_NAME, ClusterName.CLUSTER_NAME_SETTING); defaultSettings = settingBuilder.build(); @@ -270,6 +283,7 @@ public static List> pluginSettings() { .add(METRICS_ROLLING_INTERVAL_SETTING) .add(DATASOURCE_URI_HOSTS_DENY_LIST) .add(SPARK_EXECUTION_ENGINE_CONFIG) + .add(SPARK_EXECUTION_SESSION_ENABLED_SETTING) .build(); } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f3fd043b63..a9a35f6318 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.emrserverless.AWSEMRServerless; @@ -99,6 +100,8 @@ import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; @@ -318,7 +321,11 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( new DataSourceUserAuthorizationHelperImpl(client), jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), - client); + client, + new SessionManager( + new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client), + emrServerlessClient, + pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 13db103f4b..7234170a97 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -65,13 +65,15 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getLang(), sparkExecutionEngineConfig.getExecutionRoleARN(), sparkExecutionEngineConfig.getClusterName(), - sparkExecutionEngineConfig.getSparkSubmitParameters())); + sparkExecutionEngineConfig.getSparkSubmitParameters(), + createAsyncQueryRequest.getSessionId())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata( sparkExecutionEngineConfig.getApplicationId(), dispatchQueryResponse.getJobId(), dispatchQueryResponse.isDropIndexQuery(), - dispatchQueryResponse.getResultIndex())); + dispatchQueryResponse.getResultIndex(), + dispatchQueryResponse.getSessionId())); return new CreateAsyncQueryResponse(dispatchQueryResponse.getJobId()); } @@ -81,6 +83,7 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { Optional jobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (jobMetadata.isPresent()) { + String sessionId = jobMetadata.get().getSessionId(); JSONObject jsonObject = sparkQueryDispatcher.getQueryResponse(jobMetadata.get()); if (JobRunState.SUCCESS.toString().equals(jsonObject.getString(STATUS_FIELD))) { DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = @@ -90,13 +93,18 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { result.add(sparkSqlFunctionResponseHandle.next()); } return new AsyncQueryExecutionResponse( - JobRunState.SUCCESS.toString(), sparkSqlFunctionResponseHandle.schema(), result, null); + JobRunState.SUCCESS.toString(), + sparkSqlFunctionResponseHandle.schema(), + result, + null, + sessionId); } else { return new AsyncQueryExecutionResponse( jsonObject.optString(STATUS_FIELD, JobRunState.FAILED.toString()), null, null, - jsonObject.optString(ERROR_FIELD, "")); + jsonObject.optString(ERROR_FIELD, ""), + sessionId); } } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java index d2e54af004..e5d9cffd5f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java @@ -19,4 +19,5 @@ public class AsyncQueryExecutionResponse { private final ExecutionEngine.Schema schema; private final List results; private final String error; + private final String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index b470ef989f..b80fefa173 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -30,12 +30,15 @@ public class AsyncQueryJobMetadata { private String jobId; private boolean isDropIndexQuery; private String resultIndex; + // optional sessionId. + private String sessionId; public AsyncQueryJobMetadata(String applicationId, String jobId, String resultIndex) { this.applicationId = applicationId; this.jobId = jobId; this.isDropIndexQuery = false; this.resultIndex = resultIndex; + this.sessionId = null; } @Override @@ -57,6 +60,7 @@ public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) builder.field("applicationId", metadata.getApplicationId()); builder.field("isDropIndexQuery", metadata.isDropIndexQuery()); builder.field("resultIndex", metadata.getResultIndex()); + builder.field("sessionId", metadata.getSessionId()); builder.endObject(); return builder; } @@ -92,6 +96,7 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws String applicationId = null; boolean isDropIndexQuery = false; String resultIndex = null; + String sessionId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -109,6 +114,9 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws case "resultIndex": resultIndex = parser.textOrNull(); break; + case "sessionId": + sessionId = parser.textOrNull(); + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -116,6 +124,7 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws if (jobId == null || applicationId == null) { throw new IllegalArgumentException("jobId and applicationId are required fields."); } - return new AsyncQueryJobMetadata(applicationId, jobId, isDropIndexQuery, resultIndex); + return new AsyncQueryJobMetadata( + applicationId, jobId, isDropIndexQuery, resultIndex, sessionId); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java index c229aa3920..7fda8aefd8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java @@ -12,25 +12,30 @@ public class AsyncQueryResult extends QueryResult { @Getter private final String status; @Getter private final String error; + @Getter private final String sessionId; public AsyncQueryResult( String status, ExecutionEngine.Schema schema, Collection exprValues, Cursor cursor, - String error) { + String error, + String sessionId) { super(schema, exprValues, cursor); this.status = status; this.error = error; + this.sessionId = sessionId; } public AsyncQueryResult( String status, ExecutionEngine.Schema schema, Collection exprValues, - String error) { + String error, + String sessionId) { super(schema, exprValues); this.status = status; this.error = error; + this.sessionId = sessionId; } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 284afcc0a9..1b248eb15d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -21,6 +21,7 @@ public class SparkConstants { public static final String SPARK_SQL_APPLICATION_JAR = "file:///home/hadoop/.ivy2/jars/org.opensearch_opensearch-spark-sql-application_2.12-0.1.0-SNAPSHOT.jar"; public static final String SPARK_RESPONSE_BUFFER_INDEX_NAME = ".query_execution_result"; + public static final String SPARK_REQUEST_BUFFER_INDEX_NAME = ".query_execution_request"; // TODO should be replaced with mvn jar. public static final String FLINT_INTEGRATION_JAR = "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 347e154885..194a075edf 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -16,6 +16,7 @@ import java.util.Base64; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import lombok.AllArgsConstructor; import lombok.Getter; @@ -39,6 +40,13 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.execution.session.CreateSessionRequest; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -69,6 +77,8 @@ public class SparkQueryDispatcher { private Client client; + private SessionManager sessionManager; + public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { return handleSQLQuery(dispatchQueryRequest); @@ -124,10 +134,28 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - CancelJobRunResult cancelJobRunResult = - emrServerlessClient.cancelJobRun( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - return cancelJobRunResult.getJobRunId(); + if (sessionManager.isEnabled() && asyncQueryJobMetadata.getSessionId() != null) { + SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + statement.get().cancel(); + return statementId.getId(); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } + } else { + CancelJobRunResult cancelJobRunResult = + emrServerlessClient.cancelJobRun( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + return cancelJobRunResult.getJobRunId(); + } } private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryRequest) { @@ -173,7 +201,7 @@ private DispatchQueryResponse handleIndexQuery( indexDetails.getAutoRefresh(), dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex()); + return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); } private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { @@ -198,8 +226,35 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ tags, false, dataSourceMetadata.getResultIndex()); - String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex()); + if (sessionManager.isEnabled()) { + Session session; + if (dispatchQueryRequest.getSessionId() != null) { + // get session from request + SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); + Optional createdSession = sessionManager.getSession(sessionId); + if (createdSession.isEmpty()) { + throw new IllegalArgumentException("no session found. " + sessionId); + } + session = createdSession.get(); + } else { + // create session if not exist + session = + sessionManager.createSession( + new CreateSessionRequest(startJobRequest, dataSourceMetadata.getName())); + } + StatementId statementId = + session.submit( + new QueryRequest( + dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); + return new DispatchQueryResponse( + statementId.getId(), + false, + dataSourceMetadata.getResultIndex(), + session.getSessionId().getSessionId()); + } else { + String jobId = emrServerlessClient.startJobRun(startJobRequest); + return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); + } } private DispatchQueryResponse handleDropIndexQuery( @@ -229,7 +284,7 @@ private DispatchQueryResponse handleDropIndexQuery( } } return new DispatchQueryResponse( - new DropIndexResult(status).toJobId(), true, dataSourceMetadata.getResultIndex()); + new DropIndexResult(status).toJobId(), true, dataSourceMetadata.getResultIndex(), null); } private static Map getDefaultTagsForJobSubmission( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java index 823a4570ce..6aa28227a1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java @@ -23,4 +23,7 @@ public class DispatchQueryRequest { /** Optional extra Spark submit parameters to include in final request */ private String extraSparkSubmitParams; + + /** Optional sessionId. */ + private String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java index 9ee5f156f2..893446c617 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -9,4 +9,5 @@ public class DispatchQueryResponse { private String jobId; private boolean isDropIndexQuery; private String resultIndex; + private String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 217af80caf..c34be7015f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -5,10 +5,12 @@ package org.opensearch.sql.spark.execution.session; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_ENABLED; import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId; import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -21,6 +23,7 @@ public class SessionManager { private final StateStore stateStore; private final EMRServerlessClient emrServerlessClient; + private final Settings settings; public Session createSession(CreateSessionRequest request) { InteractiveSession session = @@ -47,4 +50,8 @@ public Optional getSession(SessionId sid) { } return Optional.empty(); } + + public boolean isEnabled() { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index 741501cd18..518b554792 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -43,6 +43,8 @@ public class RestAsyncQueryManagementAction extends BaseRestHandler { public static final String ASYNC_QUERY_ACTIONS = "async_query_actions"; public static final String BASE_ASYNC_QUERY_ACTION_URL = "/_plugins/_async_query"; + public static final String PARAMS_SESSION_ID = "sessionId"; + private static final Logger LOG = LogManager.getLogger(RestAsyncQueryManagementAction.class); @Override @@ -112,6 +114,7 @@ private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClie throws IOException { CreateAsyncQueryRequest submitJobRequest = CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); + submitJobRequest.setSessionId(restRequest.param(PARAMS_SESSION_ID, null)); return restChannel -> Scheduler.schedule( nodeClient, @@ -140,13 +143,14 @@ public void onFailure(Exception e) { private RestChannelConsumer executeGetAsyncQueryResultRequest( RestRequest restRequest, NodeClient nodeClient) { String queryId = restRequest.param("queryId"); + String sessionId = restRequest.param(PARAMS_SESSION_ID, null); return restChannel -> Scheduler.schedule( nodeClient, () -> nodeClient.execute( TransportGetAsyncQueryResultAction.ACTION_TYPE, - new GetAsyncQueryResultActionRequest(queryId), + new GetAsyncQueryResultActionRequest(queryId, sessionId), new ActionListener<>() { @Override public void onResponse( @@ -181,13 +185,14 @@ private void handleException(Exception e, RestChannel restChannel) { private RestChannelConsumer executeDeleteRequest(RestRequest restRequest, NodeClient nodeClient) { String queryId = restRequest.param("queryId"); + String sessionId = restRequest.param(PARAMS_SESSION_ID, null); return restChannel -> Scheduler.schedule( nodeClient, () -> nodeClient.execute( TransportCancelAsyncQueryRequestAction.ACTION_TYPE, - new CancelAsyncQueryActionRequest(queryId), + new CancelAsyncQueryActionRequest(queryId, sessionId), new ActionListener<>() { @Override public void onResponse( diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index 8802630d9f..0f1abe4bb2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -18,6 +18,7 @@ public class CreateAsyncQueryRequest { private String query; private String datasource; private LangType lang; + private String sessionId; public CreateAsyncQueryRequest(String query, String datasource, LangType lang) { this.query = Validate.notNull(query, "Query can't be null"); diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java index 5c784cf04c..72ddcd98e1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java @@ -62,7 +62,8 @@ protected void doExecute( asyncQueryExecutionResponse.getSchema(), asyncQueryExecutionResponse.getResults(), Cursor.None, - asyncQueryExecutionResponse.getError())); + asyncQueryExecutionResponse.getError(), + asyncQueryExecutionResponse.getSessionId())); listener.onResponse(new GetAsyncQueryResultActionResponse(responseContent)); } catch (Exception e) { listener.onFailure(e); diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java b/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java index 3a2a5b110f..9f11f5499f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java @@ -57,6 +57,9 @@ public Object buildJsonObject(AsyncQueryResult response) { if (!Strings.isEmpty(response.getError())) { json.error(response.getError()); } + if (response.getSessionId() != null) { + json.sessionId(response.getSessionId()); + } return json.build(); } @@ -85,6 +88,7 @@ public static class JsonResponse { private Integer total; private Integer size; private final String error; + private final String sessionId; } @RequiredArgsConstructor diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java index 0065b575ed..3db8f4009e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java @@ -20,6 +20,8 @@ public class CancelAsyncQueryActionRequest extends ActionRequest { private String queryId; + private String sessionId; + /** Constructor of SubmitJobActionRequest from StreamInput. */ public CancelAsyncQueryActionRequest(StreamInput in) throws IOException { super(in); diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java index 06faa75a26..1ccc8eea31 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java @@ -19,6 +19,8 @@ public class GetAsyncQueryResultActionRequest extends ActionRequest { @Getter private String queryId; + @Getter private String sessionId; + /** Constructor of GetJobQueryResultActionRequest from StreamInput. */ public GetAsyncQueryResultActionRequest(StreamInput in) throws IOException { super(in); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 01bccd9030..0d4e280b61 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -78,7 +78,7 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) @@ -107,7 +107,7 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { "--conf spark.dynamicAllocation.enabled=false", TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index abae0377a2..3a0d8fc56d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -16,4 +16,6 @@ public class TestConstants { public static final String EMRS_JOB_NAME = "job_name"; public static final String SPARK_SUBMIT_PARAMETERS = "--conf org.flint.sql.SQLJob"; public static final String TEST_CLUSTER_NAME = "TEST_CLUSTER"; + public static final String MOCK_SESSION_ID = "s-0123456"; + public static final String MOCK_STATEMENT_ID = "st-0123456"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index ab9761da36..52caa7e3c4 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -8,8 +8,12 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -18,6 +22,8 @@ import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_STATEMENT_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; @@ -34,6 +40,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import org.json.JSONObject; import org.junit.jupiter.api.Assertions; @@ -56,6 +63,11 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.flint.FlintIndexType; @@ -76,6 +88,12 @@ public class SparkQueryDispatcherTest { @Mock private FlintIndexMetadata flintIndexMetadata; + @Mock private SessionManager sessionManager; + + @Mock private Session session; + + @Mock private Statement statement; + private SparkQueryDispatcher sparkQueryDispatcher; @BeforeEach @@ -87,7 +105,8 @@ void setUp() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); } @Test @@ -261,6 +280,84 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { verifyNoInteractions(flintIndexMetadataReader); } + @Test + void testDispatchSelectQueryCreateNewSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, null); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(session).when(sessionManager).createSession(any()); + doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); + doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch(queryRequest); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).getSession(any()); + Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); + } + + @Test + void testDispatchSelectQueryReuseSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, MOCK_SESSION_ID); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.of(session)) + .when(sessionManager) + .getSession(eq(new SessionId(MOCK_SESSION_ID))); + doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); + doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch(queryRequest); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).createSession(any()); + Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); + } + + @Test + void testDispatchSelectQueryInvalidSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, "invalid"); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.empty()).when(sessionManager).getSession(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).createSession(any()); + Assertions.assertEquals( + "no session found. " + new SessionId("invalid"), exception.getMessage()); + } + + @Test + void testDispatchSelectQueryFailedCreateSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, null); + + doReturn(true).when(sessionManager).isEnabled(); + doThrow(RuntimeException.class).when(sessionManager).createSession(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + Assertions.assertThrows( + RuntimeException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); + + verifyNoInteractions(emrServerlessClient); + } + @Test void testDispatchIndexQuery() { HashMap tags = new HashMap<>(); @@ -563,6 +660,73 @@ void testCancelJob() { Assertions.assertEquals(EMR_JOB_ID, jobId); } + @Test + void testCancelQueryWithSession() { + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.of(statement)).when(session).get(any()); + doNothing().when(statement).cancel(); + + String queryId = + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + + verifyNoInteractions(emrServerlessClient); + verify(statement, times(1)).cancel(); + Assertions.assertEquals(MOCK_STATEMENT_ID, queryId); + } + + @Test + void testCancelQueryWithInvalidSession() { + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.empty()).when(sessionManager).getSession(new SessionId("invalid")); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, "invalid"))); + + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(session); + Assertions.assertEquals( + "no session found. " + new SessionId("invalid"), exception.getMessage()); + } + + @Test + void testCancelQueryWithInvalidStatementId() { + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId("invalid", MOCK_SESSION_ID))); + + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(statement); + Assertions.assertEquals( + "no statement found. " + new StatementId("invalid"), exception.getMessage()); + } + + @Test + void testCancelQueryWithNoSessionId() { + doReturn(true).when(sessionManager).isEnabled(); + + when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + String jobId = + sparkQueryDispatcher.cancelJob( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); + Assertions.assertEquals(EMR_JOB_ID, jobId); + } + @Test void testGetQueryResponse() { when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) @@ -586,7 +750,8 @@ void testGetQueryResponseWithSuccess() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); JSONObject queryResult = new JSONObject(); Map resultMap = new HashMap<>(); resultMap.put(STATUS_FIELD, "SUCCESS"); @@ -623,14 +788,15 @@ void testGetQueryResponseOfDropIndex() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); String jobId = new SparkQueryDispatcher.DropIndexResult(JobRunState.SUCCESS.toString()).toJobId(); JSONObject result = sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, jobId, true, null)); + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, jobId, true, null, null)); verify(jobExecutionResponseReader, times(0)) .getResultFromOpensearchIndex(anyString(), anyString()); Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD)); @@ -997,6 +1163,24 @@ private DispatchQueryRequest constructDispatchQueryRequest( langType, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME, - extraParameters); + extraParameters, + null); + } + + private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, String sessionId) { + return new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME, + null, + sessionId); + } + + private AsyncQueryJobMetadata asyncQueryJobMetadataWithSessionId( + String queryId, String sessionId) { + return new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, queryId, false, null, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 488252d05a..429c970365 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; +import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; @@ -114,7 +115,7 @@ public void closeNotExistSession() { @Test public void sessionManagerCreateSession() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); TestSession testSession = testSession(session, stateStore); @@ -123,7 +124,8 @@ public void sessionManagerCreateSession() { @Test public void sessionManagerGetSession() { - SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + SessionManager sessionManager = + new SessionManager(stateStore, emrsClient, sessionSetting(false)); Session session = sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource")); @@ -134,7 +136,8 @@ public void sessionManagerGetSession() { @Test public void sessionManagerGetSessionNotExist() { - SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + SessionManager sessionManager = + new SessionManager(stateStore, emrsClient, sessionSetting(false)); Optional managerSession = sessionManager.getSession(new SessionId("no-exist")); assertTrue(managerSession.isEmpty()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index 95b85613be..4374bd4f11 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -5,25 +5,48 @@ package org.opensearch.sql.spark.execution.session; -import org.junit.After; -import org.junit.Before; -import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.test.OpenSearchSingleNodeTestCase; -class SessionManagerTest extends OpenSearchSingleNodeTestCase { - private static final String indexName = "mockindex"; +@ExtendWith(MockitoExtension.class) +public class SessionManagerTest { + @Mock private StateStore stateStore; + @Mock private EMRServerlessClient emrClient; - private StateStore stateStore; + @Test + public void sessionEnable() { + Assertions.assertTrue( + new SessionManager(stateStore, emrClient, sessionSetting(true)).isEnabled()); + Assertions.assertFalse( + new SessionManager(stateStore, emrClient, sessionSetting(false)).isEnabled()); + } - @Before - public void setup() { - stateStore = new StateStore(indexName, client()); - createIndex(indexName); + public static Settings sessionSetting(boolean enabled) { + Map settings = new HashMap<>(); + settings.put(Settings.Key.SPARK_EXECUTION_SESSION_ENABLED, enabled); + return settings(settings); } - @After - public void clean() { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + public static Settings settings(Map settings) { + return new Settings() { + @Override + public T getSettingValue(Key key) { + return (T) settings.get(key); + } + + @Override + public List getSettings() { + return (List) settings; + } + }; } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index b0bc84219b..85151b1314 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.execution.statement; +import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; @@ -184,7 +185,7 @@ public void cancelRunningStatementFailed() { @Test public void submitStatementInRunningSession() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running @@ -197,7 +198,7 @@ public void submitStatementInRunningSession() { @Test public void failToSubmitStatementInStartingSession() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); IllegalStateException exception = @@ -213,7 +214,7 @@ public void failToSubmitStatementInStartingSession() { @Test public void failToSubmitStatementInDeletedSession() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // other's delete session @@ -231,7 +232,7 @@ public void failToSubmitStatementInDeletedSession() { @Test public void getStatementSuccess() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); @@ -246,7 +247,7 @@ public void getStatementSuccess() { @Test public void getStatementNotExist() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index 2ff76b9b57..b6a20405aa 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -10,6 +10,7 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; import java.util.HashSet; import org.junit.jupiter.api.Assertions; @@ -53,7 +54,8 @@ public void setUp() { @Test public void testDoExecute() { - CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); + CancelAsyncQueryActionRequest request = + new CancelAsyncQueryActionRequest(EMR_JOB_ID, MOCK_SESSION_ID); when(asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)).thenReturn(EMR_JOB_ID); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(deleteJobActionResponseArgumentCaptor.capture()); @@ -65,7 +67,8 @@ public void testDoExecute() { @Test public void testDoExecuteWithException() { - CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); + CancelAsyncQueryActionRequest request = + new CancelAsyncQueryActionRequest(EMR_JOB_ID, MOCK_SESSION_ID); doThrow(new RuntimeException("Error")).when(asyncQueryExecutorService).cancelQuery(EMR_JOB_ID); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 21a213c7c2..61b7ea746f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -14,6 +14,7 @@ import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -61,9 +62,10 @@ public void setUp() { @Test public void testDoExecute() { - GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); + GetAsyncQueryResultActionRequest request = + new GetAsyncQueryResultActionRequest("jobId", MOCK_SESSION_ID); AsyncQueryExecutionResponse asyncQueryExecutionResponse = - new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null); + new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); @@ -76,7 +78,8 @@ public void testDoExecute() { @Test public void testDoExecuteWithSuccessResponse() { - GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); + GetAsyncQueryResultActionRequest request = + new GetAsyncQueryResultActionRequest("jobId", MOCK_SESSION_ID); ExecutionEngine.Schema schema = new ExecutionEngine.Schema( ImmutableList.of( @@ -89,6 +92,7 @@ public void testDoExecuteWithSuccessResponse() { Arrays.asList( tupleValue(ImmutableMap.of("name", "John", "age", 20)), tupleValue(ImmutableMap.of("name", "Smith", "age", 30))), + null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); @@ -126,7 +130,8 @@ public void testDoExecuteWithSuccessResponse() { @Test public void testDoExecuteWithException() { - GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("123"); + GetAsyncQueryResultActionRequest request = + new GetAsyncQueryResultActionRequest("123", MOCK_SESSION_ID); doThrow(new AsyncQueryNotFoundException("JobId 123 not found")) .when(jobExecutorService) .getAsyncQueryResults("123"); @@ -137,4 +142,20 @@ public void testDoExecuteWithException() { Assertions.assertTrue(exception instanceof RuntimeException); Assertions.assertEquals("JobId 123 not found", exception.getMessage()); } + + @Test + public void testDoExecuteWithSessionId() { + GetAsyncQueryResultActionRequest request = + new GetAsyncQueryResultActionRequest("jobId", MOCK_SESSION_ID); + AsyncQueryExecutionResponse asyncQueryExecutionResponse = + new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null, MOCK_SESSION_ID); + when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); + action.doExecute(task, request, actionListener); + verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + GetAsyncQueryResultActionResponse getAsyncQueryResultActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals( + "{\n" + " \"status\": \"IN_PROGRESS\",\n" + " \"sessionId\": \"s-0123456\"\n" + "}", + getAsyncQueryResultActionResponse.getResult()); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java index 711db75efb..a197c6cba1 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java @@ -5,6 +5,7 @@ import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.COMPACT; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -30,6 +31,7 @@ void formatAsyncQueryResponse() { Arrays.asList( tupleValue(ImmutableMap.of("firstname", "John", "age", 20)), tupleValue(ImmutableMap.of("firstname", "Smith", "age", 30))), + null, null); AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); assertEquals( @@ -41,8 +43,27 @@ void formatAsyncQueryResponse() { @Test void formatAsyncQueryError() { - AsyncQueryResult response = new AsyncQueryResult("FAILED", null, null, "foo"); + AsyncQueryResult response = new AsyncQueryResult("FAILED", null, null, "foo", null); AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); assertEquals("{\"status\":\"FAILED\",\"error\":\"foo\"}", formatter.format(response)); } + + @Test + void formatAsyncQueryResponseWithSessionId() { + AsyncQueryResult response = + new AsyncQueryResult( + "success", + schema, + Arrays.asList( + tupleValue(ImmutableMap.of("firstname", "John", "age", 20)), + tupleValue(ImmutableMap.of("firstname", "Smith", "age", 30))), + null, + MOCK_SESSION_ID); + AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); + assertEquals( + "{\"status\":\"success\",\"schema\":[{\"name\":\"firstname\",\"type\":\"string\"}," + + "{\"name\":\"age\",\"type\":\"integer\"}],\"datarows\":" + + "[[\"John\",20],[\"Smith\",30]],\"total\":2,\"size\":2,\"sessionId\":\"s-0123456\"}", + formatter.format(response)); + } } From eafbeb4f2ae8ec9abf5525348ef34c847c2f27fa Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Mon, 16 Oct 2023 17:49:43 -0700 Subject: [PATCH 06/20] address comments Signed-off-by: Peng Huo --- .../execution/session/InteractiveSession.java | 9 ++- .../spark/execution/session/SessionState.java | 4 ++ .../spark/execution/statement/Statement.java | 5 +- .../execution/statement/StatementModel.java | 26 +++++++- .../execution/statement/StatementState.java | 1 + .../execution/statement/StatementTest.java | 66 ++++++++++++++++++- 6 files changed, 104 insertions(+), 7 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 101cc7f5f1..e33ef4245a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; +import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE; import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; @@ -73,11 +74,13 @@ public StatementId submit(QueryRequest request) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { sessionModel = model.get(); - if (sessionModel.getSessionState() == SessionState.RUNNING) { + if (!END_STATE.contains(sessionModel.getSessionState())) { StatementId statementId = newStatementId(); Statement st = Statement.builder() .sessionId(sessionId) + .applicationId(sessionModel.getApplicationId()) + .jobId(sessionModel.getJobId()) .stateStore(stateStore) .statementId(statementId) .langType(LangType.SQL) @@ -89,7 +92,7 @@ public StatementId submit(QueryRequest request) { } else { String errMsg = String.format( - "can't submit statement, session should in running state, " + "can't submit statement, session should not be in end state, " + "current session state is: %s", sessionModel.getSessionState().getSessionState()); LOG.debug(errMsg); @@ -106,6 +109,8 @@ public Optional get(StatementId stID) { model -> Statement.builder() .sessionId(sessionId) + .applicationId(model.getApplicationId()) + .jobId(model.getJobId()) .statementId(model.getStatementId()) .langType(model.getLangType()) .query(model.getQuery()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java index 509d5105e9..a4da957f12 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -5,7 +5,9 @@ package org.opensearch.sql.spark.execution.session; +import com.google.common.collect.ImmutableList; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -17,6 +19,8 @@ public enum SessionState { DEAD("dead"), FAIL("fail"); + public static List END_STATE = ImmutableList.of(DEAD, FAIL); + private final String sessionState; SessionState(String sessionState) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index 4c54393379..8fcedb5fca 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -28,6 +28,8 @@ public class Statement { private static final Logger LOG = LogManager.getLogger(); private final SessionId sessionId; + private final String applicationId; + private final String jobId; private final StatementId statementId; private final LangType langType; private final String query; @@ -39,7 +41,8 @@ public class Statement { /** Open a statement. */ public void open() { try { - statementModel = submitStatement(sessionId, statementId, langType, query, queryId); + statementModel = + submitStatement(sessionId, applicationId, jobId, statementId, langType, query, queryId); statementModel = createStatement(stateStore).apply(statementModel); } catch (VersionConflictEngineException e) { String errorMsg = "statement already exist. " + statementId; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java index b57868964e..c7f681c541 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -5,6 +5,8 @@ package org.opensearch.sql.spark.execution.statement; +import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID; +import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import java.io.IOException; @@ -40,6 +42,8 @@ public class StatementModel extends StateModel { private final StatementState statementState; private final StatementId statementId; private final SessionId sessionId; + private final String applicationId; + private final String jobId; private final LangType langType; private final String query; private final String queryId; @@ -58,6 +62,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(STATEMENT_STATE, statementState.getState()) .field(STATEMENT_ID, statementId.getId()) .field(SESSION_ID, sessionId.getSessionId()) + .field(APPLICATION_ID, applicationId) + .field(JOB_ID, jobId) .field(LANG, langType.getText()) .field(QUERY, query) .field(QUERY_ID, queryId) @@ -73,6 +79,8 @@ public static StatementModel copy(StatementModel copy, long seqNo, long primaryT .statementState(copy.statementState) .statementId(copy.statementId) .sessionId(copy.sessionId) + .applicationId(copy.applicationId) + .jobId(copy.jobId) .langType(copy.langType) .query(copy.query) .queryId(copy.queryId) @@ -90,6 +98,8 @@ public static StatementModel copyWithState( .statementState(state) .statementId(copy.statementId) .sessionId(copy.sessionId) + .applicationId(copy.applicationId) + .jobId(copy.jobId) .langType(copy.langType) .query(copy.query) .queryId(copy.queryId) @@ -124,6 +134,12 @@ public static StatementModel fromXContent(XContentParser parser, long seqNo, lon case SESSION_ID: builder.sessionId(new SessionId(parser.text())); break; + case APPLICATION_ID: + builder.applicationId(parser.text()); + break; + case JOB_ID: + builder.jobId(parser.text()); + break; case LANG: builder.langType(LangType.fromString(parser.text())); break; @@ -147,12 +163,20 @@ public static StatementModel fromXContent(XContentParser parser, long seqNo, lon } public static StatementModel submitStatement( - SessionId sid, StatementId statementId, LangType langType, String query, String queryId) { + SessionId sid, + String applicationId, + String jobId, + StatementId statementId, + LangType langType, + String query, + String queryId) { return builder() .version("1.0") .statementState(WAITING) .statementId(statementId) .sessionId(sid) + .applicationId(applicationId) + .jobId(jobId) .langType(langType) .query(query) .queryId(queryId) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java index 87ad6b11ae..33f7f5e831 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java @@ -10,6 +10,7 @@ import java.util.stream.Collectors; import lombok.Getter; +/** {@link Statement} State. */ @Getter public enum StatementState { WAITING("waiting"), diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index b0bc84219b..331955e14e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -57,6 +57,8 @@ public void openThenCancelStatement() { Statement st = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) .query("query") @@ -80,6 +82,8 @@ public void openFailedBecauseConflict() { Statement st = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) .query("query") @@ -92,6 +96,8 @@ public void openFailedBecauseConflict() { Statement dupSt = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) .query("query") @@ -108,6 +114,8 @@ public void cancelNotExistStatement() { Statement st = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(stId) .langType(LangType.SQL) .query("query") @@ -130,6 +138,8 @@ public void cancelFailedBecauseOfConflict() { Statement st = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(stId) .langType(LangType.SQL) .query("query") @@ -157,6 +167,8 @@ public void cancelRunningStatementFailed() { Statement st = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(stId) .langType(LangType.SQL) .query("query") @@ -195,21 +207,69 @@ public void submitStatementInRunningSession() { } @Test - public void failToSubmitStatementInStartingSession() { + public void submitStatementInNotStartedState() { Session session = new SessionManager(stateStore, emrsClient) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + assertFalse(statementId.getId().isEmpty()); + } + + @Test + public void failToSubmitStatementInDeadState() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertEquals( + "can't submit statement, session should not be in end state, current session state is:" + + " dead", + exception.getMessage()); + } + + @Test + public void failToSubmitStatementInFailState() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); + IllegalStateException exception = assertThrows( IllegalStateException.class, () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); assertEquals( - "can't submit statement, session should in running state, current session state is:" - + " not_started", + "can't submit statement, session should not be in end state, current session state is:" + + " fail", exception.getMessage()); } + @Test + public void newStatementFieldAssert() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + Optional statement = session.get(statementId); + + assertTrue(statement.isPresent()); + assertEquals(session.getSessionId(), statement.get().getSessionId()); + assertEquals("appId", statement.get().getApplicationId()); + assertEquals("jobId", statement.get().getJobId()); + assertEquals(statementId, statement.get().getStatementId()); + assertEquals(WAITING, statement.get().getStatementState()); + assertEquals(LangType.SQL, statement.get().getLangType()); + assertEquals("select 1", statement.get().getQuery()); + } + @Test public void failToSubmitStatementInDeletedSession() { Session session = From df4b64a1aabb9cf13211640a22da02857b16120f Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Mon, 16 Oct 2023 21:31:50 -0700 Subject: [PATCH 07/20] update Signed-off-by: Peng Huo --- .../AsyncQueryExecutorServiceImpl.java | 2 +- .../dispatcher/SparkQueryDispatcher.java | 34 +++++++++++++++---- .../spark/execution/session/SessionId.java | 2 +- .../execution/statement/StatementId.java | 2 +- .../rest/model/CreateAsyncQueryResponse.java | 2 ++ .../execution/statement/StatementTest.java | 6 ++-- ...portCreateAsyncQueryRequestActionTest.java | 22 +++++++++++- 7 files changed, 56 insertions(+), 14 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 7234170a97..36ca6ea7c8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -74,7 +74,7 @@ public CreateAsyncQueryResponse createAsyncQuery( dispatchQueryResponse.isDropIndexQuery(), dispatchQueryResponse.getResultIndex(), dispatchQueryResponse.getSessionId())); - return new CreateAsyncQueryResponse(dispatchQueryResponse.getJobId()); + return new CreateAsyncQueryResponse(dispatchQueryResponse.getJobId(), dispatchQueryResponse.getSessionId()); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 194a075edf..d973b809b0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -47,6 +47,7 @@ import org.opensearch.sql.spark.execution.statement.QueryRequest; import org.opensearch.sql.spark.execution.statement.Statement; import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -121,13 +122,32 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) String error = items.optString(ERROR_FIELD, ""); result.put(ERROR_FIELD, error); } else { - // make call to EMR Serverless when related result index documents are not available - GetJobRunResult getJobRunResult = - emrServerlessClient.getJobRunResult( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - String jobState = getJobRunResult.getJobRun().getState(); - result.put(STATUS_FIELD, jobState); - result.put(ERROR_FIELD, ""); + if (sessionManager.isEnabled() && asyncQueryJobMetadata.getSessionId() != null) { + SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + StatementState statementState = statement.get().getStatementState(); + result.put(STATUS_FIELD, statementState.getState()); + result.put(ERROR_FIELD, ""); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } + } else { + // make call to EMR Serverless when related result index documents are not available + GetJobRunResult getJobRunResult = + emrServerlessClient.getJobRunResult( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + String jobState = getJobRunResult.getJobRun().getState(); + result.put(STATUS_FIELD, jobState); + result.put(ERROR_FIELD, ""); + } } return result; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java index a2847cde18..f60e355c51 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -13,7 +13,7 @@ public class SessionId { private final String sessionId; public static SessionId newSessionId() { - return new SessionId(RandomStringUtils.random(10, true, true)); + return new SessionId(RandomStringUtils.random(16, true, true)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java index 4baff71493..8f2aaaf1cb 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java @@ -13,7 +13,7 @@ public class StatementId { private final String id; public static StatementId newStatementId() { - return new StatementId(RandomStringUtils.random(10, true, true)); + return new StatementId(RandomStringUtils.random(16, true, true)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java index 8cfe57c2a6..2f918308c4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java @@ -12,4 +12,6 @@ @AllArgsConstructor public class CreateAsyncQueryResponse { private String queryId; + // optional sessionId + private String sessionId; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index a8fe41c0a1..214bcb8258 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -220,7 +220,7 @@ public void submitStatementInNotStartedState() { @Test public void failToSubmitStatementInDeadState() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); @@ -238,7 +238,7 @@ public void failToSubmitStatementInDeadState() { @Test public void failToSubmitStatementInFailState() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); @@ -256,7 +256,7 @@ public void failToSubmitStatementInFailState() { @Test public void newStatementFieldAssert() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 8599e4b88e..32eec61ff3 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; import java.util.HashSet; import org.junit.jupiter.api.Assertions; @@ -61,7 +62,7 @@ public void testDoExecute() { CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) - .thenReturn(new CreateAsyncQueryResponse("123")); + .thenReturn(new CreateAsyncQueryResponse("123", null)); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); CreateAsyncQueryActionResponse createAsyncQueryActionResponse = @@ -70,6 +71,25 @@ public void testDoExecute() { "{\n" + " \"queryId\": \"123\"\n" + "}", createAsyncQueryActionResponse.getResult()); } + @Test + public void testDoExecuteWithSessionId() { + CreateAsyncQueryRequest createAsyncQueryRequest = + new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "my_glue", LangType.SQL); + CreateAsyncQueryActionRequest request = + new CreateAsyncQueryActionRequest(createAsyncQueryRequest); + when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) + .thenReturn(new CreateAsyncQueryResponse("123", MOCK_SESSION_ID)); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + CreateAsyncQueryActionResponse createAsyncQueryActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals( + "{\n" + + " \"queryId\": \"123\",\n" + + " \"sessionId\": \"s-0123456\"\n" + + "}", createAsyncQueryActionResponse.getResult()); + } + @Test public void testDoExecuteWithException() { CreateAsyncQueryRequest createAsyncQueryRequest = From cb500a874cce15e3af3b881e34c4529bb3a60cfd Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 17 Oct 2023 07:49:46 -0700 Subject: [PATCH 08/20] Update REST and Transport interface Signed-off-by: Peng Huo --- .../sql/spark/rest/RestAsyncQueryManagementAction.java | 6 ++---- .../sql/spark/rest/model/CreateAsyncQueryRequest.java | 1 + .../transport/model/CancelAsyncQueryActionRequest.java | 2 -- .../transport/model/GetAsyncQueryResultActionRequest.java | 2 -- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index 518b554792..1683ff7483 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -143,14 +143,13 @@ public void onFailure(Exception e) { private RestChannelConsumer executeGetAsyncQueryResultRequest( RestRequest restRequest, NodeClient nodeClient) { String queryId = restRequest.param("queryId"); - String sessionId = restRequest.param(PARAMS_SESSION_ID, null); return restChannel -> Scheduler.schedule( nodeClient, () -> nodeClient.execute( TransportGetAsyncQueryResultAction.ACTION_TYPE, - new GetAsyncQueryResultActionRequest(queryId, sessionId), + new GetAsyncQueryResultActionRequest(queryId), new ActionListener<>() { @Override public void onResponse( @@ -185,14 +184,13 @@ private void handleException(Exception e, RestChannel restChannel) { private RestChannelConsumer executeDeleteRequest(RestRequest restRequest, NodeClient nodeClient) { String queryId = restRequest.param("queryId"); - String sessionId = restRequest.param(PARAMS_SESSION_ID, null); return restChannel -> Scheduler.schedule( nodeClient, () -> nodeClient.execute( TransportCancelAsyncQueryRequestAction.ACTION_TYPE, - new CancelAsyncQueryActionRequest(queryId, sessionId), + new CancelAsyncQueryActionRequest(queryId), new ActionListener<>() { @Override public void onResponse( diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index 0f1abe4bb2..ba2a0cf2ee 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -18,6 +18,7 @@ public class CreateAsyncQueryRequest { private String query; private String datasource; private LangType lang; + // optional sessionId private String sessionId; public CreateAsyncQueryRequest(String query, String datasource, LangType lang) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java index 3db8f4009e..0065b575ed 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java @@ -20,8 +20,6 @@ public class CancelAsyncQueryActionRequest extends ActionRequest { private String queryId; - private String sessionId; - /** Constructor of SubmitJobActionRequest from StreamInput. */ public CancelAsyncQueryActionRequest(StreamInput in) throws IOException { super(in); diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java index 1ccc8eea31..06faa75a26 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java @@ -19,8 +19,6 @@ public class GetAsyncQueryResultActionRequest extends ActionRequest { @Getter private String queryId; - @Getter private String sessionId; - /** Constructor of GetJobQueryResultActionRequest from StreamInput. */ public GetAsyncQueryResultActionRequest(StreamInput in) throws IOException { super(in); From 479373d1681518159b0c0b3ac974e7526352c96a Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 17 Oct 2023 09:06:46 -0700 Subject: [PATCH 09/20] Revert on transport layer Signed-off-by: Peng Huo --- docs/user/admin/settings.rst | 36 +++++++++++++ .../setting/OpenSearchSettings.java | 2 +- .../asyncquery/model/AsyncQueryResult.java | 9 +--- .../dispatcher/SparkQueryDispatcher.java | 4 +- .../TransportGetAsyncQueryResultAction.java | 3 +- .../AsyncQueryResultResponseFormatter.java | 4 -- .../dispatcher/SparkQueryDispatcherTest.java | 52 +++++++++++++++++-- ...portCancelAsyncQueryRequestActionTest.java | 7 +-- ...ransportGetAsyncQueryResultActionTest.java | 26 ++-------- ...AsyncQueryResultResponseFormatterTest.java | 23 +------- 10 files changed, 95 insertions(+), 71 deletions(-) diff --git a/docs/user/admin/settings.rst b/docs/user/admin/settings.rst index b5da4e28e2..b6cdf238f2 100644 --- a/docs/user/admin/settings.rst +++ b/docs/user/admin/settings.rst @@ -311,3 +311,39 @@ SQL query:: "status": 400 } +plugins.query.executionengine.spark.session.enabled +=================================================== + +Description +----------- + +By default, execution engine is executed in job mode. You can enable session mode by this setting. + +1. The default value is false. +2. This setting is node scope. +3. This setting can be updated dynamically. + +You can update the setting with a new value like this. + +SQL query:: + + sh$ curl -sS -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings \ + ... -d '{"transient":{"plugins.query.executionengine.spark.session.enabled":"false"}}' + { + "acknowledged": true, + "persistent": {}, + "transient": { + "plugins": { + "query": { + "executionengine": { + "spark": { + "session": { + "enabled": "false" + } + } + } + } + } + } + } + diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index ec5bc7dfc0..ecb35afafa 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -138,7 +138,7 @@ public class OpenSearchSettings extends Settings { public static final Setting SPARK_EXECUTION_SESSION_ENABLED_SETTING = Setting.boolSetting( Key.SPARK_EXECUTION_SESSION_ENABLED.getKeyValue(), - true, + false, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java index 7fda8aefd8..c229aa3920 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java @@ -12,30 +12,25 @@ public class AsyncQueryResult extends QueryResult { @Getter private final String status; @Getter private final String error; - @Getter private final String sessionId; public AsyncQueryResult( String status, ExecutionEngine.Schema schema, Collection exprValues, Cursor cursor, - String error, - String sessionId) { + String error) { super(schema, exprValues, cursor); this.status = status; this.error = error; - this.sessionId = sessionId; } public AsyncQueryResult( String status, ExecutionEngine.Schema schema, Collection exprValues, - String error, - String sessionId) { + String error) { super(schema, exprValues); this.status = status; this.error = error; - this.sessionId = sessionId; } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index d973b809b0..8d5ae10e91 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -122,7 +122,7 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) String error = items.optString(ERROR_FIELD, ""); result.put(ERROR_FIELD, error); } else { - if (sessionManager.isEnabled() && asyncQueryJobMetadata.getSessionId() != null) { + if (asyncQueryJobMetadata.getSessionId() != null) { SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); Optional session = sessionManager.getSession(sessionId); if (session.isPresent()) { @@ -154,7 +154,7 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - if (sessionManager.isEnabled() && asyncQueryJobMetadata.getSessionId() != null) { + if (asyncQueryJobMetadata.getSessionId() != null) { SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); Optional session = sessionManager.getSession(sessionId); if (session.isPresent()) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java index 72ddcd98e1..5c784cf04c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java @@ -62,8 +62,7 @@ protected void doExecute( asyncQueryExecutionResponse.getSchema(), asyncQueryExecutionResponse.getResults(), Cursor.None, - asyncQueryExecutionResponse.getError(), - asyncQueryExecutionResponse.getSessionId())); + asyncQueryExecutionResponse.getError())); listener.onResponse(new GetAsyncQueryResultActionResponse(responseContent)); } catch (Exception e) { listener.onFailure(e); diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java b/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java index 9f11f5499f..3a2a5b110f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java @@ -57,9 +57,6 @@ public Object buildJsonObject(AsyncQueryResult response) { if (!Strings.isEmpty(response.getError())) { json.error(response.getError()); } - if (response.getSessionId() != null) { - json.sessionId(response.getSessionId()); - } return json.build(); } @@ -88,7 +85,6 @@ public static class JsonResponse { private Integer total; private Integer size; private final String error; - private final String sessionId; } @RequiredArgsConstructor diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 52caa7e3c4..2d240cc7fc 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -68,6 +68,7 @@ import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statement.Statement; import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.flint.FlintIndexType; @@ -662,7 +663,6 @@ void testCancelJob() { @Test void testCancelQueryWithSession() { - doReturn(true).when(sessionManager).isEnabled(); doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); doReturn(Optional.of(statement)).when(session).get(any()); doNothing().when(statement).cancel(); @@ -678,7 +678,6 @@ void testCancelQueryWithSession() { @Test void testCancelQueryWithInvalidSession() { - doReturn(true).when(sessionManager).isEnabled(); doReturn(Optional.empty()).when(sessionManager).getSession(new SessionId("invalid")); IllegalArgumentException exception = @@ -696,7 +695,6 @@ void testCancelQueryWithInvalidSession() { @Test void testCancelQueryWithInvalidStatementId() { - doReturn(true).when(sessionManager).isEnabled(); doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); IllegalArgumentException exception = @@ -714,8 +712,6 @@ void testCancelQueryWithInvalidStatementId() { @Test void testCancelQueryWithNoSessionId() { - doReturn(true).when(sessionManager).isEnabled(); - when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn( new CancelJobRunResult() @@ -741,6 +737,52 @@ void testGetQueryResponse() { Assertions.assertEquals("PENDING", result.get("status")); } + @Test + void testGetQueryResponseWithSession() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.of(statement)).when(session).get(any()); + doReturn(StatementState.WAITING).when(statement).getStatementState(); + + doReturn(new JSONObject()).when(jobExecutionResponseReader).getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), + any()); + JSONObject result = sparkQueryDispatcher + .getQueryResponse(asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals("waiting", result.get("status")); + } + + @Test + void testGetQueryResponseWithInvalidSession() { + doReturn(Optional.empty()).when(sessionManager).getSession(eq(new SessionId(MOCK_SESSION_ID))); + doReturn(new JSONObject()).when(jobExecutionResponseReader).getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), + any()); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkQueryDispatcher + .getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals( + "no session found. " + new SessionId(MOCK_SESSION_ID), exception.getMessage()); + } + + @Test + void testGetQueryResponseWithStatementNotExist() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.empty()).when(session).get(any()); + doReturn(new JSONObject()).when(jobExecutionResponseReader).getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), + any()); + + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkQueryDispatcher + .getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals( + "no statement found. " + new StatementId(MOCK_STATEMENT_ID), exception.getMessage()); + } + @Test void testGetQueryResponseWithSuccess() { SparkQueryDispatcher sparkQueryDispatcher = diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index b6a20405aa..2ff76b9b57 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -10,7 +10,6 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; -import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; import java.util.HashSet; import org.junit.jupiter.api.Assertions; @@ -54,8 +53,7 @@ public void setUp() { @Test public void testDoExecute() { - CancelAsyncQueryActionRequest request = - new CancelAsyncQueryActionRequest(EMR_JOB_ID, MOCK_SESSION_ID); + CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); when(asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)).thenReturn(EMR_JOB_ID); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(deleteJobActionResponseArgumentCaptor.capture()); @@ -67,8 +65,7 @@ public void testDoExecute() { @Test public void testDoExecuteWithException() { - CancelAsyncQueryActionRequest request = - new CancelAsyncQueryActionRequest(EMR_JOB_ID, MOCK_SESSION_ID); + CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); doThrow(new RuntimeException("Error")).when(asyncQueryExecutorService).cancelQuery(EMR_JOB_ID); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 61b7ea746f..34f10b0083 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -14,7 +14,6 @@ import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -62,8 +61,7 @@ public void setUp() { @Test public void testDoExecute() { - GetAsyncQueryResultActionRequest request = - new GetAsyncQueryResultActionRequest("jobId", MOCK_SESSION_ID); + GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); AsyncQueryExecutionResponse asyncQueryExecutionResponse = new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); @@ -78,8 +76,7 @@ public void testDoExecute() { @Test public void testDoExecuteWithSuccessResponse() { - GetAsyncQueryResultActionRequest request = - new GetAsyncQueryResultActionRequest("jobId", MOCK_SESSION_ID); + GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); ExecutionEngine.Schema schema = new ExecutionEngine.Schema( ImmutableList.of( @@ -130,8 +127,7 @@ public void testDoExecuteWithSuccessResponse() { @Test public void testDoExecuteWithException() { - GetAsyncQueryResultActionRequest request = - new GetAsyncQueryResultActionRequest("123", MOCK_SESSION_ID); + GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("123"); doThrow(new AsyncQueryNotFoundException("JobId 123 not found")) .when(jobExecutorService) .getAsyncQueryResults("123"); @@ -142,20 +138,4 @@ public void testDoExecuteWithException() { Assertions.assertTrue(exception instanceof RuntimeException); Assertions.assertEquals("JobId 123 not found", exception.getMessage()); } - - @Test - public void testDoExecuteWithSessionId() { - GetAsyncQueryResultActionRequest request = - new GetAsyncQueryResultActionRequest("jobId", MOCK_SESSION_ID); - AsyncQueryExecutionResponse asyncQueryExecutionResponse = - new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null, MOCK_SESSION_ID); - when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); - action.doExecute(task, request, actionListener); - verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); - GetAsyncQueryResultActionResponse getAsyncQueryResultActionResponse = - createJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals( - "{\n" + " \"status\": \"IN_PROGRESS\",\n" + " \"sessionId\": \"s-0123456\"\n" + "}", - getAsyncQueryResultActionResponse.getResult()); - } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java index a197c6cba1..711db75efb 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java @@ -5,7 +5,6 @@ import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.COMPACT; -import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -31,7 +30,6 @@ void formatAsyncQueryResponse() { Arrays.asList( tupleValue(ImmutableMap.of("firstname", "John", "age", 20)), tupleValue(ImmutableMap.of("firstname", "Smith", "age", 30))), - null, null); AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); assertEquals( @@ -43,27 +41,8 @@ void formatAsyncQueryResponse() { @Test void formatAsyncQueryError() { - AsyncQueryResult response = new AsyncQueryResult("FAILED", null, null, "foo", null); + AsyncQueryResult response = new AsyncQueryResult("FAILED", null, null, "foo"); AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); assertEquals("{\"status\":\"FAILED\",\"error\":\"foo\"}", formatter.format(response)); } - - @Test - void formatAsyncQueryResponseWithSessionId() { - AsyncQueryResult response = - new AsyncQueryResult( - "success", - schema, - Arrays.asList( - tupleValue(ImmutableMap.of("firstname", "John", "age", 20)), - tupleValue(ImmutableMap.of("firstname", "Smith", "age", 30))), - null, - MOCK_SESSION_ID); - AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); - assertEquals( - "{\"status\":\"success\",\"schema\":[{\"name\":\"firstname\",\"type\":\"string\"}," - + "{\"name\":\"age\",\"type\":\"integer\"}],\"datarows\":" - + "[[\"John\",20],[\"Smith\",30]],\"total\":2,\"size\":2,\"sessionId\":\"s-0123456\"}", - formatter.format(response)); - } } From a6394eed9a28d5c601838817f10455cb5fa6f9ec Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 17 Oct 2023 09:37:53 -0700 Subject: [PATCH 10/20] format code Signed-off-by: Peng Huo --- .../AsyncQueryExecutorServiceImpl.java | 3 +- .../dispatcher/SparkQueryDispatcherTest.java | 40 +++++++++++-------- ...portCreateAsyncQueryRequestActionTest.java | 6 +-- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 36ca6ea7c8..7cba2757cc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -74,7 +74,8 @@ public CreateAsyncQueryResponse createAsyncQuery( dispatchQueryResponse.isDropIndexQuery(), dispatchQueryResponse.getResultIndex(), dispatchQueryResponse.getSessionId())); - return new CreateAsyncQueryResponse(dispatchQueryResponse.getJobId(), dispatchQueryResponse.getSessionId()); + return new CreateAsyncQueryResponse( + dispatchQueryResponse.getJobId(), dispatchQueryResponse.getSessionId()); } @Override diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 9df0e5e0f6..58fe626dae 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -724,10 +724,12 @@ void testGetQueryResponseWithSession() { doReturn(Optional.of(statement)).when(session).get(any()); doReturn(StatementState.WAITING).when(statement).getStatementState(); - doReturn(new JSONObject()).when(jobExecutionResponseReader).getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), - any()); - JSONObject result = sparkQueryDispatcher - .getQueryResponse(asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + JSONObject result = + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); verifyNoInteractions(emrServerlessClient); Assertions.assertEquals("waiting", result.get("status")); @@ -736,12 +738,15 @@ void testGetQueryResponseWithSession() { @Test void testGetQueryResponseWithInvalidSession() { doReturn(Optional.empty()).when(sessionManager).getSession(eq(new SessionId(MOCK_SESSION_ID))); - doReturn(new JSONObject()).when(jobExecutionResponseReader).getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), - any()); - IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, - () -> sparkQueryDispatcher - .getQueryResponse( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); verifyNoInteractions(emrServerlessClient); Assertions.assertEquals( @@ -752,13 +757,16 @@ void testGetQueryResponseWithInvalidSession() { void testGetQueryResponseWithStatementNotExist() { doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); doReturn(Optional.empty()).when(session).get(any()); - doReturn(new JSONObject()).when(jobExecutionResponseReader).getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), - any()); + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); - IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, - () -> sparkQueryDispatcher - .getQueryResponse( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); verifyNoInteractions(emrServerlessClient); Assertions.assertEquals( "no statement found. " + new StatementId(MOCK_STATEMENT_ID), exception.getMessage()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 32eec61ff3..2e48597593 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -84,10 +84,8 @@ public void testDoExecuteWithSessionId() { CreateAsyncQueryActionResponse createAsyncQueryActionResponse = createJobActionResponseArgumentCaptor.getValue(); Assertions.assertEquals( - "{\n" + - " \"queryId\": \"123\",\n" + - " \"sessionId\": \"s-0123456\"\n" + - "}", createAsyncQueryActionResponse.getResult()); + "{\n" + " \"queryId\": \"123\",\n" + " \"sessionId\": \"s-0123456\"\n" + "}", + createAsyncQueryActionResponse.getResult()); } @Test From b0b3c4d7ea61c59cee52b9835f4b9290916145b7 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 17 Oct 2023 09:53:14 -0700 Subject: [PATCH 11/20] add API doc Signed-off-by: Peng Huo --- docs/user/interfaces/asyncqueryinterface.rst | 43 ++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index a9fc77264c..5c64499c6f 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -62,6 +62,49 @@ Sample Response:: "queryId": "00fd796ut1a7eg0q" } +Execute query in session +------------------------ + +if plugins.query.executionengine.spark.session.enabled is set to true, session based execution is enabled. Under the hood, all queries submitted to the same session will be executed in the same SparkContext. Session is auto closed if not query submission in 10 minutes. + +Async query response include ``sessionId`` indicate the query is executed in session. + +Sample Request:: + + curl --location 'http://localhost:9200/_plugins/_async_query' \ + --header 'Content-Type: application/json' \ + --data '{ + "datasource" : "my_glue", + "lang" : "sql", + "query" : "select * from my_glue.default.http_logs limit 10" + }' + +Sample Response:: + + { + "queryId": "HlbM61kX6MDkAktO", + "sessionId": "1Giy65ZnzNlmsPAm" + } + +User could reuse the session by using ``sessionId`` query parameters. + +Sample Request:: + + curl --location 'http://localhost:9200/_plugins/_async_query?sessionId=1Giy65ZnzNlmsPAm' \ + --header 'Content-Type: application/json' \ + --data '{ + "datasource" : "my_glue", + "lang" : "sql", + "query" : "select * from my_glue.default.http_logs limit 10" + }' + +Sample Response:: + + { + "queryId": "7GC4mHhftiTejvxN", + "sessionId": "1Giy65ZnzNlmsPAm" + } + Async Query Result API ====================================== From 75e472c4401ae89c2ba87db9f1c862a8d7514619 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 17 Oct 2023 10:59:11 -0700 Subject: [PATCH 12/20] modify api Signed-off-by: Peng Huo --- docs/user/interfaces/asyncqueryinterface.rst | 5 +- .../rest/RestAsyncQueryManagementAction.java | 3 -- .../rest/model/CreateAsyncQueryRequest.java | 13 ++++- .../model/CreateAsyncQueryRequestTest.java | 52 +++++++++++++++++++ ...portCreateAsyncQueryRequestActionTest.java | 3 +- 5 files changed, 69 insertions(+), 7 deletions(-) create mode 100644 spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index 5c64499c6f..3fbc16d15f 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -90,12 +90,13 @@ User could reuse the session by using ``sessionId`` query parameters. Sample Request:: - curl --location 'http://localhost:9200/_plugins/_async_query?sessionId=1Giy65ZnzNlmsPAm' \ + curl --location 'http://localhost:9200/_plugins/_async_query' \ --header 'Content-Type: application/json' \ --data '{ "datasource" : "my_glue", "lang" : "sql", - "query" : "select * from my_glue.default.http_logs limit 10" + "query" : "select * from my_glue.default.http_logs limit 10", + "sessionId" : "1Giy65ZnzNlmsPAm" }' Sample Response:: diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index 1683ff7483..741501cd18 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -43,8 +43,6 @@ public class RestAsyncQueryManagementAction extends BaseRestHandler { public static final String ASYNC_QUERY_ACTIONS = "async_query_actions"; public static final String BASE_ASYNC_QUERY_ACTION_URL = "/_plugins/_async_query"; - public static final String PARAMS_SESSION_ID = "sessionId"; - private static final Logger LOG = LogManager.getLogger(RestAsyncQueryManagementAction.class); @Override @@ -114,7 +112,6 @@ private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClie throws IOException { CreateAsyncQueryRequest submitJobRequest = CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); - submitJobRequest.setSessionId(restRequest.param(PARAMS_SESSION_ID, null)); return restChannel -> Scheduler.schedule( nodeClient, diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index ba2a0cf2ee..6acf6bc9a8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.rest.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_ID; import java.io.IOException; import lombok.Data; @@ -27,11 +28,19 @@ public CreateAsyncQueryRequest(String query, String datasource, LangType lang) { this.lang = Validate.notNull(lang, "lang can't be null"); } + public CreateAsyncQueryRequest(String query, String datasource, LangType lang, String sessionId) { + this.query = Validate.notNull(query, "Query can't be null"); + this.datasource = Validate.notNull(datasource, "Datasource can't be null"); + this.lang = Validate.notNull(lang, "lang can't be null"); + this.sessionId = sessionId; + } + public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) throws IOException { String query = null; LangType lang = null; String datasource = null; + String sessionId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -43,10 +52,12 @@ public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) lang = LangType.fromString(langString); } else if (fieldName.equals("datasource")) { datasource = parser.textOrNull(); + } else if (fieldName.equals(SESSION_ID)) { + sessionId = parser.textOrNull(); } else { throw new IllegalArgumentException("Unknown field: " + fieldName); } } - return new CreateAsyncQueryRequest(query, datasource, lang); + return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java new file mode 100644 index 0000000000..dd634d6055 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +import java.io.IOException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; + +public class CreateAsyncQueryRequestTest { + + @Test + public void fromXContent() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\"\n" + + "}"; + CreateAsyncQueryRequest queryRequest = + CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + Assertions.assertEquals("my_glue", queryRequest.getDatasource()); + Assertions.assertEquals(LangType.SQL, queryRequest.getLang()); + Assertions.assertEquals("select 1", queryRequest.getQuery()); + } + + @Test + public void fromXContentWithSessionId() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\",\n" + + " \"sessionId\": \"00fdjevgkf12s00q\"\n" + + "}"; + CreateAsyncQueryRequest queryRequest = + CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + Assertions.assertEquals("00fdjevgkf12s00q", queryRequest.getSessionId()); + } + + private XContentParser xContentParser(String request) throws IOException { + return XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, request); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 2e48597593..36060d3850 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -74,7 +74,8 @@ public void testDoExecute() { @Test public void testDoExecuteWithSessionId() { CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "my_glue", LangType.SQL); + new CreateAsyncQueryRequest( + "source = my_glue.default.alb_logs", "my_glue", LangType.SQL, MOCK_SESSION_ID); CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) From 62771112ceade4942bf280cf7d0fe93546706cdc Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 17 Oct 2023 13:35:32 -0700 Subject: [PATCH 13/20] create query_execution_request index on demand Signed-off-by: Peng Huo --- .../org/opensearch/sql/plugin/SQLPlugin.java | 2 +- .../execution/statestore/StateStore.java | 56 +++++++++++++++++++ .../query_execution_request_mapping.yml | 38 +++++++++++++ .../query_execution_request_settings.yml | 11 ++++ .../session/InteractiveSessionTest.java | 14 +++-- .../execution/statement/StatementTest.java | 16 +++--- 6 files changed, 123 insertions(+), 14 deletions(-) create mode 100644 spark/src/main/resources/query_execution_request_mapping.yml create mode 100644 spark/src/main/resources/query_execution_request_settings.yml diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index a9a35f6318..009651025d 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -323,7 +323,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( new FlintIndexMetadataReaderImpl(client), client, new SessionManager( - new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client), + new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client, clusterService), emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index bd72b17353..8b5251e2a5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -6,14 +6,19 @@ package org.opensearch.sql.spark.execution.statestore; import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.util.Locale; import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Function; import lombok.RequiredArgsConstructor; +import org.apache.commons.io.IOUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; @@ -22,6 +27,9 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; @@ -35,13 +43,22 @@ @RequiredArgsConstructor public class StateStore { + public static Function SETTINGS_FILE_NAME = indexName -> String.format( + "%s_settings.yml", indexName.substring(indexName.indexOf('.') + 1)); + public static Function MAPPING_FILE_NAME = indexName -> String.format( + "%s_mapping.yml", indexName.substring(indexName.indexOf('.') + 1)); + private static final Logger LOG = LogManager.getLogger(); private final String indexName; private final Client client; + private final ClusterService clusterService; protected T create(T st, StateModel.CopyBuilder builder) { try { + if (!this.clusterService.state().routingTable().hasIndex(indexName)) { + createIndex(); + } IndexRequest indexRequest = new IndexRequest(indexName) .id(st.getId()) @@ -69,6 +86,10 @@ protected T create(T st, StateModel.CopyBuilder builde protected Optional get(String sid, StateModel.FromXContent builder) { try { + if (!this.clusterService.state().routingTable().hasIndex(indexName)) { + createIndex(); + return Optional.empty(); + } GetRequest getRequest = new GetRequest().index(indexName).id(sid); GetResponse getResponse = client.get(getRequest).actionGet(); if (getResponse.isExists()) { @@ -120,6 +141,41 @@ protected T updateState( } } + private void createIndex() { + try { + CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); + createIndexRequest + .mapping(loadConfigFromResource(MAPPING_FILE_NAME), XContentType.YAML) + .settings(loadConfigFromResource(SETTINGS_FILE_NAME), XContentType.YAML); + ActionFuture createIndexResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); + } + CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); + if (createIndexResponse.isAcknowledged()) { + LOG.info("Index: {} creation Acknowledged", indexName); + } else { + throw new RuntimeException("Index creation is not acknowledged."); + } + } catch (Throwable e) { + throw new RuntimeException( + "Internal server error while creating" + + indexName + + " index:: " + + e.getMessage()); + } + } + + private String loadConfigFromResource(Function indexToResource) + throws IOException { + InputStream fileStream = + StateStore.class + .getClassLoader() + .getResourceAsStream(indexToResource.apply(indexName)); + return IOUtils.toString(fileStream, StandardCharsets.UTF_8); + } + /** Helper Functions */ public static Function createStatement(StateStore stateStore) { return (st) -> stateStore.create(st, StatementModel::copy); diff --git a/spark/src/main/resources/query_execution_request_mapping.yml b/spark/src/main/resources/query_execution_request_mapping.yml new file mode 100644 index 0000000000..135910466e --- /dev/null +++ b/spark/src/main/resources/query_execution_request_mapping.yml @@ -0,0 +1,38 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Schema file for the .ql-job-metadata index +# Also "dynamic" is set to "false" so that other fields can be added. +dynamic: false +properties: + type: + type: keyword + state: + type: keyword + statementId: + type: keyword + applicationId: + type: keyword + sessionId: + type: keyword + error: + type: text + lang: + type: keyword + query: + type: text + dataSourceName: + type: keyword + submitTime: + type: date + format: strict_date_time||epoch_millis + jobId: + type: keyword + lastUpdateTime: + type: date + format: strict_date_time||epoch_millis + queryId: + type: keyword diff --git a/spark/src/main/resources/query_execution_request_settings.yml b/spark/src/main/resources/query_execution_request_settings.yml new file mode 100644 index 0000000000..da2bf07bf1 --- /dev/null +++ b/spark/src/main/resources/query_execution_request_settings.yml @@ -0,0 +1,11 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Settings file for the .ql-job-metadata index +index: + number_of_shards: "1" + auto_expand_replicas: "0-2" + number_of_replicas: "0" diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 429c970365..77151914af 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -22,13 +22,14 @@ import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.data.constants.SparkConstants; import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.test.OpenSearchIntegTestCase; /** mock-maker-inline does not work with OpenSearchTestCase. */ -public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { +public class InteractiveSessionTest extends OpenSearchIntegTestCase { - private static final String indexName = "mockindex"; + private static final String indexName = SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; @@ -38,13 +39,14 @@ public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(indexName, client()); - createIndex(indexName); + stateStore = new StateStore(indexName, client(), clusterService()); } @After public void clean() { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + if (clusterService().state().routingTable().hasIndex(indexName)) { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 214bcb8258..b58b1ef811 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -22,6 +22,7 @@ import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.data.constants.SparkConstants; import org.opensearch.sql.spark.execution.session.CreateSessionRequest; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; @@ -30,11 +31,11 @@ import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.rest.model.LangType; -import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.test.OpenSearchIntegTestCase; -public class StatementTest extends OpenSearchSingleNodeTestCase { +public class StatementTest extends OpenSearchIntegTestCase { - private static final String indexName = "mockindex"; + private static final String indexName = SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; private StartJobRequest startJobRequest; private StateStore stateStore; @@ -44,13 +45,14 @@ public class StatementTest extends OpenSearchSingleNodeTestCase { @Before public void setup() { startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(indexName, client()); - createIndex(indexName); + stateStore = new StateStore(indexName, client(), clusterService()); } @After public void clean() { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + if (clusterService().state().routingTable().hasIndex(indexName)) { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } } @Test @@ -125,7 +127,7 @@ public void cancelNotExistStatement() { .build(); st.open(); - client().delete(new DeleteRequest(indexName, stId.getId())); + client().delete(new DeleteRequest(indexName, stId.getId())).actionGet(); IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); assertEquals( From 5bc133ae5a614322adcd47002e6acb6ce88463da Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 17 Oct 2023 14:43:36 -0700 Subject: [PATCH 14/20] add REPL spark parameters Signed-off-by: Peng Huo --- .../model/SparkSubmitParameters.java | 13 ++++- .../spark/data/constants/SparkConstants.java | 4 ++ .../dispatcher/SparkQueryDispatcher.java | 48 ++++++++++++------- .../session/CreateSessionRequest.java | 21 +++++++- .../execution/session/InteractiveSession.java | 4 ++ .../spark/execution/session/SessionState.java | 7 ++- .../spark/execution/session/SessionType.java | 14 ++---- .../execution/statement/StatementState.java | 7 ++- .../execution/statestore/StateStore.java | 20 ++++---- .../session/InteractiveSessionTest.java | 27 +++++++---- .../execution/statement/StatementTest.java | 18 +++---- 11 files changed, 122 insertions(+), 61 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index 0609d8903c..e8a1d8bb06 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -39,7 +39,7 @@ public class SparkSubmitParameters { public static class Builder { - private final String className; + private String className; private final Map config; private String extraParameters; @@ -70,6 +70,11 @@ public static Builder builder() { return new Builder(); } + public Builder className(String className) { + this.className = className; + return this; + } + public Builder dataSource(DataSourceMetadata metadata) { if (DataSourceType.S3GLUE.equals(metadata.getConnector())) { String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); @@ -141,6 +146,12 @@ public Builder extraParameters(String params) { return this; } + public Builder sessionExecution(String sessionId) { + config.put(FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME); + config.put(FLINT_JOB_SESSION_ID, sessionId); + return this; + } + public SparkSubmitParameters build() { return new SparkSubmitParameters(className, config, extraParameters); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 1b248eb15d..85ce3c4989 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -87,4 +87,8 @@ public class SparkConstants { public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER = "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; + + public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex"; + public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId"; + public static final String FLINT_SESSION_CLASS_NAME = "org.apache.spark.sql.FlintREPL"; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 8d5ae10e91..b09e0ef9e2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; @@ -230,22 +231,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - StartJobRequest startJobRequest = - new StartJobRequest( - dispatchQueryRequest.getQuery(), - jobName, - dispatchQueryRequest.getApplicationId(), - dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() - .dataSource( - dataSourceService.getRawDataSourceMetadata( - dispatchQueryRequest.getDatasource())) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) - .build() - .toString(), - tags, - false, - dataSourceMetadata.getResultIndex()); + if (sessionManager.isEnabled()) { Session session; if (dispatchQueryRequest.getSessionId() != null) { @@ -260,7 +246,19 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ // create session if not exist session = sessionManager.createSession( - new CreateSessionRequest(startJobRequest, dataSourceMetadata.getName())); + new CreateSessionRequest( + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .className(FLINT_SESSION_CLASS_NAME) + .dataSource( + dataSourceService.getRawDataSourceMetadata( + dispatchQueryRequest.getDatasource())) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()), + tags, + dataSourceMetadata.getResultIndex(), + dataSourceMetadata.getName())); } StatementId statementId = session.submit( @@ -272,6 +270,22 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceMetadata.getResultIndex(), session.getSessionId().getSessionId()); } else { + StartJobRequest startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .dataSource( + dataSourceService.getRawDataSourceMetadata( + dispatchQueryRequest.getDatasource())) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) + .build() + .toString(), + tags, + false, + dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index 17e3346248..ca2b2b4867 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -5,11 +5,30 @@ package org.opensearch.sql.spark.execution.session; +import java.util.Map; import lombok.Data; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.StartJobRequest; @Data public class CreateSessionRequest { - private final StartJobRequest startJobRequest; + private final String jobName; + private final String applicationId; + private final String executionRoleArn; + private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder; + private final Map tags; + private final String resultIndex; private final String datasourceName; + + public StartJobRequest getStartJobRequest() { + return new StartJobRequest( + "select 1", + jobName, + applicationId, + executionRoleArn, + sparkSubmitParametersBuilder.build().toString(), + tags, + false, + resultIndex); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index e33ef4245a..d7be5619fd 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -42,6 +42,10 @@ public class InteractiveSession implements Session { @Override public void open(CreateSessionRequest createSessionRequest) { try { + // append session id; + createSessionRequest + .getSparkSubmitParametersBuilder() + .sessionExecution(sessionId.getSessionId()); String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java index a4da957f12..bd5d14c603 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -32,8 +33,10 @@ public enum SessionState { .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); public static SessionState fromString(String key) { - if (STATES.containsKey(key)) { - return STATES.get(key); + for (SessionState ss : SessionState.values()) { + if (ss.getSessionState().toLowerCase(Locale.ROOT).equals(key)) { + return ss; + } } throw new IllegalArgumentException("Invalid session state: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java index dd179a1dc5..10b9ce7bd5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java @@ -5,9 +5,7 @@ package org.opensearch.sql.spark.execution.session; -import java.util.Arrays; -import java.util.Map; -import java.util.stream.Collectors; +import java.util.Locale; import lombok.Getter; @Getter @@ -20,13 +18,11 @@ public enum SessionType { this.sessionType = sessionType; } - private static Map TYPES = - Arrays.stream(SessionType.values()) - .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); - public static SessionType fromString(String key) { - if (TYPES.containsKey(key)) { - return TYPES.get(key); + for (SessionType sType : SessionType.values()) { + if (sType.getSessionType().toLowerCase(Locale.ROOT).equals(key)) { + return sType; + } } throw new IllegalArgumentException("Invalid session type: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java index 33f7f5e831..48978ff8f9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.statement; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -30,8 +31,10 @@ public enum StatementState { .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); public static StatementState fromString(String key) { - if (STATES.containsKey(key)) { - return STATES.get(key); + for (StatementState ss : StatementState.values()) { + if (ss.getState().toLowerCase(Locale.ROOT).equals(key)) { + return ss; + } } throw new IllegalArgumentException("Invalid statement state: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index 8b5251e2a5..8df7f1cc49 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -43,10 +43,11 @@ @RequiredArgsConstructor public class StateStore { - public static Function SETTINGS_FILE_NAME = indexName -> String.format( - "%s_settings.yml", indexName.substring(indexName.indexOf('.') + 1)); - public static Function MAPPING_FILE_NAME = indexName -> String.format( - "%s_mapping.yml", indexName.substring(indexName.indexOf('.') + 1)); + public static Function SETTINGS_FILE_NAME = + indexName -> + String.format("%s_settings.yml", indexName.substring(indexName.indexOf('.') + 1)); + public static Function MAPPING_FILE_NAME = + indexName -> String.format("%s_mapping.yml", indexName.substring(indexName.indexOf('.') + 1)); private static final Logger LOG = LogManager.getLogger(); @@ -149,7 +150,7 @@ private void createIndex() { .settings(loadConfigFromResource(SETTINGS_FILE_NAME), XContentType.YAML); ActionFuture createIndexResponseActionFuture; try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { + client.threadPool().getThreadContext().stashContext()) { createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); } CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); @@ -160,19 +161,14 @@ private void createIndex() { } } catch (Throwable e) { throw new RuntimeException( - "Internal server error while creating" - + indexName - + " index:: " - + e.getMessage()); + "Internal server error while creating" + indexName + " index:: " + e.getMessage()); } } private String loadConfigFromResource(Function indexToResource) throws IOException { InputStream fileStream = - StateStore.class - .getClassLoader() - .getResourceAsStream(indexToResource.apply(indexName)); + StateStore.class.getClassLoader().getResourceAsStream(indexToResource.apply(indexName)); return IOUtils.toString(fileStream, StandardCharsets.UTF_8); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 77151914af..04be8bc49d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -12,6 +12,7 @@ import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.google.common.collect.ImmutableMap; import java.util.HashMap; import java.util.Optional; import lombok.RequiredArgsConstructor; @@ -20,6 +21,7 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.data.constants.SparkConstants; @@ -61,7 +63,7 @@ public void openCloseSession() { // open session TestSession testSession = testSession(session, stateStore); testSession - .open(new CreateSessionRequest(startJobRequest, "datasource")) + .open(createSessionRequest()) .assertSessionState(NOT_STARTED) .assertAppId("appId") .assertJobId("jobId"); @@ -81,7 +83,7 @@ public void openSessionFailedConflict() { .stateStore(stateStore) .serverlessClient(emrsClient) .build(); - session.open(new CreateSessionRequest(startJobRequest, "datasource")); + session.open(createSessionRequest()); InteractiveSession duplicateSession = InteractiveSession.builder() @@ -91,8 +93,7 @@ public void openSessionFailedConflict() { .build(); IllegalStateException exception = assertThrows( - IllegalStateException.class, - () -> duplicateSession.open(new CreateSessionRequest(startJobRequest, "datasource"))); + IllegalStateException.class, () -> duplicateSession.open(createSessionRequest())); assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage()); } @@ -105,7 +106,7 @@ public void closeNotExistSession() { .stateStore(stateStore) .serverlessClient(emrsClient) .build(); - session.open(new CreateSessionRequest(startJobRequest, "datasource")); + session.open(createSessionRequest()); client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); @@ -118,7 +119,7 @@ public void closeNotExistSession() { public void sessionManagerCreateSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); TestSession testSession = testSession(session, stateStore); testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); @@ -128,8 +129,7 @@ public void sessionManagerCreateSession() { public void sessionManagerGetSession() { SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting(false)); - Session session = - sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource")); + Session session = sessionManager.createSession(createSessionRequest()); Optional managerSession = sessionManager.getSession(session.getSessionId()); assertTrue(managerSession.isPresent()); @@ -186,6 +186,17 @@ public TestSession close() { } } + public static CreateSessionRequest createSessionRequest() { + return new CreateSessionRequest( + "jobName", + "appId", + "arn", + SparkSubmitParameters.Builder.builder(), + ImmutableMap.of(), + "resultIndex", + "datasource"); + } + public static class TestEMRServerlessClient implements EMRServerlessClient { private int startJobRunCalled = 0; diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index b58b1ef811..e0f07d3445 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.execution.statement; +import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.createSessionRequest; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; @@ -23,7 +24,6 @@ import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.data.constants.SparkConstants; -import org.opensearch.sql.spark.execution.session.CreateSessionRequest; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -200,7 +200,7 @@ public void cancelRunningStatementFailed() { public void submitStatementInRunningSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // App change state to running updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); @@ -213,7 +213,7 @@ public void submitStatementInRunningSession() { public void submitStatementInNotStartedState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); assertFalse(statementId.getId().isEmpty()); @@ -223,7 +223,7 @@ public void submitStatementInNotStartedState() { public void failToSubmitStatementInDeadState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); @@ -241,7 +241,7 @@ public void failToSubmitStatementInDeadState() { public void failToSubmitStatementInFailState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); @@ -259,7 +259,7 @@ public void failToSubmitStatementInFailState() { public void newStatementFieldAssert() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); @@ -277,7 +277,7 @@ public void newStatementFieldAssert() { public void failToSubmitStatementInDeletedSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // other's delete session client() @@ -295,7 +295,7 @@ public void failToSubmitStatementInDeletedSession() { public void getStatementSuccess() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // App change state to running updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); @@ -310,7 +310,7 @@ public void getStatementSuccess() { public void getStatementNotExist() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // App change state to running updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); From 3ccd276a9bb7cba54b1bc41dcf9498d37f745d6c Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 17 Oct 2023 23:25:16 -0700 Subject: [PATCH 15/20] Add IT Signed-off-by: Peng Huo --- .../sql/common/setting/Settings.java | 5 - spark/build.gradle | 1 + .../execution/statestore/StateStore.java | 2 +- ...AsyncQueryExecutorServiceImplSpecTest.java | 349 ++++++++++++++++++ 4 files changed, 351 insertions(+), 6 deletions(-) create mode 100644 spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index 89d046b3d9..e2dffe61e1 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -63,9 +63,4 @@ public static Optional of(String keyValue) { public abstract T getSettingValue(Key key); public abstract List getSettings(); - - /** Helper class */ - public static boolean isSparkExecutionSessionEnabled(Settings settings) { - return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); - } } diff --git a/spark/build.gradle b/spark/build.gradle index 15f1e200e0..8f4388495e 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -68,6 +68,7 @@ dependencies { because 'allows tests to run from IDEs that bundle older version of launcher' } testImplementation("org.opensearch.test:framework:${opensearch_version}") + testImplementation project(':opensearch') } test { diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index 8df7f1cc49..498f51fa28 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -91,7 +91,7 @@ protected Optional get(String sid, StateModel.FromXCon createIndex(); return Optional.empty(); } - GetRequest getRequest = new GetRequest().index(indexName).id(sid); + GetRequest getRequest = new GetRequest().index(indexName).id(sid).refresh(true); GetResponse getResponse = client.get(getRequest).actionGet(); if (getResponse.isExists()) { XContentParser parser = diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java new file mode 100644 index 0000000000..1512c28a53 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -0,0 +1,349 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery; + +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_SESSION_ENABLED_SETTING; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_DOC_TYPE; +import static org.opensearch.sql.spark.execution.statement.StatementModel.SESSION_ID; +import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; + +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRun; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import lombok.Getter; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.plugins.Plugin; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.datasources.encryptor.EncryptorImpl; +import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; +import org.opensearch.sql.datasources.service.DataSourceMetadataStorage; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.datasources.storage.OpenSearchDataSourceMetadataStorage; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.StatementModel; +import org.opensearch.sql.spark.execution.statement.StatementState; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.storage.DataSourceFactory; +import org.opensearch.test.OpenSearchIntegTestCase; + + +public class AsyncQueryExecutorServiceImplSpecTest extends OpenSearchIntegTestCase { + public static final String DATASOURCE = "mys3"; + + private ClusterService clusterService; + private org.opensearch.sql.common.setting.Settings pluginSettings; + private NodeClient client; + private DataSourceServiceImpl dataSourceService; + private StateStore stateStore; + private ClusterSettings clusterSettings; + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(TestSettingPlugin.class); + } + + public static class TestSettingPlugin extends Plugin { + @Override + public List> getSettings() { + return OpenSearchSettings.pluginSettings(); + } + } + + @Before + public void setup() { + clusterService = clusterService(); + clusterSettings = clusterService.getClusterSettings(); + pluginSettings = new OpenSearchSettings(clusterSettings); + client = (NodeClient) cluster().client(); + dataSourceService = createDataSourceService(); + dataSourceService.createDataSource(new DataSourceMetadata(DATASOURCE, DataSourceType.S3GLUE, + ImmutableList.of(), ImmutableMap.of("glue.auth.type", "iam_role", + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", + "glue.indexstore.opensearch.uri", "http://ec2-18-237-133-156.us-west-2.compute.amazonaws" + + ".com:9200", + "glue.indexstore.opensearch.auth", "noauth"), null)); + stateStore = new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client, clusterService); + createIndex(SPARK_RESPONSE_BUFFER_INDEX_NAME); + } + + @After + public void clean() { + client.admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings(Settings.builder().putNull(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey()).build()).get(); + } + + @Test + public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // disable session + enableSession(false); + + // 1. create async query. + CreateAsyncQueryResponse response = asyncQueryExecutorService + .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, + LangType.SQL, null)); + assertFalse(clusterService().state().routingTable().hasIndex(SPARK_REQUEST_BUFFER_INDEX_NAME)); + emrsClient.startJobRunCalled(1); + + // 2. fetch async query result. + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals("RUNNING", asyncQueryResults.getStatus()); + emrsClient.getJobRunResultCalled(1); + + // 3. cancel async query. + String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + assertEquals(response.getQueryId(), cancelQueryId); + emrsClient.cancelJobRunCalled(1); + } + + @Test + public void createAsyncQueryCreateJobWithCorrectParameters() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + enableSession(false); + CreateAsyncQueryResponse response = asyncQueryExecutorService + .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, + LangType.SQL, null)); + String params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertNull(response.getSessionId()); + assertTrue(params.contains(String.format("--class %s", DEFAULT_CLASS_NAME))); + assertFalse(params.contains(String.format("%s=%s", + FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); + assertFalse(params.contains(String.format("%s=%s", + FLINT_JOB_SESSION_ID, response.getSessionId()))); + + // enable session + enableSession(true); + response = asyncQueryExecutorService + .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, + LangType.SQL, null)); + params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertTrue(params.contains(String.format("--class %s", FLINT_SESSION_CLASS_NAME))); + assertTrue(params.contains(String.format("%s=%s", + FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); + assertTrue(params.contains(String.format("%s=%s", + FLINT_JOB_SESSION_ID, response.getSessionId()))); + } + + @Test + public void withSessionCreateAsyncQueryThenGetResultThenCancel() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse response = asyncQueryExecutorService + .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, + LangType.SQL, null)); + assertNotNull(response.getSessionId()); + Optional statementModel = getStatement(stateStore).apply(response.getQueryId()); + assertTrue(statementModel.isPresent()); + assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); + + // 2. fetch async query result. + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); + + // 3. cancel async query. + String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + assertEquals(response.getQueryId(), cancelQueryId); + } + + @Test + public void reuseSessionWhenCreateAsyncQuery() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse first = asyncQueryExecutorService + .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, + LangType.SQL, null)); + assertNotNull(first.getSessionId()); + + // 2. reuse session id + CreateAsyncQueryResponse second = asyncQueryExecutorService + .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, + LangType.SQL, first.getSessionId())); + + assertEquals(first.getSessionId(), second.getSessionId()); + assertNotEquals(first.getQueryId(), second.getQueryId()); + // one session doc. + assertEquals(1, search(QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("type", SESSION_DOC_TYPE)) + .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); + // two statement docs has same sessionId. + assertEquals(2, search(QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("type", STATEMENT_DOC_TYPE)) + .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); + + Optional firstModel = getStatement(stateStore).apply(first.getQueryId()); + assertTrue(firstModel.isPresent()); + assertEquals(StatementState.WAITING, firstModel.get().getStatementState()); + assertEquals(first.getQueryId(), firstModel.get().getStatementId().getId()); + assertEquals(first.getQueryId(), firstModel.get().getQueryId()); + Optional secondModel = getStatement(stateStore).apply(second.getQueryId()); + assertEquals(StatementState.WAITING, secondModel.get().getStatementState()); + assertEquals(second.getQueryId(), secondModel.get().getStatementId().getId()); + assertEquals(second.getQueryId(), secondModel.get().getQueryId()); + } + + private DataSourceServiceImpl createDataSourceService() { + String masterKey = "1234567890"; + DataSourceMetadataStorage dataSourceMetadataStorage = + new OpenSearchDataSourceMetadataStorage( + client, clusterService, new EncryptorImpl(masterKey)); + return new DataSourceServiceImpl( + new ImmutableSet.Builder() + .add(new GlueDataSourceFactory(pluginSettings)) + .build(), + dataSourceMetadataStorage, + meta -> { + }); + } + + private AsyncQueryExecutorService createAsyncQueryExecutorService( + EMRServerlessClient emrServerlessClient) { + AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = + new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, + this.dataSourceService, + new DataSourceUserAuthorizationHelperImpl(client), + jobExecutionResponseReader, + new FlintIndexMetadataReaderImpl(client), + client, + new SessionManager( + new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client, clusterService), + emrServerlessClient, + pluginSettings)); + return new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, + sparkQueryDispatcher, + this::sparkExecutionEngineConfig); + } + + public static class LocalEMRSClient implements EMRServerlessClient { + + private int startJobRunCalled = 0; + private int cancelJobRunCalled = 0; + private int getJobResult = 0; + + @Getter + private StartJobRequest jobRequest; + + @Override + public String startJobRun(StartJobRequest startJobRequest) { + jobRequest = startJobRequest; + startJobRunCalled++; + return "jobId"; + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + getJobResult++; + JobRun jobRun = new JobRun(); + jobRun.setState("RUNNING"); + return new GetJobRunResult().withJobRun(jobRun); + } + + @Override + public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { + cancelJobRunCalled++; + return new CancelJobRunResult().withJobRunId(jobId); + } + + public void startJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, startJobRunCalled); + } + + public void cancelJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, cancelJobRunCalled); + } + + public void getJobRunResultCalled(int expectedTimes) { + assertEquals(expectedTimes, getJobResult); + } + } + + public SparkExecutionEngineConfig sparkExecutionEngineConfig() { + return new SparkExecutionEngineConfig("appId", "us-west-2", "roleArn", "", "myCluster"); + } + + public void enableSession(boolean enabled) { + client.admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings(Settings.builder().put(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey(), enabled).build()).get(); + } + + int search(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(SPARK_REQUEST_BUFFER_INDEX_NAME); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchRequest.source(searchSourceBuilder); + SearchResponse searchResponse = client.search(searchRequest).actionGet(); + + return searchResponse.getHits().getHits().length; + } +} From 1941c249e500af4687945609a4b8ddbf02b90661 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 18 Oct 2023 13:23:27 -0700 Subject: [PATCH 16/20] format code Signed-off-by: Peng Huo --- ...AsyncQueryExecutorServiceImplSpecTest.java | 117 +++++++++++------- 1 file changed, 70 insertions(+), 47 deletions(-) diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 1512c28a53..06d3607ec0 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -68,7 +68,6 @@ import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.test.OpenSearchIntegTestCase; - public class AsyncQueryExecutorServiceImplSpecTest extends OpenSearchIntegTestCase { public static final String DATASOURCE = "mys3"; @@ -98,22 +97,34 @@ public void setup() { pluginSettings = new OpenSearchSettings(clusterSettings); client = (NodeClient) cluster().client(); dataSourceService = createDataSourceService(); - dataSourceService.createDataSource(new DataSourceMetadata(DATASOURCE, DataSourceType.S3GLUE, - ImmutableList.of(), ImmutableMap.of("glue.auth.type", "iam_role", - "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", - "glue.indexstore.opensearch.uri", "http://ec2-18-237-133-156.us-west-2.compute.amazonaws" + - ".com:9200", - "glue.indexstore.opensearch.auth", "noauth"), null)); + dataSourceService.createDataSource( + new DataSourceMetadata( + DATASOURCE, + DataSourceType.S3GLUE, + ImmutableList.of(), + ImmutableMap.of( + "glue.auth.type", + "iam_role", + "glue.auth.role_arn", + "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", + "glue.indexstore.opensearch.uri", + "http://ec2-18-237-133-156.us-west-2.compute.amazonaws" + ".com:9200", + "glue.indexstore.opensearch.auth", + "noauth"), + null)); stateStore = new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client, clusterService); createIndex(SPARK_RESPONSE_BUFFER_INDEX_NAME); } @After public void clean() { - client.admin() + client + .admin() .cluster() .prepareUpdateSettings() - .setTransientSettings(Settings.builder().putNull(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey()).build()).get(); + .setTransientSettings( + Settings.builder().putNull(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey()).build()) + .get(); } @Test @@ -126,9 +137,9 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { enableSession(false); // 1. create async query. - CreateAsyncQueryResponse response = asyncQueryExecutorService - .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, - LangType.SQL, null)); + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); assertFalse(clusterService().state().routingTable().hasIndex(SPARK_REQUEST_BUFFER_INDEX_NAME)); emrsClient.startJobRunCalled(1); @@ -151,28 +162,30 @@ public void createAsyncQueryCreateJobWithCorrectParameters() { createAsyncQueryExecutorService(emrsClient); enableSession(false); - CreateAsyncQueryResponse response = asyncQueryExecutorService - .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, - LangType.SQL, null)); + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); String params = emrsClient.getJobRequest().getSparkSubmitParams(); assertNull(response.getSessionId()); assertTrue(params.contains(String.format("--class %s", DEFAULT_CLASS_NAME))); - assertFalse(params.contains(String.format("%s=%s", - FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); - assertFalse(params.contains(String.format("%s=%s", - FLINT_JOB_SESSION_ID, response.getSessionId()))); + assertFalse( + params.contains( + String.format("%s=%s", FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); + assertFalse( + params.contains(String.format("%s=%s", FLINT_JOB_SESSION_ID, response.getSessionId()))); // enable session enableSession(true); - response = asyncQueryExecutorService - .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, - LangType.SQL, null)); + response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); params = emrsClient.getJobRequest().getSparkSubmitParams(); assertTrue(params.contains(String.format("--class %s", FLINT_SESSION_CLASS_NAME))); - assertTrue(params.contains(String.format("%s=%s", - FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); - assertTrue(params.contains(String.format("%s=%s", - FLINT_JOB_SESSION_ID, response.getSessionId()))); + assertTrue( + params.contains( + String.format("%s=%s", FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); + assertTrue( + params.contains(String.format("%s=%s", FLINT_JOB_SESSION_ID, response.getSessionId()))); } @Test @@ -185,9 +198,9 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { enableSession(true); // 1. create async query. - CreateAsyncQueryResponse response = asyncQueryExecutorService - .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, - LangType.SQL, null)); + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); assertNotNull(response.getSessionId()); Optional statementModel = getStatement(stateStore).apply(response.getQueryId()); assertTrue(statementModel.isPresent()); @@ -213,26 +226,33 @@ public void reuseSessionWhenCreateAsyncQuery() { enableSession(true); // 1. create async query. - CreateAsyncQueryResponse first = asyncQueryExecutorService - .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, - LangType.SQL, null)); + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); assertNotNull(first.getSessionId()); // 2. reuse session id - CreateAsyncQueryResponse second = asyncQueryExecutorService - .createAsyncQuery(new CreateAsyncQueryRequest("select 1", DATASOURCE, - LangType.SQL, first.getSessionId())); + CreateAsyncQueryResponse second = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, first.getSessionId())); assertEquals(first.getSessionId(), second.getSessionId()); assertNotEquals(first.getQueryId(), second.getQueryId()); // one session doc. - assertEquals(1, search(QueryBuilders.boolQuery() - .must(QueryBuilders.termQuery("type", SESSION_DOC_TYPE)) - .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); + assertEquals( + 1, + search( + QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("type", SESSION_DOC_TYPE)) + .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); // two statement docs has same sessionId. - assertEquals(2, search(QueryBuilders.boolQuery() - .must(QueryBuilders.termQuery("type", STATEMENT_DOC_TYPE)) - .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); + assertEquals( + 2, + search( + QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("type", STATEMENT_DOC_TYPE)) + .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); Optional firstModel = getStatement(stateStore).apply(first.getQueryId()); assertTrue(firstModel.isPresent()); @@ -255,8 +275,7 @@ private DataSourceServiceImpl createDataSourceService() { .add(new GlueDataSourceFactory(pluginSettings)) .build(), dataSourceMetadataStorage, - meta -> { - }); + meta -> {}); } private AsyncQueryExecutorService createAsyncQueryExecutorService( @@ -288,8 +307,7 @@ public static class LocalEMRSClient implements EMRServerlessClient { private int cancelJobRunCalled = 0; private int getJobResult = 0; - @Getter - private StartJobRequest jobRequest; + @Getter private StartJobRequest jobRequest; @Override public String startJobRun(StartJobRequest startJobRequest) { @@ -330,10 +348,15 @@ public SparkExecutionEngineConfig sparkExecutionEngineConfig() { } public void enableSession(boolean enabled) { - client.admin() + client + .admin() .cluster() .prepareUpdateSettings() - .setTransientSettings(Settings.builder().put(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey(), enabled).build()).get(); + .setTransientSettings( + Settings.builder() + .put(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey(), enabled) + .build()) + .get(); } int search(QueryBuilder query) { From 11da8ddd48ed6afc9ae4886bf850b3beddf5d177 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 18 Oct 2023 16:03:21 -0700 Subject: [PATCH 17/20] bind request index to datasource Signed-off-by: Peng Huo --- .../org/opensearch/sql/plugin/SQLPlugin.java | 5 +- .../execution/session/InteractiveSession.java | 11 +- .../spark/execution/session/SessionId.java | 21 ++- .../execution/session/SessionManager.java | 5 +- .../spark/execution/statement/Statement.java | 20 +- .../execution/statement/StatementModel.java | 10 + .../execution/statestore/StateStore.java | 175 +++++++++++------- ...AsyncQueryExecutorServiceImplSpecTest.java | 18 +- .../session/InteractiveSessionTest.java | 22 ++- .../execution/statement/StatementTest.java | 35 ++-- 10 files changed, 209 insertions(+), 113 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 009651025d..027a4dd502 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -7,7 +7,6 @@ import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.emrserverless.AWSEMRServerless; @@ -323,9 +322,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( new FlintIndexMetadataReaderImpl(client), client, new SessionManager( - new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client, clusterService), - emrServerlessClient, - pluginSettings)); + new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index d7be5619fd..af22db5838 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -52,7 +52,7 @@ public void open(CreateSessionRequest createSessionRequest) { sessionModel = initInteractiveSession( applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - createSession(stateStore).apply(sessionModel); + createSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -63,7 +63,8 @@ public void open(CreateSessionRequest createSessionRequest) { /** todo. StatementSweeper will delete doc. */ @Override public void close() { - Optional model = getSession(stateStore).apply(sessionModel.getId()); + Optional model = + getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -73,7 +74,8 @@ public void close() { /** Submit statement. If submit successfully, Statement in waiting state. */ public StatementId submit(QueryRequest request) { - Optional model = getSession(stateStore).apply(sessionModel.getId()); + Optional model = + getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -88,6 +90,7 @@ public StatementId submit(QueryRequest request) { .stateStore(stateStore) .statementId(statementId) .langType(LangType.SQL) + .datasourceName(sessionModel.getDatasourceName()) .query(request.getQuery()) .queryId(statementId.getId()) .build(); @@ -107,7 +110,7 @@ public StatementId submit(QueryRequest request) { @Override public Optional get(StatementId stID) { - return StateStore.getStatement(stateStore) + return StateStore.getStatement(stateStore, sessionModel.getDatasourceName()) .apply(stID.getId()) .map( model -> diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java index 861d906b9b..b3bd716925 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -5,15 +5,32 @@ package org.opensearch.sql.spark.execution.session; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import lombok.Data; import org.apache.commons.lang3.RandomStringUtils; @Data public class SessionId { + public static final int PREFIX_LEN = 10; + private final String sessionId; - public static SessionId newSessionId() { - return new SessionId(RandomStringUtils.randomAlphanumeric(16)); + public static SessionId newSessionId(String datasourceName) { + return new SessionId(encode(datasourceName)); + } + + public String getDataSourceName() { + return decode(sessionId); + } + + private static String decode(String sessionId) { + return new String(Base64.getDecoder().decode(sessionId)).substring(PREFIX_LEN); + } + + private static String encode(String datasourceName) { + String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; + return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index c34be7015f..c0f7bbcde8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -28,7 +28,7 @@ public class SessionManager { public Session createSession(CreateSessionRequest request) { InteractiveSession session = InteractiveSession.builder() - .sessionId(newSessionId()) + .sessionId(newSessionId(request.getDatasourceName())) .stateStore(stateStore) .serverlessClient(emrServerlessClient) .build(); @@ -37,7 +37,8 @@ public Session createSession(CreateSessionRequest request) { } public Optional getSession(SessionId sid) { - Optional model = StateStore.getSession(stateStore).apply(sid.getSessionId()); + Optional model = + StateStore.getSession(stateStore, sid.getDataSourceName()).apply(sid.getSessionId()); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index 8fcedb5fca..d84c91bdb8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -32,6 +32,7 @@ public class Statement { private final String jobId; private final StatementId statementId; private final LangType langType; + private final String datasourceName; private final String query; private final String queryId; private final StateStore stateStore; @@ -42,8 +43,16 @@ public class Statement { public void open() { try { statementModel = - submitStatement(sessionId, applicationId, jobId, statementId, langType, query, queryId); - statementModel = createStatement(stateStore).apply(statementModel); + submitStatement( + sessionId, + applicationId, + jobId, + statementId, + langType, + datasourceName, + query, + queryId); + statementModel = createStatement(stateStore, datasourceName).apply(statementModel); } catch (VersionConflictEngineException e) { String errorMsg = "statement already exist. " + statementId; LOG.error(errorMsg); @@ -61,7 +70,8 @@ public void cancel() { } try { this.statementModel = - updateStatementState(stateStore).apply(this.statementModel, StatementState.CANCELLED); + updateStatementState(stateStore, statementModel.getDatasourceName()) + .apply(this.statementModel, StatementState.CANCELLED); } catch (DocumentMissingException e) { String errorMsg = String.format("cancel statement failed. no statement found. statement: %s.", statementId); @@ -69,7 +79,9 @@ public void cancel() { throw new IllegalStateException(errorMsg); } catch (VersionConflictEngineException e) { this.statementModel = - getStatement(stateStore).apply(statementModel.getId()).orElse(this.statementModel); + getStatement(stateStore, statementModel.getDatasourceName()) + .apply(statementModel.getId()) + .orElse(this.statementModel); String errorMsg = String.format( "cancel statement failed. current statementState: %s " + "statement: %s.", diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java index c7f681c541..2a1043bf73 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.statement; import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID; +import static org.opensearch.sql.spark.execution.session.SessionModel.DATASOURCE_NAME; import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; @@ -45,6 +46,7 @@ public class StatementModel extends StateModel { private final String applicationId; private final String jobId; private final LangType langType; + private final String datasourceName; private final String query; private final String queryId; private final long submitTime; @@ -65,6 +67,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(APPLICATION_ID, applicationId) .field(JOB_ID, jobId) .field(LANG, langType.getText()) + .field(DATASOURCE_NAME, datasourceName) .field(QUERY, query) .field(QUERY_ID, queryId) .field(SUBMIT_TIME, submitTime) @@ -82,6 +85,7 @@ public static StatementModel copy(StatementModel copy, long seqNo, long primaryT .applicationId(copy.applicationId) .jobId(copy.jobId) .langType(copy.langType) + .datasourceName(copy.datasourceName) .query(copy.query) .queryId(copy.queryId) .submitTime(copy.submitTime) @@ -101,6 +105,7 @@ public static StatementModel copyWithState( .applicationId(copy.applicationId) .jobId(copy.jobId) .langType(copy.langType) + .datasourceName(copy.datasourceName) .query(copy.query) .queryId(copy.queryId) .submitTime(copy.submitTime) @@ -143,6 +148,9 @@ public static StatementModel fromXContent(XContentParser parser, long seqNo, lon case LANG: builder.langType(LangType.fromString(parser.text())); break; + case DATASOURCE_NAME: + builder.datasourceName(parser.text()); + break; case QUERY: builder.query(parser.text()); break; @@ -168,6 +176,7 @@ public static StatementModel submitStatement( String jobId, StatementId statementId, LangType langType, + String datasourceName, String query, String queryId) { return builder() @@ -178,6 +187,7 @@ public static StatementModel submitStatement( .applicationId(applicationId) .jobId(jobId) .langType(langType) + .datasourceName(datasourceName) .query(query) .queryId(queryId) .submitTime(System.currentTimeMillis()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index 498f51fa28..a36ee3ef45 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -5,6 +5,8 @@ package org.opensearch.sql.spark.execution.statestore; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; + import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; @@ -41,24 +43,28 @@ import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; +/** + * State Store maintain the state of Session and Statement. State State create/update/get doc on + * index regardless user FGAC permissions. + */ @RequiredArgsConstructor public class StateStore { - public static Function SETTINGS_FILE_NAME = - indexName -> - String.format("%s_settings.yml", indexName.substring(indexName.indexOf('.') + 1)); - public static Function MAPPING_FILE_NAME = - indexName -> String.format("%s_mapping.yml", indexName.substring(indexName.indexOf('.') + 1)); + public static String SETTINGS_FILE_NAME = "query_execution_request_settings.yml"; + public static String MAPPING_FILE_NAME = "query_execution_request_mapping.yml"; + public static Function DATASOURCE_TO_REQUEST_INDEX = + datasourceName -> String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); + public static String ALL_REQUEST_INDEX = String.format("%s_*", SPARK_REQUEST_BUFFER_INDEX_NAME); private static final Logger LOG = LogManager.getLogger(); - private final String indexName; private final Client client; private final ClusterService clusterService; - protected T create(T st, StateModel.CopyBuilder builder) { + protected T create( + T st, StateModel.CopyBuilder builder, String indexName) { try { if (!this.clusterService.state().routingTable().hasIndex(indexName)) { - createIndex(); + createIndex(indexName); } IndexRequest indexRequest = new IndexRequest(indexName) @@ -68,44 +74,52 @@ protected T create(T st, StateModel.CopyBuilder builde .setIfPrimaryTerm(st.getPrimaryTerm()) .create(true) .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client.index(indexRequest).actionGet(); - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("Successfully created doc. id: {}", st.getId()); - return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed create doc. id: %s, error: %s", - st.getId(), - indexResponse.getResult().getLowercase())); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + ; + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", st.getId()); + return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + st.getId(), + indexResponse.getResult().getLowercase())); + } } } catch (IOException e) { throw new RuntimeException(e); } } - protected Optional get(String sid, StateModel.FromXContent builder) { + protected Optional get( + String sid, StateModel.FromXContent builder, String indexName) { try { if (!this.clusterService.state().routingTable().hasIndex(indexName)) { - createIndex(); + createIndex(indexName); return Optional.empty(); } GetRequest getRequest = new GetRequest().index(indexName).id(sid).refresh(true); - GetResponse getResponse = client.get(getRequest).actionGet(); - if (getResponse.isExists()) { - XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - getResponse.getSourceAsString()); - parser.nextToken(); - return Optional.of( - builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); - } else { - return Optional.empty(); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } } } catch (IOException e) { throw new RuntimeException(e); @@ -113,7 +127,7 @@ protected Optional get(String sid, StateModel.FromXCon } protected T updateState( - T st, S state, StateModel.StateCopyBuilder builder) { + T st, S state, StateModel.StateCopyBuilder builder, String indexName) { try { T model = builder.of(st, state, st.getSeqNo(), st.getPrimaryTerm()); UpdateRequest updateRequest = @@ -125,24 +139,28 @@ protected T updateState( .doc(model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) .fetchSource(true) .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - UpdateResponse updateResponse = client.update(updateRequest).actionGet(); - if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { - LOG.debug("Successfully update doc. id: {}", st.getId()); - return builder.of(model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed update doc. id: %s, error: %s", - st.getId(), - updateResponse.getResult().getLowercase())); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + UpdateResponse updateResponse = client.update(updateRequest).actionGet(); + if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { + LOG.debug("Successfully update doc. id: {}", st.getId()); + return builder.of( + model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed update doc. id: %s, error: %s", + st.getId(), + updateResponse.getResult().getLowercase())); + } } } catch (IOException e) { throw new RuntimeException(e); } } - private void createIndex() { + private void createIndex(String indexName) { try { CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); createIndexRequest @@ -165,37 +183,66 @@ private void createIndex() { } } - private String loadConfigFromResource(Function indexToResource) - throws IOException { - InputStream fileStream = - StateStore.class.getClassLoader().getResourceAsStream(indexToResource.apply(indexName)); + private String loadConfigFromResource(String fileName) throws IOException { + InputStream fileStream = StateStore.class.getClassLoader().getResourceAsStream(fileName); return IOUtils.toString(fileStream, StandardCharsets.UTF_8); } /** Helper Functions */ - public static Function createStatement(StateStore stateStore) { - return (st) -> stateStore.create(st, StatementModel::copy); + public static Function createStatement( + StateStore stateStore, String datasourceName) { + return (st) -> + stateStore.create( + st, StatementModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function> getStatement(StateStore stateStore) { - return (docId) -> stateStore.get(docId, StatementModel::fromXContent); + public static Function> getStatement( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, StatementModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } public static BiFunction updateStatementState( - StateStore stateStore) { - return (old, state) -> stateStore.updateState(old, state, StatementModel::copyWithState); + StateStore stateStore, String datasourceName) { + return (old, state) -> + stateStore.updateState( + old, + state, + StatementModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function createSession(StateStore stateStore) { - return (session) -> stateStore.create(session, SessionModel::of); + public static Function createSession( + StateStore stateStore, String datasourceName) { + return (session) -> + stateStore.create( + session, SessionModel::of, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function> getSession(StateStore stateStore) { - return (docId) -> stateStore.get(docId, SessionModel::fromXContent); + public static Function> getSession( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, SessionModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Function> searchSession(StateStore stateStore) { + return (docId) -> stateStore.get(docId, SessionModel::fromXContent, ALL_REQUEST_INDEX); } public static BiFunction updateSessionState( - StateStore stateStore) { - return (old, state) -> stateStore.updateState(old, state, SessionModel::copyWithState); + StateStore stateStore, String datasourceName) { + return (old, state) -> + stateStore.updateState( + old, + state, + SessionModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Runnable createStateStoreIndex(StateStore stateStore, String datasourceName) { + String indexName = String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); + return () -> stateStore.createIndex(indexName); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 06d3607ec0..3eb8958eb2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -15,6 +15,7 @@ import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_DOC_TYPE; import static org.opensearch.sql.spark.execution.statement.StatementModel.SESSION_ID; import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; @@ -112,7 +113,7 @@ public void setup() { "glue.indexstore.opensearch.auth", "noauth"), null)); - stateStore = new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client, clusterService); + stateStore = new StateStore(client, clusterService); createIndex(SPARK_RESPONSE_BUFFER_INDEX_NAME); } @@ -202,7 +203,8 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); assertNotNull(response.getSessionId()); - Optional statementModel = getStatement(stateStore).apply(response.getQueryId()); + Optional statementModel = + getStatement(stateStore, DATASOURCE).apply(response.getQueryId()); assertTrue(statementModel.isPresent()); assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); @@ -254,12 +256,14 @@ public void reuseSessionWhenCreateAsyncQuery() { .must(QueryBuilders.termQuery("type", STATEMENT_DOC_TYPE)) .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); - Optional firstModel = getStatement(stateStore).apply(first.getQueryId()); + Optional firstModel = + getStatement(stateStore, DATASOURCE).apply(first.getQueryId()); assertTrue(firstModel.isPresent()); assertEquals(StatementState.WAITING, firstModel.get().getStatementState()); assertEquals(first.getQueryId(), firstModel.get().getStatementId().getId()); assertEquals(first.getQueryId(), firstModel.get().getQueryId()); - Optional secondModel = getStatement(stateStore).apply(second.getQueryId()); + Optional secondModel = + getStatement(stateStore, DATASOURCE).apply(second.getQueryId()); assertEquals(StatementState.WAITING, secondModel.get().getStatementState()); assertEquals(second.getQueryId(), secondModel.get().getStatementId().getId()); assertEquals(second.getQueryId(), secondModel.get().getQueryId()); @@ -292,9 +296,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( new FlintIndexMetadataReaderImpl(client), client, new SessionManager( - new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client, clusterService), - emrServerlessClient, - pluginSettings)); + new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, @@ -361,7 +363,7 @@ public void enableSession(boolean enabled) { int search(QueryBuilder query) { SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(SPARK_REQUEST_BUFFER_INDEX_NAME); + searchRequest.indices(DATASOURCE_TO_REQUEST_INDEX.apply(DATASOURCE)); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(query); searchRequest.source(searchSourceBuilder); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 04be8bc49d..06a8d8c73c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -8,6 +8,7 @@ import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; @@ -24,14 +25,14 @@ import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.data.constants.SparkConstants; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchIntegTestCase; /** mock-maker-inline does not work with OpenSearchTestCase. */ public class InteractiveSessionTest extends OpenSearchIntegTestCase { - private static final String indexName = SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; + private static final String DS_NAME = "mys3"; + private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; @@ -41,7 +42,7 @@ public class InteractiveSessionTest extends OpenSearchIntegTestCase { public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(indexName, client(), clusterService()); + stateStore = new StateStore(client(), clusterService()); } @After @@ -55,7 +56,7 @@ public void clean() { public void openCloseSession() { InteractiveSession session = InteractiveSession.builder() - .sessionId(SessionId.newSessionId()) + .sessionId(SessionId.newSessionId(DS_NAME)) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); @@ -76,7 +77,7 @@ public void openCloseSession() { @Test public void openSessionFailedConflict() { - SessionId sessionId = new SessionId("duplicate-session-id"); + SessionId sessionId = SessionId.newSessionId(DS_NAME); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) @@ -94,12 +95,12 @@ public void openSessionFailedConflict() { IllegalStateException exception = assertThrows( IllegalStateException.class, () -> duplicateSession.open(createSessionRequest())); - assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage()); + assertEquals("session already exist. " + sessionId, exception.getMessage()); } @Test public void closeNotExistSession() { - SessionId sessionId = SessionId.newSessionId(); + SessionId sessionId = SessionId.newSessionId(DS_NAME); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) @@ -141,7 +142,8 @@ public void sessionManagerGetSessionNotExist() { SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting(false)); - Optional managerSession = sessionManager.getSession(new SessionId("no-exist")); + Optional managerSession = + sessionManager.getSession(SessionId.newSessionId("no-exist")); assertTrue(managerSession.isEmpty()); } @@ -158,7 +160,7 @@ public TestSession assertSessionState(SessionState expected) { assertEquals(expected, session.getSessionModel().getSessionState()); Optional sessionStoreState = - getSession(stateStore).apply(session.getSessionModel().getId()); + getSession(stateStore, DS_NAME).apply(session.getSessionModel().getId()); assertTrue(sessionStoreState.isPresent()); assertEquals(expected, sessionStoreState.get().getSessionState()); @@ -194,7 +196,7 @@ public static CreateSessionRequest createSessionRequest() { SparkSubmitParameters.Builder.builder(), ImmutableMap.of(), "resultIndex", - "datasource"); + DS_NAME); } public static class TestEMRServerlessClient implements EMRServerlessClient { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index e0f07d3445..ff3ddd1bef 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -10,11 +10,11 @@ import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; -import java.util.HashMap; import java.util.Optional; import lombok.RequiredArgsConstructor; import org.junit.After; @@ -22,8 +22,6 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.data.constants.SparkConstants; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -35,17 +33,16 @@ public class StatementTest extends OpenSearchIntegTestCase { - private static final String indexName = SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; + private static final String DS_NAME = "mys3"; + private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); - private StartJobRequest startJobRequest; private StateStore stateStore; private InteractiveSessionTest.TestEMRServerlessClient emrsClient = new InteractiveSessionTest.TestEMRServerlessClient(); @Before public void setup() { - startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(indexName, client(), clusterService()); + stateStore = new StateStore(client(), clusterService()); } @After @@ -64,6 +61,7 @@ public void openThenCancelStatement() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -89,6 +87,7 @@ public void openFailedBecauseConflict() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -103,6 +102,7 @@ public void openFailedBecauseConflict() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -121,6 +121,7 @@ public void cancelNotExistStatement() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -145,6 +146,7 @@ public void cancelFailedBecauseOfConflict() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -152,7 +154,7 @@ public void cancelFailedBecauseOfConflict() { st.open(); StatementModel running = - updateStatementState(stateStore).apply(st.getStatementModel(), CANCELLED); + updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), CANCELLED); assertEquals(StatementState.CANCELLED, running.getStatementState()); @@ -174,6 +176,7 @@ public void cancelRunningStatementFailed() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -203,7 +206,7 @@ public void submitStatementInRunningSession() { .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); assertFalse(statementId.getId().isEmpty()); @@ -225,7 +228,7 @@ public void failToSubmitStatementInDeadState() { new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(createSessionRequest()); - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); IllegalStateException exception = assertThrows( @@ -243,7 +246,7 @@ public void failToSubmitStatementInFailState() { new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(createSessionRequest()); - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); IllegalStateException exception = assertThrows( @@ -297,7 +300,7 @@ public void getStatementSuccess() { new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); @@ -312,7 +315,7 @@ public void getStatementNotExist() { new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); Optional statement = session.get(StatementId.newStatementId()); assertFalse(statement.isPresent()); @@ -330,7 +333,8 @@ public static TestStatement testStatement(Statement st, StateStore stateStore) { public TestStatement assertSessionState(StatementState expected) { assertEquals(expected, st.getStatementModel().getStatementState()); - Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + Optional model = + getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementState()); @@ -340,7 +344,8 @@ public TestStatement assertSessionState(StatementState expected) { public TestStatement assertStatementId(StatementId expected) { assertEquals(expected, st.getStatementModel().getStatementId()); - Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + Optional model = + getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementId()); return this; From e849604427caf4457fc80b4569ea82495a9c442f Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 18 Oct 2023 22:44:58 -0700 Subject: [PATCH 18/20] fix bug when fetch query result Signed-off-by: Peng Huo --- .../model/SparkSubmitParameters.java | 5 +++-- .../spark/client/EmrServerlessClientImpl.java | 4 ++-- .../dispatcher/SparkQueryDispatcher.java | 19 +++++++++++++------ .../execution/session/InteractiveSession.java | 2 +- .../response/JobExecutionResponseReader.java | 4 ++++ .../dispatcher/SparkQueryDispatcherTest.java | 6 +++--- 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index e8a1d8bb06..db78abb2a8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -12,6 +12,7 @@ import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI; import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN; import static org.opensearch.sql.spark.data.constants.SparkConstants.*; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import java.net.URI; import java.net.URISyntaxException; @@ -146,8 +147,8 @@ public Builder extraParameters(String params) { return this; } - public Builder sessionExecution(String sessionId) { - config.put(FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME); + public Builder sessionExecution(String sessionId, String datasourceName) { + config.put(FLINT_JOB_REQUEST_INDEX, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); config.put(FLINT_JOB_SESSION_ID, sessionId); return this; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java index 335f3b6fc8..2f5b00e5bd 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -6,7 +6,6 @@ package org.opensearch.sql.spark.client; import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; import com.amazonaws.services.emrserverless.AWSEMRServerless; import com.amazonaws.services.emrserverless.model.CancelJobRunRequest; @@ -49,7 +48,8 @@ public String startJobRun(StartJobRequest startJobRequest) { new JobDriver() .withSparkSubmit( new SparkSubmit() - .withEntryPoint(SPARK_SQL_APPLICATION_JAR) + .withEntryPoint( + "s3://flint-data-dp-eu-west-1-beta/code/flint/sql-job-assembly-0.1.0-SNAPSHOT.jar") .withEntryPointArguments(startJobRequest.getQuery(), resultIndex) .withSparkSubmitParameters(startJobRequest.getSparkSubmitParams()))); StartJobRunResult startJobRunResult = diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index b09e0ef9e2..2bd1ae67b9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -97,12 +97,19 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) return DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()).result(); } - // either empty json when the result is not available or data with status - // Fetch from Result Index - JSONObject result = - jobExecutionResponseReader.getResultFromOpensearchIndex( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); - + JSONObject result; + if (asyncQueryJobMetadata.getSessionId() == null) { + // either empty json when the result is not available or data with status + // Fetch from Result Index + result = + jobExecutionResponseReader.getResultFromOpensearchIndex( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + } else { + // when session enabled, jobId in asyncQueryJobMetadata is actually queryId. + result = + jobExecutionResponseReader.getResultWithQueryId( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + } // if result index document has a status, we are gonna use the status directly; otherwise, we // will use emr-s job status. // That a job is successful does not mean there is no error in execution. For example, even if diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index af22db5838..4428c3b83d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -45,7 +45,7 @@ public void open(CreateSessionRequest createSessionRequest) { // append session id; createSessionRequest .getSparkSubmitParametersBuilder() - .sessionExecution(sessionId.getSessionId()); + .sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName()); String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java index d3cbd68dce..2614992463 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java @@ -39,6 +39,10 @@ public JSONObject getResultFromOpensearchIndex(String jobId, String resultIndex) return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultIndex); } + public JSONObject getResultWithQueryId(String queryId, String resultIndex) { + return searchInSparkIndex(QueryBuilders.termQuery("queryId", queryId), resultIndex); + } + private JSONObject searchInSparkIndex(QueryBuilder query, String resultIndex) { SearchRequest searchRequest = new SearchRequest(); String searchResultIndex = resultIndex == null ? SPARK_RESPONSE_BUFFER_INDEX_NAME : resultIndex; diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 58fe626dae..15211dec01 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -726,7 +726,7 @@ void testGetQueryResponseWithSession() { doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); JSONObject result = sparkQueryDispatcher.getQueryResponse( asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); @@ -740,7 +740,7 @@ void testGetQueryResponseWithInvalidSession() { doReturn(Optional.empty()).when(sessionManager).getSession(eq(new SessionId(MOCK_SESSION_ID))); doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); IllegalArgumentException exception = Assertions.assertThrows( IllegalArgumentException.class, @@ -759,7 +759,7 @@ void testGetQueryResponseWithStatementNotExist() { doReturn(Optional.empty()).when(session).get(any()); doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); IllegalArgumentException exception = Assertions.assertThrows( From 2a6041122db917c8b92424f12da0c1ef69adda9e Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Thu, 19 Oct 2023 07:42:19 -0700 Subject: [PATCH 19/20] revert entrypoint class Signed-off-by: Peng Huo --- .../opensearch/sql/spark/client/EmrServerlessClientImpl.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java index 2f5b00e5bd..335f3b6fc8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.client; import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; import com.amazonaws.services.emrserverless.AWSEMRServerless; import com.amazonaws.services.emrserverless.model.CancelJobRunRequest; @@ -48,8 +49,7 @@ public String startJobRun(StartJobRequest startJobRequest) { new JobDriver() .withSparkSubmit( new SparkSubmit() - .withEntryPoint( - "s3://flint-data-dp-eu-west-1-beta/code/flint/sql-job-assembly-0.1.0-SNAPSHOT.jar") + .withEntryPoint(SPARK_SQL_APPLICATION_JAR) .withEntryPointArguments(startJobRequest.getQuery(), resultIndex) .withSparkSubmitParameters(startJobRequest.getSparkSubmitParams()))); StartJobRunResult startJobRunResult = From edff2706288ddaad8485c8ebc1452f15b2122b74 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Fri, 20 Oct 2023 00:01:20 -0700 Subject: [PATCH 20/20] update mapping Signed-off-by: Peng Huo --- spark/src/main/resources/query_execution_request_mapping.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spark/src/main/resources/query_execution_request_mapping.yml b/spark/src/main/resources/query_execution_request_mapping.yml index 135910466e..87bd927e6e 100644 --- a/spark/src/main/resources/query_execution_request_mapping.yml +++ b/spark/src/main/resources/query_execution_request_mapping.yml @@ -18,6 +18,8 @@ properties: type: keyword sessionId: type: keyword + sessionType: + type: keyword error: type: text lang: