Skip to content

Commit

Permalink
Call LeaseManager for BatchQuery (#3153)
Browse files Browse the repository at this point in the history
* Call LeaseManager for BatchQuery

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

* Reformat code

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

* Fix unit test for coverage

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

* Reformat

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

---------

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
  • Loading branch information
ykmr1224 authored Nov 14, 2024
1 parent 95a1643 commit 5b3cdd8
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
import org.opensearch.sql.spark.dispatcher.model.JobType;
import org.opensearch.sql.spark.leasemanager.LeaseManager;
import org.opensearch.sql.spark.leasemanager.model.LeaseRequest;
import org.opensearch.sql.spark.metrics.MetricsService;
import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
Expand Down Expand Up @@ -75,13 +76,23 @@ public String cancelJob(
return asyncQueryJobMetadata.getQueryId();
}

/**
* This method allows RefreshQueryHandler to override the job type when calling
* leaseManager.borrow.
*/
protected void borrow(String datasource) {
leaseManager.borrow(new LeaseRequest(JobType.BATCH, datasource));
}

@Override
public DispatchQueryResponse submit(
DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) {
String clusterName = dispatchQueryRequest.getClusterName();
Map<String, String> tags = context.getTags();
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();

this.borrow(dispatchQueryRequest.getDatasource());

tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
StartJobRequest startJobRequest =
new StartJobRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,14 @@ public String cancelJob(
return asyncQueryJobMetadata.getQueryId();
}

@Override
protected void borrow(String datasource) {
leaseManager.borrow(new LeaseRequest(JobType.REFRESH, datasource));
}

@Override
public DispatchQueryResponse submit(
DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) {
leaseManager.borrow(new LeaseRequest(JobType.REFRESH, dispatchQueryRequest.getDatasource()));

DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context);
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import org.opensearch.sql.spark.flint.IndexDMLResultStorageService;
import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory;
import org.opensearch.sql.spark.leasemanager.LeaseManager;
import org.opensearch.sql.spark.leasemanager.model.LeaseRequest;
import org.opensearch.sql.spark.metrics.MetricsService;
import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection;
import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider;
Expand Down Expand Up @@ -137,6 +138,7 @@ public class AsyncQueryCoreIntegTest {
@Captor ArgumentCaptor<FlintIndexOptions> flintIndexOptionsArgumentCaptor;
@Captor ArgumentCaptor<StartJobRunRequest> startJobRunRequestArgumentCaptor;
@Captor ArgumentCaptor<CreateSessionRequest> createSessionRequestArgumentCaptor;
@Captor ArgumentCaptor<LeaseRequest> leaseRequestArgumentCaptor;

AsyncQueryExecutorService asyncQueryExecutorService;

Expand Down Expand Up @@ -267,7 +269,8 @@ public void createVacuumIndexQuery() {
assertEquals(SESSION_ID, response.getSessionId());
verifyGetQueryIdCalled();
verifyGetSessionIdCalled();
verify(leaseManager).borrow(any());
verify(leaseManager).borrow(leaseRequestArgumentCaptor.capture());
assertEquals(JobType.INTERACTIVE, leaseRequestArgumentCaptor.getValue().getJobType());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.INTERACTIVE);
}
Expand Down Expand Up @@ -356,11 +359,38 @@ public void createStreamingQuery() {
assertEquals(QUERY_ID, response.getQueryId());
assertNull(response.getSessionId());
verifyGetQueryIdCalled();
verify(leaseManager).borrow(any());
verify(leaseManager).borrow(leaseRequestArgumentCaptor.capture());
assertEquals(JobType.STREAMING, leaseRequestArgumentCaptor.getValue().getJobType());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.STREAMING);
}

@Test
public void createBatchQuery() {
givenSparkExecutionEngineConfigIsSupplied();
givenValidDataSourceMetadataExist();
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
when(awsemrServerless.startJobRun(any()))
.thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID));

CreateAsyncQueryResponse response =
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"CREATE INDEX index_name ON table_name(l_orderkey, l_quantity)"
+ " WITH (auto_refresh = false)",
DATASOURCE_NAME,
LangType.SQL),
asyncQueryRequestContext);

assertEquals(QUERY_ID, response.getQueryId());
assertNull(response.getSessionId());
verifyGetQueryIdCalled();
verify(leaseManager).borrow(leaseRequestArgumentCaptor.capture());
assertEquals(JobType.BATCH, leaseRequestArgumentCaptor.getValue().getJobType());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.BATCH);
}

