diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/PrepareCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/PrepareCommand.java index 40c0d8b567ba12..d0d2fffa278b2c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/PrepareCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/PrepareCommand.java @@ -110,7 +110,7 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { } ctx.addPreparedStatementContext(name, new PreparedStatementContext(this, ctx, ctx.getStatementContext(), name)); - if (ctx.getCommand() == MysqlCommand.COM_STMT_PREPARE) { + if (ctx.getCommand() == MysqlCommand.COM_STMT_PREPARE && !ctx.isProxy()) { executor.sendStmtPrepareOK(Integer.parseInt(name), labels); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java index 54de7ed1339ffb..6d86cc7f037f40 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java @@ -76,6 +76,8 @@ import com.google.common.collect.Maps; import com.google.common.collect.Sets; import io.netty.util.concurrent.FastThreadLocal; +import lombok.Getter; +import lombok.Setter; import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -83,6 +85,7 @@ import org.xnio.StreamConnection; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -248,6 +251,10 @@ public enum ConnectType { // it's default thread-safe private boolean isProxy = false; + @Getter + @Setter + private ByteBuffer prepareExecuteBuffer; + private MysqlHandshakePacket mysqlHandshakePacket; public void setUserQueryTimeout(int queryTimeout) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java index cbc0aee98a5c8d..06f5397d3dccd8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java @@ -47,6 +47,7 @@ import org.apache.doris.datasource.InternalCatalog; import org.apache.doris.metric.MetricRepo; import org.apache.doris.mysql.MysqlChannel; +import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.mysql.MysqlPacket; import org.apache.doris.mysql.MysqlSerializer; import org.apache.doris.mysql.MysqlServerStatusFlag; @@ -60,6 +61,7 @@ import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.stats.StatsErrorEstimator; import org.apache.doris.nereids.trees.plans.commands.ExplainCommand; +import org.apache.doris.nereids.trees.plans.commands.PrepareCommand; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalSqlCache; import org.apache.doris.plugin.DialectConverterPlugin; @@ -86,6 +88,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -714,7 +717,24 @@ public TMasterOpResult proxyExecute(TMasterOpRequest request) throws TException UUID uuid = UUID.randomUUID(); queryId = new TUniqueId(uuid.getMostSignificantBits(), uuid.getLeastSignificantBits()); } - executor.queryRetry(queryId); + if (request.isSetPrepareExecuteBuffer()) { + ctx.setCommand(MysqlCommand.COM_STMT_PREPARE); + executor.execute(); + ctx.setCommand(MysqlCommand.COM_STMT_EXECUTE); + String preparedStmtId = executor.getPrepareStmtName(); + PreparedStatementContext preparedStatementContext = ctx.getPreparedStementContext(preparedStmtId); + if (preparedStatementContext == null) { + if (LOG.isDebugEnabled()) { + LOG.debug("Something error, just support nereids preparedStmtId:{}", preparedStmtId); + } + throw new RuntimeException("Prepare failed when proxy execute"); + } + handleExecute(preparedStatementContext.command, Long.parseLong(preparedStmtId), + preparedStatementContext, + ByteBuffer.wrap(request.getPrepareExecuteBuffer()).order(ByteOrder.LITTLE_ENDIAN), queryId); + } else { + executor.queryRetry(queryId); + } } catch (IOException e) { // Client failed. LOG.warn("Process one query failed because IOException: ", e); @@ -784,4 +804,10 @@ private Map userVariableFromThrift(Map t throw new TException(e.getMessage()); } } + + + protected void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedStatementContext prepCtx, + ByteBuffer packetBuf, TUniqueId queryId) { + throw new NotSupportedException("Just MysqlConnectProcessor support execute"); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/MasterOpExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/MasterOpExecutor.java index 1f7d87bdfe35b3..285a752cb32305 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/MasterOpExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/MasterOpExecutor.java @@ -25,6 +25,7 @@ import org.apache.doris.common.Config; import org.apache.doris.common.DdlException; import org.apache.doris.common.ErrorCode; +import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.thrift.FrontendService; import org.apache.doris.thrift.TExpr; import org.apache.doris.thrift.TExprNode; @@ -236,6 +237,12 @@ private TMasterOpRequest buildStmtForwardParams() throws AnalysisException { if (ctx.isTxnModel()) { params.setTxnLoadInfo(ctx.getTxnEntry().getTxnLoadInfoInObserver()); } + + if (ctx.getCommand() == MysqlCommand.COM_STMT_EXECUTE) { + if (null != ctx.getPrepareExecuteBuffer()) { + params.setPrepareExecuteBuffer(ctx.getPrepareExecuteBuffer()); + } + } return params; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java index 97b5061a212907..50990a753c35fd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java @@ -43,6 +43,7 @@ import org.apache.doris.nereids.trees.plans.PlaceholderId; import org.apache.doris.nereids.trees.plans.commands.ExecuteCommand; import org.apache.doris.nereids.trees.plans.commands.PrepareCommand; +import org.apache.doris.thrift.TUniqueId; import com.google.common.base.Preconditions; import com.google.common.base.Strings; @@ -100,7 +101,18 @@ private void debugPacket() { } } - private void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedStatementContext prepCtx) { + private String getHexStr(ByteBuffer packetBuf) { + byte[] bytes = packetBuf.array(); + StringBuilder hex = new StringBuilder(); + for (int i = packetBuf.position(); i < packetBuf.limit(); ++i) { + hex.append(String.format("%02X ", bytes[i])); + } + return hex.toString(); + } + + @Override + protected void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedStatementContext prepCtx, + ByteBuffer packetBuf, TUniqueId queryId) { int paramCount = prepareCommand.placeholderCount(); LOG.debug("execute prepared statement {}, paramCount {}", stmtId, paramCount); // null bitmap @@ -108,6 +120,12 @@ private void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedS try { StatementContext statementContext = prepCtx.statementContext; if (paramCount > 0) { + if (LOG.isDebugEnabled()) { + LOG.debug("execute param buf: {}, array: {}", packetBuf, getHexStr(packetBuf)); + } + if (!ctx.isProxy()) { + ctx.setPrepareExecuteBuffer(packetBuf.duplicate()); + } byte[] nullbitmapData = new byte[(paramCount + 7) / 8]; packetBuf.get(nullbitmapData); // new_params_bind_flag @@ -148,7 +166,11 @@ private void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedS stmt.setOrigStmt(prepareCommand.getOriginalStmt()); executor = new StmtExecutor(ctx, stmt); ctx.setExecutor(executor); - executor.execute(); + if (null != queryId) { + executor.execute(queryId); + } else { + executor.execute(); + } if (ctx.getSessionVariable().isEnablePreparedStmtAuditLog()) { stmtStr = executeStmt.toSql(); stmtStr = stmtStr + " /*originalSql = " + prepareCommand.getOriginalStmt().originStmt + "*/"; @@ -191,7 +213,7 @@ private void handleExecute() { "msg: Not supported such prepared statement"); return; } - handleExecute(preparedStatementContext.command, stmtId, preparedStatementContext); + handleExecute(preparedStatementContext.command, stmtId, preparedStatementContext, packetBuf, null); } // Process COM_QUERY statement, diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java index 7555337756352b..368e3f94ab1298 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java @@ -145,6 +145,7 @@ import org.apache.doris.nereids.glue.LogicalPlanAdapter; import org.apache.doris.nereids.minidump.MinidumpUtils; import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.trees.expressions.Placeholder; import org.apache.doris.nereids.trees.plans.commands.Command; import org.apache.doris.nereids.trees.plans.commands.CreatePolicyCommand; import org.apache.doris.nereids.trees.plans.commands.CreateTableCommand; @@ -279,6 +280,7 @@ public class StmtExecutor { private Data.PQueryStatistics.Builder statisticsForAuditLog; private boolean isCached; private String stmtName; + private String prepareStmtName; // for prox private String mysqlLoadId; // Handle selects that fe can do without be private boolean isHandleQueryInFe = false; @@ -701,8 +703,12 @@ private void executeByNereids(TUniqueId queryId) throws Exception { } long stmtId = Config.prepared_stmt_start_id > 0 ? Config.prepared_stmt_start_id : context.getPreparedStmtId(); - logicalPlan = new PrepareCommand(String.valueOf(stmtId), - logicalPlan, statementContext.getPlaceholders(), originStmt); + this.prepareStmtName = String.valueOf(stmtId); + // When proxy executing, this.statementContext is created in constructor. + // But context.statementContext is created in LogicalPlanBuilder. + List placeholders = context == null + ? statementContext.getPlaceholders() : context.getStatementContext().getPlaceholders(); + logicalPlan = new PrepareCommand(prepareStmtName, logicalPlan, placeholders, originStmt); } // when we in transaction mode, we only support insert into command and transaction command if (context.isTxnModel()) { @@ -723,8 +729,7 @@ private void executeByNereids(TUniqueId queryId) throws Exception { if (logicalPlan instanceof InsertIntoTableCommand) { profileType = ProfileType.LOAD; } - if (context.getCommand() == MysqlCommand.COM_STMT_PREPARE - || context.getCommand() == MysqlCommand.COM_STMT_EXECUTE) { + if (context.getCommand() == MysqlCommand.COM_STMT_PREPARE) { throw new UserException("Forward master command is not supported for prepare statement"); } if (isProxy) { @@ -3687,4 +3692,8 @@ public void sendProxyQueryResult() throws IOException { context.getMysqlChannel().sendOnePacket(byteBuffer); } } + + public String getPrepareStmtName() { + return this.prepareStmtName; + } } diff --git a/gensrc/thrift/FrontendService.thrift b/gensrc/thrift/FrontendService.thrift index 92205c5ae0f011..0723de2cf2f8d3 100644 --- a/gensrc/thrift/FrontendService.thrift +++ b/gensrc/thrift/FrontendService.thrift @@ -597,6 +597,7 @@ struct TMasterOpRequest { // transaction load 29: optional TTxnLoadInfo txnLoadInfo 30: optional TGroupCommitInfo groupCommitInfo + 31: optional binary prepareExecuteBuffer // selectdb cloud 1000: optional string cloud_cluster diff --git a/regression-test/suites/query_p0/test_forward_qeury.groovy b/regression-test/suites/query_p0/test_forward_qeury.groovy index d4761c835a26e0..e2b11e9535ff93 100644 --- a/regression-test/suites/query_p0/test_forward_qeury.groovy +++ b/regression-test/suites/query_p0/test_forward_qeury.groovy @@ -43,7 +43,12 @@ suite("test_forward_query", 'docker') { cluster.injectDebugPoints(NodeType.FE, ['StmtExecutor.forward_all_queries' : [forwardAllQueries:true, execute:1]]) - def ret = sql """ SELECT * FROM ${tbl} """ + def stmt = prepareStatement("""INSERT INTO ${tbl} VALUES(?);""") + stmt.setInt(1, 2) + stmt.executeUpdate() + + def ret = sql """ SELECT * FROM ${tbl} order by k1""" assertEquals(ret[0][0], 1) + assertEquals(ret[1][0], 2) } }