Skip to content

Commit

Permalink
response to CR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
LuckyPickleZZ committed Apr 24, 2024
1 parent 6997eec commit ce6eced
Show file tree
Hide file tree
Showing 17 changed files with 179 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.oceanbase.odc.core.sql.execute.model.SqlExecuteStatus;
import com.oceanbase.odc.core.sql.execute.model.SqlTuple;
import com.oceanbase.odc.service.session.interceptor.NlsFormatInterceptor;
import com.oceanbase.odc.service.session.model.AsyncExecuteContext;
import com.oceanbase.odc.service.session.model.SqlExecuteResult;

/**
Expand All @@ -43,7 +44,7 @@ public void afterCompletion_mysql_notingSet() throws Exception {
ConnectionSession session = getConnectionSession(ConnectType.OB_MYSQL);
NlsFormatInterceptor interceptor = new NlsFormatInterceptor();
SqlExecuteResult r = getResponse("set session nls_date_format='DD-MON-RR'", SqlExecuteStatus.SUCCESS);
interceptor.afterCompletion(r, session, new HashMap<>());
interceptor.afterCompletion(r, session, getContext());
Assert.assertNull(ConnectionSessionUtil.getNlsDateFormat(session));
}

Expand All @@ -52,7 +53,7 @@ public void afterCompletion_failedSqlResult_notingSet() throws Exception {
ConnectionSession session = getConnectionSession(ConnectType.OB_ORACLE);
NlsFormatInterceptor interceptor = new NlsFormatInterceptor();
SqlExecuteResult r = getResponse("set session nls_date_format='DD-MON-RR'", SqlExecuteStatus.FAILED);
interceptor.afterCompletion(r, session, new HashMap<>());
interceptor.afterCompletion(r, session, getContext());
Assert.assertNull(ConnectionSessionUtil.getNlsDateFormat(session));
}

Expand All @@ -65,7 +66,7 @@ public void afterCompletion_multiSqls_notingSet() throws Exception {
+ "begin\n"
+ "dbms_output.put_line('aaaa');\n"
+ "end;", SqlExecuteStatus.SUCCESS);
interceptor.afterCompletion(r, session, new HashMap<>());
interceptor.afterCompletion(r, session, getContext());
Assert.assertNull(ConnectionSessionUtil.getNlsDateFormat(session));
}

Expand All @@ -74,7 +75,7 @@ public void afterCompletion_noSetVarExists_notingSet() throws Exception {
ConnectionSession session = getConnectionSession(ConnectType.OB_ORACLE);
NlsFormatInterceptor interceptor = new NlsFormatInterceptor();
SqlExecuteResult r = getResponse("-- comment\nselect 123 from dual;", SqlExecuteStatus.SUCCESS);
interceptor.afterCompletion(r, session, new HashMap<>());
interceptor.afterCompletion(r, session, getContext());
Assert.assertNull(ConnectionSessionUtil.getNlsDateFormat(session));
}

Expand All @@ -85,7 +86,7 @@ public void afterCompletion_commentWithSetVar_setSucceed() throws Exception {
String expect = "DD-MON-RR";
SqlExecuteResult r = getResponse("-- comment\nset session nls_date_format='" + expect + "';",
SqlExecuteStatus.SUCCESS);
interceptor.afterCompletion(r, session, new HashMap<>());
interceptor.afterCompletion(r, session, getContext());
Assert.assertEquals(expect, ConnectionSessionUtil.getNlsDateFormat(session));
}

Expand All @@ -96,7 +97,7 @@ public void afterCompletion_multiCommentsWithSetVar_setSucceed() throws Exceptio
String expect = "DD-MON-RR";
SqlExecuteResult r = getResponse("/*asdasdasd*/ set session nls_date_format='" + expect + "';",
SqlExecuteStatus.SUCCESS);
interceptor.afterCompletion(r, session, new HashMap<>());
interceptor.afterCompletion(r, session, getContext());
Assert.assertEquals(expect, ConnectionSessionUtil.getNlsDateFormat(session));
}