private void verifyStartJobRunCalled() {
verify(awsemrServerless).startJobRun(startJobRunRequestArgumentCaptor.capture());
StartJobRunRequest startJobRunRequest = startJobRunRequestArgumentCaptor.getValue();
Expand Down Expand Up @@ -413,7 +443,8 @@ public void createRefreshQuery() {
assertEquals(QUERY_ID, response.getQueryId());
assertNull(response.getSessionId());
verifyGetQueryIdCalled();
verify(leaseManager).borrow(any());
verify(leaseManager).borrow(leaseRequestArgumentCaptor.capture());
assertEquals(JobType.REFRESH, leaseRequestArgumentCaptor.getValue().getJobType());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.REFRESH);
}
Expand All @@ -439,7 +470,8 @@ public void createInteractiveQuery() {
assertEquals(SESSION_ID, response.getSessionId());
verifyGetQueryIdCalled();
verifyGetSessionIdCalled();
verify(leaseManager).borrow(any());
verify(leaseManager).borrow(leaseRequestArgumentCaptor.capture());
assertEquals(JobType.INTERACTIVE, leaseRequestArgumentCaptor.getValue().getJobType());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.INTERACTIVE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ private void createIndex(String indexName) {
}
}

private long count(String indexName, QueryBuilder query) {
@VisibleForTesting
public long count(String indexName, QueryBuilder query) {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(query);
searchSourceBuilder.size(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ public String description() {

@Override
public boolean test(LeaseRequest leaseRequest) {
if (leaseRequest.getJobType() == JobType.INTERACTIVE) {
if (leaseRequest.getJobType() != JobType.REFRESH
&& leaseRequest.getJobType() != JobType.STREAMING) {
return true;
}
return activeRefreshJobCount(stateStore, ALL_DATASOURCE).get() < refreshJobLimit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package org.opensearch.sql.spark.leasemanager;

import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -23,19 +25,36 @@ class DefaultLeaseManagerTest {
@Mock private StateStore stateStore;

@Test
public void concurrentSessionRuleOnlyApplyToInteractiveQuery() {
assertTrue(
new DefaultLeaseManager.ConcurrentSessionRule(settings, stateStore)
.test(new LeaseRequest(JobType.BATCH, "mys3")));
assertTrue(
new DefaultLeaseManager.ConcurrentSessionRule(settings, stateStore)
.test(new LeaseRequest(JobType.STREAMING, "mys3")));
public void leaseManagerRejectsJobs() {
when(stateStore.count(any(), any())).thenReturn(3L);
when(settings.getSettingValue(any())).thenReturn(3);
DefaultLeaseManager defaultLeaseManager = new DefaultLeaseManager(settings, stateStore);

defaultLeaseManager.borrow(getLeaseRequest(JobType.BATCH));
assertThrows(
ConcurrencyLimitExceededException.class,
() -> defaultLeaseManager.borrow(getLeaseRequest(JobType.INTERACTIVE)));
assertThrows(
ConcurrencyLimitExceededException.class,
() -> defaultLeaseManager.borrow(getLeaseRequest(JobType.STREAMING)));
assertThrows(
ConcurrencyLimitExceededException.class,
() -> defaultLeaseManager.borrow(getLeaseRequest(JobType.REFRESH)));
}

@Test
public void concurrentRefreshRuleOnlyNotAppliedToInteractiveQuery() {
assertTrue(
new DefaultLeaseManager.ConcurrentRefreshJobRule(settings, stateStore)
.test(new LeaseRequest(JobType.INTERACTIVE, "mys3")));
public void leaseManagerAcceptsJobs() {
when(stateStore.count(any(), any())).thenReturn(2L);
when(settings.getSettingValue(any())).thenReturn(3);
DefaultLeaseManager defaultLeaseManager = new DefaultLeaseManager(settings, stateStore);

defaultLeaseManager.borrow(getLeaseRequest(JobType.BATCH));
defaultLeaseManager.borrow(getLeaseRequest(JobType.INTERACTIVE));
defaultLeaseManager.borrow(getLeaseRequest(JobType.STREAMING));
defaultLeaseManager.borrow(getLeaseRequest(JobType.REFRESH));
}

private LeaseRequest getLeaseRequest(JobType jobType) {
return new LeaseRequest(jobType, "mys3");
}
}

0 comments on commit 5b3cdd8

Please sign in to comment.