diff --git a/server/integration-test/src/test/java/com/oceanbase/odc/service/session/NlsFormatInterceptorTest.java b/server/integration-test/src/test/java/com/oceanbase/odc/service/session/NlsFormatInterceptorTest.java index 13b3ff6fb9..5bfdf17fcf 100644 --- a/server/integration-test/src/test/java/com/oceanbase/odc/service/session/NlsFormatInterceptorTest.java +++ b/server/integration-test/src/test/java/com/oceanbase/odc/service/session/NlsFormatInterceptorTest.java @@ -27,6 +27,7 @@ import com.oceanbase.odc.core.sql.execute.model.SqlExecuteStatus; import com.oceanbase.odc.core.sql.execute.model.SqlTuple; import com.oceanbase.odc.service.session.interceptor.NlsFormatInterceptor; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.SqlExecuteResult; /** @@ -43,7 +44,7 @@ public void afterCompletion_mysql_notingSet() throws Exception { ConnectionSession session = getConnectionSession(ConnectType.OB_MYSQL); NlsFormatInterceptor interceptor = new NlsFormatInterceptor(); SqlExecuteResult r = getResponse("set session nls_date_format='DD-MON-RR'", SqlExecuteStatus.SUCCESS); - interceptor.afterCompletion(r, session, new HashMap<>()); + interceptor.afterCompletion(r, session, getContext()); Assert.assertNull(ConnectionSessionUtil.getNlsDateFormat(session)); } @@ -52,7 +53,7 @@ public void afterCompletion_failedSqlResult_notingSet() throws Exception { ConnectionSession session = getConnectionSession(ConnectType.OB_ORACLE); NlsFormatInterceptor interceptor = new NlsFormatInterceptor(); SqlExecuteResult r = getResponse("set session nls_date_format='DD-MON-RR'", SqlExecuteStatus.FAILED); - interceptor.afterCompletion(r, session, new HashMap<>()); + interceptor.afterCompletion(r, session, getContext()); Assert.assertNull(ConnectionSessionUtil.getNlsDateFormat(session)); } @@ -65,7 +66,7 @@ public void afterCompletion_multiSqls_notingSet() throws Exception { + "begin\n" + "dbms_output.put_line('aaaa');\n" + "end;", SqlExecuteStatus.SUCCESS); - interceptor.afterCompletion(r, session, new HashMap<>()); + interceptor.afterCompletion(r, session, getContext()); Assert.assertNull(ConnectionSessionUtil.getNlsDateFormat(session)); } @@ -74,7 +75,7 @@ public void afterCompletion_noSetVarExists_notingSet() throws Exception { ConnectionSession session = getConnectionSession(ConnectType.OB_ORACLE); NlsFormatInterceptor interceptor = new NlsFormatInterceptor(); SqlExecuteResult r = getResponse("-- comment\nselect 123 from dual;", SqlExecuteStatus.SUCCESS); - interceptor.afterCompletion(r, session, new HashMap<>()); + interceptor.afterCompletion(r, session, getContext()); Assert.assertNull(ConnectionSessionUtil.getNlsDateFormat(session)); } @@ -85,7 +86,7 @@ public void afterCompletion_commentWithSetVar_setSucceed() throws Exception { String expect = "DD-MON-RR"; SqlExecuteResult r = getResponse("-- comment\nset session nls_date_format='" + expect + "';", SqlExecuteStatus.SUCCESS); - interceptor.afterCompletion(r, session, new HashMap<>()); + interceptor.afterCompletion(r, session, getContext()); Assert.assertEquals(expect, ConnectionSessionUtil.getNlsDateFormat(session)); } @@ -96,7 +97,7 @@ public void afterCompletion_multiCommentsWithSetVar_setSucceed() throws Exceptio String expect = "DD-MON-RR"; SqlExecuteResult r = getResponse("/*asdasdasd*/ set session nls_date_format='" + expect + "';", SqlExecuteStatus.SUCCESS); - interceptor.afterCompletion(r, session, new HashMap<>()); + interceptor.afterCompletion(r, session, getContext()); Assert.assertEquals(expect, ConnectionSessionUtil.getNlsDateFormat(session)); } @@ -107,7 +108,7 @@ public void afterCompletion_nlsTimestampFormat_setSucceed() throws Exception { String expect = "DD-MON-RR"; SqlExecuteResult r = getResponse("/*asdasdasd*/ set session nls_timestamp_format='" + expect + "';", SqlExecuteStatus.SUCCESS); - interceptor.afterCompletion(r, session, new HashMap<>()); + interceptor.afterCompletion(r, session, getContext()); Assert.assertEquals(expect, ConnectionSessionUtil.getNlsTimestampFormat(session)); } @@ -118,7 +119,7 @@ public void afterCompletion_nlsTimestampTZFormat_setSucceed() throws Exception { String expect = "DD-MON-RR"; SqlExecuteResult r = getResponse("/*asdasdasd*/ set session nls_timestamp_tz_format='" + expect + "';", SqlExecuteStatus.SUCCESS); - interceptor.afterCompletion(r, session, new HashMap<>()); + interceptor.afterCompletion(r, session, getContext()); Assert.assertEquals(expect, ConnectionSessionUtil.getNlsTimestampTZFormat(session)); } @@ -129,7 +130,7 @@ public void afterCompletion_setGlobal_nothingSet() throws Exception { String expect = "DD-MON-RR"; SqlExecuteResult r = getResponse("/*asdasdasd*/ set global nls_timestamp_tz_format='" + expect + "';", SqlExecuteStatus.SUCCESS); - interceptor.afterCompletion(r, session, new HashMap<>()); + interceptor.afterCompletion(r, session, getContext()); Assert.assertNull(ConnectionSessionUtil.getNlsTimestampTZFormat(session)); } @@ -140,7 +141,7 @@ public void afterCompletion_alterSession_setSucceed() throws Exception { String expect = "DD-MON-RR"; SqlExecuteResult r = getResponse("/*asdsd*/ alter session \n\t\r set \"nls_date_format\"='" + expect + "';", SqlExecuteStatus.SUCCESS); - interceptor.afterCompletion(r, session, new HashMap<>()); + interceptor.afterCompletion(r, session, getContext()); Assert.assertEquals(expect, ConnectionSessionUtil.getNlsDateFormat(session)); } @@ -159,4 +160,8 @@ private ConnectionSession getConnectionSession(ConnectType type) { return session; } + private AsyncExecuteContext getContext() { + return new AsyncExecuteContext(null, new HashMap<>()); + } + } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/DataMaskingInterceptor.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/DataMaskingInterceptor.java index 2cc642cf6f..6bc34cea53 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/DataMaskingInterceptor.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/datasecurity/DataMaskingInterceptor.java @@ -20,7 +20,6 @@ import java.util.Comparator; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -40,6 +39,7 @@ import com.oceanbase.odc.service.datasecurity.util.DataMaskingUtil; import com.oceanbase.odc.service.db.browser.DBSchemaAccessors; import com.oceanbase.odc.service.session.interceptor.BaseTimeConsumingInterceptor; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.DBResultSetMetaData; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp; @@ -64,14 +64,14 @@ public class DataMaskingInterceptor extends BaseTimeConsumingInterceptor { @Override public boolean preHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response, - @NonNull ConnectionSession session, @NonNull Map context) { + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) { return true; } @Override @SuppressWarnings("all") public void doAfterCompletion(@NonNull SqlExecuteResult response, @NonNull ConnectionSession session, - @NonNull Map context) throws Exception { + @NonNull AsyncExecuteContext context) throws Exception { // TODO: May intercept sensitive column operation (WHERE / ORDER BY / HAVING) if (!maskingService.isMaskingEnabled()) { return; diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/integration/ExternalSqlInterceptor.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/integration/ExternalSqlInterceptor.java index 37976a3c62..84602fd2c0 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/integration/ExternalSqlInterceptor.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/integration/ExternalSqlInterceptor.java @@ -16,7 +16,6 @@ package com.oceanbase.odc.service.integration; import java.util.List; -import java.util.Map; import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; @@ -48,6 +47,7 @@ import com.oceanbase.odc.service.regulation.ruleset.model.Rule.RuleViolation; import com.oceanbase.odc.service.regulation.ruleset.model.SqlConsoleRules; import com.oceanbase.odc.service.session.interceptor.BaseTimeConsumingInterceptor; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp; import com.oceanbase.odc.service.session.model.SqlExecuteResult; @@ -81,7 +81,7 @@ public class ExternalSqlInterceptor extends BaseTimeConsumingInterceptor { @Override public boolean doPreHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response, - @NonNull ConnectionSession session, @NonNull Map context) { + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) { Long ruleSetId = ConnectionSessionUtil.getRuleSetId(session); if (Objects.isNull(ruleSetId) || isIndividualTeam()) { return true; @@ -147,7 +147,7 @@ public boolean doPreHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyn @Override public void afterCompletion(@NonNull SqlExecuteResult response, - @NonNull ConnectionSession session, @NonNull Map context) {} + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) {} @Override diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/ConnectConsoleService.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/ConnectConsoleService.java index 0e0f7bbdc1..0140483656 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/ConnectConsoleService.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/ConnectConsoleService.java @@ -128,7 +128,6 @@ public class ConnectConsoleService { public static final int DEFAULT_GET_RESULT_TIMEOUT_SECONDS = 3; public static final String SHOW_TABLE_COLUMN_INFO = "SHOW_TABLE_COLUMN_INFO"; - private static final Long EXECUTING_CONTEXT_WAIT_MILLIS = 1100L; @Autowired private ConnectSessionService sessionService; @@ -270,11 +269,12 @@ public SqlAsyncExecuteResp execute(@NotNull String sessionId, context.put(SHOW_TABLE_COLUMN_INFO, request.getShowTableColumnInfo()); context.put(SqlCheckInterceptor.NEED_SQL_CHECK_KEY, needSqlRuleCheck); context.put(SqlConsoleInterceptor.NEED_SQL_CONSOLE_CHECK, needSqlRuleCheck); + AsyncExecuteContext executeContext = new AsyncExecuteContext(sqlTuples, context); List stages = sqlTuples.stream() .map(s -> s.getSqlWatch().start(SqlExecuteStages.SQL_PRE_CHECK)) .collect(Collectors.toList()); try { - if (!sqlInterceptService.preHandle(request, response, connectionSession, context)) { + if (!sqlInterceptService.preHandle(request, response, connectionSession, executeContext)) { return response; } } finally { @@ -352,11 +352,12 @@ public SqlAsyncExecuteResp executeV2(@NotNull String sessionId, context.put(SHOW_TABLE_COLUMN_INFO, request.getShowTableColumnInfo()); context.put(SqlCheckInterceptor.NEED_SQL_CHECK_KEY, needSqlRuleCheck); context.put(SqlConsoleInterceptor.NEED_SQL_CONSOLE_CHECK, needSqlRuleCheck); + AsyncExecuteContext executeContext = new AsyncExecuteContext(sqlTuples, context); List stages = sqlTuples.stream() .map(s -> s.getSqlWatch().start(SqlExecuteStages.SQL_PRE_CHECK)) .collect(Collectors.toList()); try { - if (!sqlInterceptService.preHandle(request, response, connectionSession, context)) { + if (!sqlInterceptService.preHandle(request, response, connectionSession, executeContext)) { return response; } } finally { @@ -373,7 +374,6 @@ public SqlAsyncExecuteResp executeV2(@NotNull String sessionId, Objects.nonNull(request.getContinueExecutionOnError()) ? request.getContinueExecutionOnError() : userConfigFacade.isContinueExecutionOnError(); boolean stopOnError = !continueExecutionOnError; - AsyncExecuteContext executeContext = new AsyncExecuteContext(sqlTuples, context); OdcStatementCallBack statementCallBack = new OdcStatementCallBack(sqlTuples, connectionSession, request.getAutoCommit(), queryLimit, stopOnError, executeContext); @@ -415,10 +415,10 @@ public List getAsyncResult(@NotNull String sessionId, String r Map context = ConnectionSessionUtil.getFutureJdbcContext(connectionSession, requestId); ConnectionSessionUtil.removeFutureJdbc(connectionSession, requestId); return resultList.stream().map(jdbcGeneralResult -> { - Map cxt = context == null ? new HashMap<>() : context; - SqlExecuteResult result = generateResult(connectionSession, jdbcGeneralResult, cxt); + Map ctx = context == null ? new HashMap<>() : context; + SqlExecuteResult result = generateResult(connectionSession, jdbcGeneralResult, ctx); try (TraceStage stage = result.getSqlTuple().getSqlWatch().start(SqlExecuteStages.SQL_AFTER_CHECK)) { - sqlInterceptService.afterCompletion(result, connectionSession, cxt); + sqlInterceptService.afterCompletion(result, connectionSession, new AsyncExecuteContext(null, ctx)); } catch (Exception e) { throw new IllegalStateException(e); } @@ -440,24 +440,26 @@ public AsyncExecuteResultResp getAsyncResultV2(@NotNull String sessionId, String ConnectionSession connectionSession = sessionService.nullSafeGet(sessionId); AsyncExecuteContext context = (AsyncExecuteContext) ConnectionSessionUtil.getExecuteContext(connectionSession, requestId); - List resultList = context.getResults(); - List results = resultList.stream().map(jdbcGeneralResult -> { - SqlExecuteResult result = generateResult(connectionSession, jdbcGeneralResult, context.getContextMap()); - try (TraceStage stage = result.getSqlTuple().getSqlWatch().start(SqlExecuteStages.SQL_AFTER_CHECK)) { - sqlInterceptService.afterCompletion(result, connectionSession, context.getContextMap()); - } catch (Exception e) { - throw new IllegalStateException(e); - } - return result; - }).collect(Collectors.toList()); - if (context.isFinished()) { - ConnectionSessionUtil.removeExecuteContext(connectionSession, requestId); - return new AsyncExecuteResultResp(true, context, results); - } else { - if (log.isDebugEnabled()) { - log.debug("Get sql execution result timed out, sessionId={}, requestId={}", sessionId, requestId); + boolean shouldRemoveContext = context.isFinished(); + try { + List resultList = context.getFinishedSqlExecutionResults(); + List results = resultList.stream().map(jdbcGeneralResult -> { + SqlExecuteResult result = generateResult(connectionSession, jdbcGeneralResult, context.getContextMap()); + try (TraceStage stage = result.getSqlTuple().getSqlWatch().start(SqlExecuteStages.SQL_AFTER_CHECK)) { + sqlInterceptService.afterCompletion(result, connectionSession, context); + } catch (Exception e) { + throw new IllegalStateException(e); + } + return result; + }).collect(Collectors.toList()); + return new AsyncExecuteResultResp(context, results); + } catch (Exception e) { + shouldRemoveContext = true; + throw e; + } finally { + if (shouldRemoveContext) { + ConnectionSessionUtil.removeExecuteContext(connectionSession, requestId); } - return new AsyncExecuteResultResp(false, context, results); } } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/OBExecutionListener.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/OBExecutionListener.java index ca1c835b16..9c2792ad1f 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/OBExecutionListener.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/OBExecutionListener.java @@ -56,7 +56,7 @@ public void onExecutionStart(SqlTuple sqlTuple, AsyncExecuteContext context) {} public void onExecutionEnd(SqlTuple sqlTuple, List results, AsyncExecuteContext context) {} @Override - public void onExecutionCanceled(SqlTuple sqlTuple, List results, AsyncExecuteContext context) {} + public void onExecutionCancelled(SqlTuple sqlTuple, List results, AsyncExecuteContext context) {} public void onExecutionStartAfter(SqlTuple sqlTuple, AsyncExecuteContext context) { if (CollectionUtils.isEmpty(sessionIds)) { diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/OdcStatementCallBack.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/OdcStatementCallBack.java index ecec407fea..7c8d148088 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/OdcStatementCallBack.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/OdcStatementCallBack.java @@ -178,11 +178,7 @@ public List doInStatement(Statement statement) throws SQLExce return this.sqls.stream().map(sqlTuple -> { JdbcGeneralResult result = JdbcGeneralResult.canceledResult(sqlTuple); result.setConnectionReset(true); - if (context != null) { - context.addResult(result); - } - listeners.forEach( - listener -> listener.onExecutionCanceled(sqlTuple, Collections.singletonList(result), context)); + onExecutionCancelled(sqlTuple, Collections.singletonList(result)); return result; }).collect(Collectors.toList()); } @@ -203,12 +199,7 @@ public List doInStatement(Statement statement) throws SQLExce // eat exception } } - if (context != null) { - context.setCurrentExecutingSql(sqlTuple.getExecutedSql()); - context.addCount(); - context.setCurrentExecutingSqlTraceId(null); - } - listeners.forEach(listener -> listener.onExecutionStart(sqlTuple, context)); + onExecutionStart(sqlTuple); try { applyConnectionSettings(statement); } catch (Exception e) { @@ -219,37 +210,16 @@ public List doInStatement(Statement statement) throws SQLExce if (Thread.currentThread().isInterrupted() || ConnectionSessionUtil.isConsoleSessionKillQuery(connectionSession)) { executeResults = Collections.singletonList(JdbcGeneralResult.canceledResult(sqlTuple)); - listeners.forEach(listener -> listener.onExecutionCanceled(sqlTuple, executeResults, context)); + onExecutionCancelled(sqlTuple, executeResults); } else { CountDownLatch latch = new CountDownLatch(1); - handle = executor.submit(() -> { - long startTs = System.currentTimeMillis(); - List sortedListeners = listeners.stream() - .filter(listener -> listener.getOnExecutionStartAfterMillis() != null - && listener.getOnExecutionStartAfterMillis() > 0) - .sorted(Comparator - .comparingLong(SqlExecutionListener::getOnExecutionStartAfterMillis)) - .collect(Collectors.toList()); - for (SqlExecutionListener listener : sortedListeners) { - long waitTs = System.currentTimeMillis() - startTs; - Long expectedTs = listener.getOnExecutionStartAfterMillis(); - if (!latch.await(expectedTs - waitTs, TimeUnit.MILLISECONDS)) { - listener.onExecutionStartAfter(sqlTuple, context); - } else { - break; - } - } - return null; - }); + handle = executor.submit(() -> onExecutionStartAfterMillis(sqlTuple, latch)); executeResults = doExecuteSql(statement, sqlTuple, latch); - listeners.forEach(listener -> listener.onExecutionEnd(sqlTuple, executeResults, context)); + onExecutionEnd(sqlTuple, executeResults); } } else { executeResults = Collections.singletonList(JdbcGeneralResult.canceledResult(sqlTuple)); - listeners.forEach(listener -> listener.onExecutionCanceled(sqlTuple, executeResults, context)); - } - if (context != null) { - context.addResults(executeResults); + onExecutionCancelled(sqlTuple, executeResults); } returnVal.addAll(executeResults); } @@ -274,6 +244,7 @@ public List doInStatement(Statement statement) throws SQLExce log.info("Clear dbms_output cache, dbmsInfo={}", dbmsInfo); } } + executor.shutdownNow(); } return returnVal; } @@ -596,6 +567,73 @@ private void rollback(Connection connection) { } } + private void onExecutionStart(SqlTuple sqlTuple) { + if (context != null) { + context.setCurrentExecutingSql(sqlTuple.getExecutedSql()); + context.incrementTotalExecutedSqlCount(); + context.setCurrentExecutingSqlTraceId(null); + } + listeners.forEach(listener -> { + try { + listener.onExecutionStart(sqlTuple, context); + } catch (Exception e) { + log.warn("An error occurred in listener {}.", listener.getClass(), e); + } + }); + } + + private void onExecutionCancelled(SqlTuple sqlTuple, List results) { + if (context != null) { + context.addSqlExecutionResults(results); + } + listeners.forEach(listener -> { + try { + listener.onExecutionCancelled(sqlTuple, results, context); + } catch (Exception e) { + log.warn("An error occurred in listener {}.", listener.getClass(), e); + } + }); + } + + private void onExecutionEnd(SqlTuple sqlTuple, List results) { + if (context != null) { + context.addSqlExecutionResults(results); + } + listeners.forEach(listener -> { + try { + listener.onExecutionEnd(sqlTuple, results, context); + } catch (Exception e) { + log.warn("An error occurred in listener {}.", listener.getClass(), e); + } + }); + } + + private Void onExecutionStartAfterMillis(SqlTuple sqlTuple, CountDownLatch latch) { + long startTs = System.currentTimeMillis(); + List sortedListeners = listeners.stream() + .filter(listener -> listener.getOnExecutionStartAfterMillis() != null + && listener.getOnExecutionStartAfterMillis() > 0) + .sorted(Comparator + .comparingLong(SqlExecutionListener::getOnExecutionStartAfterMillis)) + .collect(Collectors.toList()); + for (SqlExecutionListener listener : sortedListeners) { + long waitTs = System.currentTimeMillis() - startTs; + Long expectedTs = listener.getOnExecutionStartAfterMillis(); + try { + if (!latch.await(expectedTs - waitTs, TimeUnit.MILLISECONDS)) { + listener.onExecutionStartAfter(sqlTuple, context); + } else { + break; + } + } catch (InterruptedException e) { + return null; + } catch (Exception e) { + log.warn("An error occurred in listener {}.", listener.getClass(), e); + } + } + return null; + } + @Getter @ToString static class FunctionDefinition { diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SqlExecutionListener.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SqlExecutionListener.java index 1c07102214..1143412a21 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SqlExecutionListener.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SqlExecutionListener.java @@ -31,7 +31,7 @@ public interface SqlExecutionListener { void onExecutionEnd(SqlTuple sqlTuple, List results, AsyncExecuteContext context); - void onExecutionCanceled(SqlTuple sqlTuple, List results, AsyncExecuteContext context); + void onExecutionCancelled(SqlTuple sqlTuple, List results, AsyncExecuteContext context); void onExecutionStartAfter(SqlTuple sqlTuple, AsyncExecuteContext context); diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/BaseTimeConsumingInterceptor.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/BaseTimeConsumingInterceptor.java index 5ba5bd2bf2..77475f26a4 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/BaseTimeConsumingInterceptor.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/BaseTimeConsumingInterceptor.java @@ -17,11 +17,11 @@ package com.oceanbase.odc.service.session.interceptor; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import com.oceanbase.odc.common.util.TraceStage; import com.oceanbase.odc.core.session.ConnectionSession; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp; import com.oceanbase.odc.service.session.model.SqlExecuteResult; @@ -32,7 +32,7 @@ public abstract class BaseTimeConsumingInterceptor implements SqlExecuteIntercep @Override public boolean preHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response, - @NonNull ConnectionSession session, @NonNull Map context) throws Exception { + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) throws Exception { List stageList = response.getSqls().stream() .map(v -> v.getSqlTuple().getSqlWatch().start(getExecuteStageName())) .collect(Collectors.toList()); @@ -51,19 +51,19 @@ public boolean preHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncE @Override public void afterCompletion(@NonNull SqlExecuteResult response, - @NonNull ConnectionSession session, @NonNull Map context) throws Exception { + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) throws Exception { try (TraceStage stage = response.getSqlTuple().getSqlWatch().start(getExecuteStageName())) { doAfterCompletion(response, session, context); } } protected boolean doPreHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response, - @NonNull ConnectionSession session, @NonNull Map context) throws Exception { + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) throws Exception { return true; } protected void doAfterCompletion(@NonNull SqlExecuteResult response, - @NonNull ConnectionSession session, @NonNull Map context) throws Exception {} + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) throws Exception {} protected abstract String getExecuteStageName(); diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/DatabasePermissionInterceptor.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/DatabasePermissionInterceptor.java index 2a3fa79b3f..ba1696c133 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/DatabasePermissionInterceptor.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/DatabasePermissionInterceptor.java @@ -37,6 +37,7 @@ import com.oceanbase.odc.service.iam.auth.AuthenticationFacade; import com.oceanbase.odc.service.permission.database.model.DatabasePermissionType; import com.oceanbase.odc.service.permission.database.model.UnauthorizedDatabase; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp; import com.oceanbase.odc.service.session.model.SqlExecuteResult; @@ -69,7 +70,7 @@ public int getOrder() { @Override public boolean doPreHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response, - @NonNull ConnectionSession session, @NonNull Map context) throws Exception { + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) throws Exception { if (authenticationFacade.currentUser().getOrganizationType() == OrganizationType.INDIVIDUAL) { return true; } @@ -100,7 +101,7 @@ public boolean doPreHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyn @Override public void afterCompletion(@NonNull SqlExecuteResult response, @NonNull ConnectionSession session, - @NonNull Map context) {} + @NonNull AsyncExecuteContext context) {} @Override protected String getExecuteStageName() { diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/NlsFormatInterceptor.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/NlsFormatInterceptor.java index 1a2e79d350..53e74ef523 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/NlsFormatInterceptor.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/NlsFormatInterceptor.java @@ -17,7 +17,6 @@ import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; @@ -36,6 +35,7 @@ import com.oceanbase.odc.core.sql.parser.AbstractSyntaxTreeFactories; import com.oceanbase.odc.core.sql.split.OffsetString; import com.oceanbase.odc.core.sql.split.SqlCommentProcessor; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp; import com.oceanbase.odc.service.session.model.SqlExecuteResult; @@ -67,13 +67,13 @@ public class NlsFormatInterceptor extends BaseTimeConsumingInterceptor { @Override public boolean preHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response, - @NonNull ConnectionSession session, @NonNull Map context) { + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) { return true; } @Override public void doAfterCompletion(@NonNull SqlExecuteResult response, - @NonNull ConnectionSession session, @NonNull Map context) { + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) { DialectType dialect = session.getDialectType(); if (response.getStatus() != SqlExecuteStatus.SUCCESS || dialect != DialectType.OB_ORACLE) { return; diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlCheckInterceptor.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlCheckInterceptor.java index e0d1403db4..19b8585156 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlCheckInterceptor.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlCheckInterceptor.java @@ -37,6 +37,7 @@ import com.oceanbase.odc.service.iam.auth.AuthenticationFacade; import com.oceanbase.odc.service.regulation.ruleset.RuleService; import com.oceanbase.odc.service.regulation.ruleset.model.Rule; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp; import com.oceanbase.odc.service.session.model.SqlExecuteResult; @@ -74,11 +75,12 @@ public class SqlCheckInterceptor extends BaseTimeConsumingInterceptor { @Override public boolean doPreHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response, - @NonNull ConnectionSession session, @NonNull Map context) { - boolean sqlCheckIntercepted = handle(request, response, session, context); - context.put(SQL_CHECK_INTERCEPTED, sqlCheckIntercepted); - if (Objects.nonNull(context.get(SqlConsoleInterceptor.SQL_CONSOLE_INTERCEPTED))) { - return sqlCheckIntercepted && (Boolean) context.get(SqlConsoleInterceptor.SQL_CONSOLE_INTERCEPTED); + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) { + Map ctx = context.getContextMap(); + boolean sqlCheckIntercepted = handle(request, response, session, ctx); + ctx.put(SQL_CHECK_INTERCEPTED, sqlCheckIntercepted); + if (Objects.nonNull(ctx.get(SqlConsoleInterceptor.SQL_CONSOLE_INTERCEPTED))) { + return sqlCheckIntercepted && (Boolean) ctx.get(SqlConsoleInterceptor.SQL_CONSOLE_INTERCEPTED); } else { return true; } @@ -132,11 +134,12 @@ protected String getExecuteStageName() { @Override @SuppressWarnings("all") public void afterCompletion(@NonNull SqlExecuteResult response, @NonNull ConnectionSession session, - @NonNull Map context) throws Exception { - if (!context.containsKey(SQL_CHECK_RESULT_KEY)) { + @NonNull AsyncExecuteContext context) throws Exception { + Map ctx = context.getContextMap(); + if (!ctx.containsKey(SQL_CHECK_RESULT_KEY)) { return; } - Map> map = (Map>) context.get(SQL_CHECK_RESULT_KEY); + Map> map = (Map>) ctx.get(SQL_CHECK_RESULT_KEY); List results = map.get(response.getSqlTuple().getOffset()); if (CollectionUtils.isEmpty(results)) { return; diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlConsoleInterceptor.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlConsoleInterceptor.java index 7610a3f1ff..35de0eff18 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlConsoleInterceptor.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlConsoleInterceptor.java @@ -42,6 +42,7 @@ import com.oceanbase.odc.service.regulation.ruleset.model.Rule; import com.oceanbase.odc.service.regulation.ruleset.model.Rule.RuleViolation; import com.oceanbase.odc.service.regulation.ruleset.model.SqlConsoleRules; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp; import com.oceanbase.odc.service.session.model.SqlExecuteResult; @@ -74,11 +75,12 @@ public class SqlConsoleInterceptor extends BaseTimeConsumingInterceptor { @Override public boolean doPreHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response, - @NonNull ConnectionSession session, @NonNull Map context) { - boolean sqlConsoleIntercepted = handle(request, response, session, context); - context.put(SQL_CONSOLE_INTERCEPTED, sqlConsoleIntercepted); - if (Objects.nonNull(context.get(SqlCheckInterceptor.SQL_CHECK_INTERCEPTED))) { - return sqlConsoleIntercepted && (Boolean) context.get(SqlCheckInterceptor.SQL_CHECK_INTERCEPTED); + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) { + Map ctx = context.getContextMap(); + boolean sqlConsoleIntercepted = handle(request, response, session, ctx); + ctx.put(SQL_CONSOLE_INTERCEPTED, sqlConsoleIntercepted); + if (Objects.nonNull(ctx.get(SqlCheckInterceptor.SQL_CHECK_INTERCEPTED))) { + return sqlConsoleIntercepted && (Boolean) ctx.get(SqlCheckInterceptor.SQL_CHECK_INTERCEPTED); } else { return true; } @@ -199,7 +201,7 @@ protected String getExecuteStageName() { @Override public void doAfterCompletion(@NonNull SqlExecuteResult response, @NonNull ConnectionSession session, - @NonNull Map context) { + @NonNull AsyncExecuteContext context) { if (response.getStatus() != SqlExecuteStatus.SUCCESS) { return; } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlExecuteInterceptor.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlExecuteInterceptor.java index 5fb77ecaa9..225e350614 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlExecuteInterceptor.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlExecuteInterceptor.java @@ -15,11 +15,10 @@ */ package com.oceanbase.odc.service.session.interceptor; -import java.util.Map; - import org.springframework.core.Ordered; import com.oceanbase.odc.core.session.ConnectionSession; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp; import com.oceanbase.odc.service.session.model.SqlExecuteResult; @@ -44,11 +43,11 @@ public interface SqlExecuteInterceptor extends Ordered { * @return whether to execute this sql */ default boolean preHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response, - @NonNull ConnectionSession session, @NonNull Map context) throws Exception { + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) throws Exception { return true; } default void afterCompletion(@NonNull SqlExecuteResult response, - @NonNull ConnectionSession session, @NonNull Map context) throws Exception {} + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) throws Exception {} } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlExecuteInterceptorService.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlExecuteInterceptorService.java index 76ad156cc9..662ed3bb5b 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlExecuteInterceptorService.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SqlExecuteInterceptorService.java @@ -29,6 +29,7 @@ import com.oceanbase.odc.core.authority.util.SkipAuthorize; import com.oceanbase.odc.core.session.ConnectionSession; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq; import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp; import com.oceanbase.odc.service.session.model.SqlExecuteResult; @@ -61,7 +62,7 @@ public void init() { } public boolean preHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response, - @NonNull ConnectionSession session, @NonNull Map context) throws Exception { + @NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) throws Exception { for (SqlExecuteInterceptor interceptor : interceptors) { if (interceptor.preHandle(request, response, session, context)) { continue; @@ -72,7 +73,7 @@ public boolean preHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncE } public void afterCompletion(@NonNull SqlExecuteResult response, @NonNull ConnectionSession session, - @NonNull Map context) throws Exception { + @NonNull AsyncExecuteContext context) throws Exception { for (SqlExecuteInterceptor interceptor : interceptors) { interceptor.afterCompletion(response, session, context); } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SwitchDatabaseInterceptor.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SwitchDatabaseInterceptor.java index 73a9b31fa7..9d41f99e8e 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SwitchDatabaseInterceptor.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/interceptor/SwitchDatabaseInterceptor.java @@ -16,7 +16,6 @@ package com.oceanbase.odc.service.session.interceptor; import java.util.Collections; -import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -25,6 +24,7 @@ import com.oceanbase.odc.core.session.ConnectionSession; import com.oceanbase.odc.core.session.ConnectionSessionUtil; import com.oceanbase.odc.core.sql.execute.model.SqlExecuteStatus; +import com.oceanbase.odc.service.session.model.AsyncExecuteContext; import com.oceanbase.odc.service.session.model.SqlExecuteResult; import com.oceanbase.odc.service.session.util.SchemaExtractor; @@ -41,7 +41,7 @@ public class SwitchDatabaseInterceptor implements SqlExecuteInterceptor { @Override public void afterCompletion(@NonNull SqlExecuteResult response, @NonNull ConnectionSession session, - @NonNull Map context) throws Exception { + @NonNull AsyncExecuteContext context) throws Exception { if (response.getStatus() != SqlExecuteStatus.SUCCESS) { return; } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/model/AsyncExecuteContext.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/model/AsyncExecuteContext.java index 8102935dad..69e5b9d770 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/model/AsyncExecuteContext.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/model/AsyncExecuteContext.java @@ -44,7 +44,7 @@ public class AsyncExecuteContext { private Future> future; private String currentExecutingSqlTraceId; private String currentExecutingSql; - private int totalSqlExecutedCount = 0; + private int totalExecutedSqlCount = 0; public AsyncExecuteContext(List sqlTuples, Map contextMap) { this.sqlTuples = sqlTuples; @@ -52,18 +52,25 @@ public AsyncExecuteContext(List sqlTuples, Map context } public boolean isFinished() { - return future.isDone(); + return future != null && future.isDone(); } - public void addCount() { - totalSqlExecutedCount++; + public boolean isCancelled() { + return future != null && future.isCancelled(); } - public int getTotal() { + public void incrementTotalExecutedSqlCount() { + totalExecutedSqlCount++; + } + + public int getToBeExecutedSqlCount() { return sqlTuples.size(); } - public List getResults() { + /** + * only return the incremental results + */ + public List getFinishedSqlExecutionResults() { List copiedResults = new ArrayList<>(); while (!results.isEmpty()) { copiedResults.add(results.poll()); @@ -71,11 +78,7 @@ public List getResults() { return copiedResults; } - public void addResult(JdbcGeneralResult result) { - this.results.add(result); - } - - public void addResults(List results) { + public void addSqlExecutionResults(List results) { this.results.addAll(results); } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/model/AsyncExecuteResultResp.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/model/AsyncExecuteResultResp.java index eb6edac6e9..0eb747b342 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/model/AsyncExecuteResultResp.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/model/AsyncExecuteResultResp.java @@ -32,12 +32,12 @@ public class AsyncExecuteResultResp { private boolean finished; private String sql; - public AsyncExecuteResultResp(boolean finished, AsyncExecuteContext context, List results) { - this.finished = finished; + public AsyncExecuteResultResp(AsyncExecuteContext context, List results) { + this.finished = context.isFinished(); this.results = results; traceId = context.getCurrentExecutingSqlTraceId(); - total = context.getTotal(); - count = context.getTotalSqlExecutedCount(); + total = context.getToBeExecutedSqlCount(); + count = context.getTotalExecutedSqlCount(); sql = context.getCurrentExecutingSql(); } }