Expand All @@ -107,7 +108,7 @@ public void afterCompletion_nlsTimestampFormat_setSucceed() throws Exception {
String expect = "DD-MON-RR";
SqlExecuteResult r = getResponse("/*asdasdasd*/ set session nls_timestamp_format='" + expect + "';",
SqlExecuteStatus.SUCCESS);
interceptor.afterCompletion(r, session, new HashMap<>());
interceptor.afterCompletion(r, session, getContext());
Assert.assertEquals(expect, ConnectionSessionUtil.getNlsTimestampFormat(session));
}

Expand All @@ -118,7 +119,7 @@ public void afterCompletion_nlsTimestampTZFormat_setSucceed() throws Exception {
String expect = "DD-MON-RR";
SqlExecuteResult r = getResponse("/*asdasdasd*/ set session nls_timestamp_tz_format='" + expect + "';",
SqlExecuteStatus.SUCCESS);
interceptor.afterCompletion(r, session, new HashMap<>());
interceptor.afterCompletion(r, session, getContext());
Assert.assertEquals(expect, ConnectionSessionUtil.getNlsTimestampTZFormat(session));
}

Expand All @@ -129,7 +130,7 @@ public void afterCompletion_setGlobal_nothingSet() throws Exception {
String expect = "DD-MON-RR";
SqlExecuteResult r = getResponse("/*asdasdasd*/ set global nls_timestamp_tz_format='" + expect + "';",
SqlExecuteStatus.SUCCESS);
interceptor.afterCompletion(r, session, new HashMap<>());
interceptor.afterCompletion(r, session, getContext());
Assert.assertNull(ConnectionSessionUtil.getNlsTimestampTZFormat(session));
}

Expand All @@ -140,7 +141,7 @@ public void afterCompletion_alterSession_setSucceed() throws Exception {
String expect = "DD-MON-RR";
SqlExecuteResult r = getResponse("/*asdsd*/ alter session \n\t\r set \"nls_date_format\"='" + expect + "';",
SqlExecuteStatus.SUCCESS);
interceptor.afterCompletion(r, session, new HashMap<>());
interceptor.afterCompletion(r, session, getContext());
Assert.assertEquals(expect, ConnectionSessionUtil.getNlsDateFormat(session));
}

Expand All @@ -159,4 +160,8 @@ private ConnectionSession getConnectionSession(ConnectType type) {
return session;
}

private AsyncExecuteContext getContext() {
return new AsyncExecuteContext(null, new HashMap<>());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
Expand All @@ -40,6 +39,7 @@
import com.oceanbase.odc.service.datasecurity.util.DataMaskingUtil;
import com.oceanbase.odc.service.db.browser.DBSchemaAccessors;
import com.oceanbase.odc.service.session.interceptor.BaseTimeConsumingInterceptor;
import com.oceanbase.odc.service.session.model.AsyncExecuteContext;
import com.oceanbase.odc.service.session.model.DBResultSetMetaData;
import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq;
import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp;
Expand All @@ -64,14 +64,14 @@ public class DataMaskingInterceptor extends BaseTimeConsumingInterceptor {

@Override
public boolean preHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response,
@NonNull ConnectionSession session, @NonNull Map<String, Object> context) {
@NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) {
return true;
}

@Override
@SuppressWarnings("all")
public void doAfterCompletion(@NonNull SqlExecuteResult response, @NonNull ConnectionSession session,
@NonNull Map<String, Object> context) throws Exception {
@NonNull AsyncExecuteContext context) throws Exception {
// TODO: May intercept sensitive column operation (WHERE / ORDER BY / HAVING)
if (!maskingService.isMaskingEnabled()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package com.oceanbase.odc.service.integration;

import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -48,6 +47,7 @@
import com.oceanbase.odc.service.regulation.ruleset.model.Rule.RuleViolation;
import com.oceanbase.odc.service.regulation.ruleset.model.SqlConsoleRules;
import com.oceanbase.odc.service.session.interceptor.BaseTimeConsumingInterceptor;
import com.oceanbase.odc.service.session.model.AsyncExecuteContext;
import com.oceanbase.odc.service.session.model.SqlAsyncExecuteReq;
import com.oceanbase.odc.service.session.model.SqlAsyncExecuteResp;
import com.oceanbase.odc.service.session.model.SqlExecuteResult;
Expand Down Expand Up @@ -81,7 +81,7 @@ public class ExternalSqlInterceptor extends BaseTimeConsumingInterceptor {

@Override
public boolean doPreHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyncExecuteResp response,
@NonNull ConnectionSession session, @NonNull Map<String, Object> context) {
@NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) {
Long ruleSetId = ConnectionSessionUtil.getRuleSetId(session);
if (Objects.isNull(ruleSetId) || isIndividualTeam()) {
return true;
Expand Down Expand Up @@ -147,7 +147,7 @@ public boolean doPreHandle(@NonNull SqlAsyncExecuteReq request, @NonNull SqlAsyn

@Override
public void afterCompletion(@NonNull SqlExecuteResult response,
@NonNull ConnectionSession session, @NonNull Map<String, Object> context) {}
@NonNull ConnectionSession session, @NonNull AsyncExecuteContext context) {}


@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ public class ConnectConsoleService {

public static final int DEFAULT_GET_RESULT_TIMEOUT_SECONDS = 3;
public static final String SHOW_TABLE_COLUMN_INFO = "SHOW_TABLE_COLUMN_INFO";
private static final Long EXECUTING_CONTEXT_WAIT_MILLIS = 1100L;

@Autowired
private ConnectSessionService sessionService;
Expand Down Expand Up @@ -270,11 +269,12 @@ public SqlAsyncExecuteResp execute(@NotNull String sessionId,
context.put(SHOW_TABLE_COLUMN_INFO, request.getShowTableColumnInfo());
context.put(SqlCheckInterceptor.NEED_SQL_CHECK_KEY, needSqlRuleCheck);
context.put(SqlConsoleInterceptor.NEED_SQL_CONSOLE_CHECK, needSqlRuleCheck);
AsyncExecuteContext executeContext = new AsyncExecuteContext(sqlTuples, context);
List<TraceStage> stages = sqlTuples.stream()
.map(s -> s.getSqlWatch().start(SqlExecuteStages.SQL_PRE_CHECK))
.collect(Collectors.toList());
try {
if (!sqlInterceptService.preHandle(request, response, connectionSession, context)) {
if (!sqlInterceptService.preHandle(request, response, connectionSession, executeContext)) {
return response;
}
} finally {
Expand Down Expand Up @@ -352,11 +352,12 @@ public SqlAsyncExecuteResp executeV2(@NotNull String sessionId,
context.put(SHOW_TABLE_COLUMN_INFO, request.getShowTableColumnInfo());
context.put(SqlCheckInterceptor.NEED_SQL_CHECK_KEY, needSqlRuleCheck);
context.put(SqlConsoleInterceptor.NEED_SQL_CONSOLE_CHECK, needSqlRuleCheck);
AsyncExecuteContext executeContext = new AsyncExecuteContext(sqlTuples, context);
List<TraceStage> stages = sqlTuples.stream()
.map(s -> s.getSqlWatch().start(SqlExecuteStages.SQL_PRE_CHECK))
.collect(Collectors.toList());
try {
if (!sqlInterceptService.preHandle(request, response, connectionSession, context)) {
if (!sqlInterceptService.preHandle(request, response, connectionSession, executeContext)) {
return response;
}
} finally {
Expand All @@ -373,7 +374,6 @@ public SqlAsyncExecuteResp executeV2(@NotNull String sessionId,
Objects.nonNull(request.getContinueExecutionOnError()) ? request.getContinueExecutionOnError()
: userConfigFacade.isContinueExecutionOnError();
boolean stopOnError = !continueExecutionOnError;
AsyncExecuteContext executeContext = new AsyncExecuteContext(sqlTuples, context);
OdcStatementCallBack statementCallBack = new OdcStatementCallBack(sqlTuples, connectionSession,
request.getAutoCommit(), queryLimit, stopOnError, executeContext);

Expand Down Expand Up @@ -415,10 +415,10 @@ public List<SqlExecuteResult> getAsyncResult(@NotNull String sessionId, String r
Map<String, Object> context = ConnectionSessionUtil.getFutureJdbcContext(connectionSession, requestId);
ConnectionSessionUtil.removeFutureJdbc(connectionSession, requestId);
return resultList.stream().map(jdbcGeneralResult -> {
Map<String, Object> cxt = context == null ? new HashMap<>() : context;
SqlExecuteResult result = generateResult(connectionSession, jdbcGeneralResult, cxt);
Map<String, Object> ctx = context == null ? new HashMap<>() : context;
SqlExecuteResult result = generateResult(connectionSession, jdbcGeneralResult, ctx);
try (TraceStage stage = result.getSqlTuple().getSqlWatch().start(SqlExecuteStages.SQL_AFTER_CHECK)) {
sqlInterceptService.afterCompletion(result, connectionSession, cxt);
sqlInterceptService.afterCompletion(result, connectionSession, new AsyncExecuteContext(null, ctx));
} catch (Exception e) {
throw new IllegalStateException(e);
}
Expand All @@ -440,24 +440,26 @@ public AsyncExecuteResultResp getAsyncResultV2(@NotNull String sessionId, String
ConnectionSession connectionSession = sessionService.nullSafeGet(sessionId);
AsyncExecuteContext context =
(AsyncExecuteContext) ConnectionSessionUtil.getExecuteContext(connectionSession, requestId);
List<JdbcGeneralResult> resultList = context.getResults();
List<SqlExecuteResult> results = resultList.stream().map(jdbcGeneralResult -> {
SqlExecuteResult result = generateResult(connectionSession, jdbcGeneralResult, context.getContextMap());
try (TraceStage stage = result.getSqlTuple().getSqlWatch().start(SqlExecuteStages.SQL_AFTER_CHECK)) {
sqlInterceptService.afterCompletion(result, connectionSession, context.getContextMap());
} catch (Exception e) {
throw new IllegalStateException(e);
}
return result;
}).collect(Collectors.toList());
if (context.isFinished()) {
ConnectionSessionUtil.removeExecuteContext(connectionSession, requestId);
return new AsyncExecuteResultResp(true, context, results);
} else {
if (log.isDebugEnabled()) {
log.debug("Get sql execution result timed out, sessionId={}, requestId={}", sessionId, requestId);
boolean shouldRemoveContext = context.isFinished();
try {
List<JdbcGeneralResult> resultList = context.getFinishedSqlExecutionResults();
List<SqlExecuteResult> results = resultList.stream().map(jdbcGeneralResult -> {
SqlExecuteResult result = generateResult(connectionSession, jdbcGeneralResult, context.getContextMap());
try (TraceStage stage = result.getSqlTuple().getSqlWatch().start(SqlExecuteStages.SQL_AFTER_CHECK)) {
sqlInterceptService.afterCompletion(result, connectionSession, context);
} catch (Exception e) {
throw new IllegalStateException(e);
}
return result;
}).collect(Collectors.toList());
return new AsyncExecuteResultResp(context, results);
} catch (Exception e) {
shouldRemoveContext = true;
throw e;
} finally {
if (shouldRemoveContext) {
ConnectionSessionUtil.removeExecuteContext(connectionSession, requestId);
}
return new AsyncExecuteResultResp(false, context, results);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void onExecutionStart(SqlTuple sqlTuple, AsyncExecuteContext context) {}
public void onExecutionEnd(SqlTuple sqlTuple, List<JdbcGeneralResult> results, AsyncExecuteContext context) {}

@Override
public void onExecutionCanceled(SqlTuple sqlTuple, List<JdbcGeneralResult> results, AsyncExecuteContext context) {}
public void onExecutionCancelled(SqlTuple sqlTuple, List<JdbcGeneralResult> results, AsyncExecuteContext context) {}

public void onExecutionStartAfter(SqlTuple sqlTuple, AsyncExecuteContext context) {
if (CollectionUtils.isEmpty(sessionIds)) {
Expand Down
Loading

0 comments on commit ce6eced

Please sign in to comment.