diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/PrepareCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/PrepareCommand.java index a844dcb9500796..ebddcd68845185 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/PrepareCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/PrepareCommand.java @@ -109,7 +109,7 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { } ctx.addPreparedStatementContext(name, new PreparedStatementContext(this, ctx, ctx.getStatementContext(), name)); - if (ctx.getCommand() == MysqlCommand.COM_STMT_PREPARE) { + if (ctx.getCommand() == MysqlCommand.COM_STMT_PREPARE && !ctx.isProxy()) { executor.sendStmtPrepareOK(Integer.parseInt(name), labels); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java index 544333905bcf8a..3d68225c63f51e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java @@ -69,6 +69,8 @@ import com.google.common.collect.Maps; import com.google.common.collect.Sets; import io.netty.util.concurrent.FastThreadLocal; +import lombok.Getter; +import lombok.Setter; import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -76,6 +78,7 @@ import org.xnio.StreamConnection; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -239,6 +242,10 @@ public enum ConnectType { // it's default thread-safe private boolean isProxy = false; + @Getter + @Setter + private ByteBuffer prepareExecuteBuffer; + private MysqlHandshakePacket mysqlHandshakePacket; public void setUserQueryTimeout(int queryTimeout) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java index 503e31ef362751..bf59d9a4a9cdd4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java @@ -61,6 +61,7 @@ import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.stats.StatsErrorEstimator; import org.apache.doris.nereids.trees.plans.commands.ExplainCommand; +import org.apache.doris.nereids.trees.plans.commands.PrepareCommand; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalSqlCache; import org.apache.doris.plugin.DialectConverterPlugin; @@ -87,6 +88,7 @@ import java.io.IOException; import java.io.StringReader; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -732,8 +734,24 @@ public TMasterOpResult proxyExecute(TMasterOpRequest request) throws TException UUID uuid = UUID.randomUUID(); queryId = new TUniqueId(uuid.getMostSignificantBits(), uuid.getLeastSignificantBits()); } - - executor.execute(queryId); + if (request.isSetPrepareExecuteBuffer()) { + ctx.setCommand(MysqlCommand.COM_STMT_PREPARE); + executor.execute(); + ctx.setCommand(MysqlCommand.COM_STMT_EXECUTE); + String preparedStmtId = executor.getPrepareStmtName(); + PreparedStatementContext preparedStatementContext = ctx.getPreparedStementContext(preparedStmtId); + if (preparedStatementContext == null) { + if (LOG.isDebugEnabled()) { + LOG.debug("Something error, just support nereids preparedStmtId:{}", preparedStmtId); + } + throw new RuntimeException("Prepare failed when proxy execute"); + } + handleExecute(preparedStatementContext.command, Long.parseLong(preparedStmtId), + preparedStatementContext, + ByteBuffer.wrap(request.getPrepareExecuteBuffer()).order(ByteOrder.LITTLE_ENDIAN), queryId); + } else { + executor.execute(queryId); + } } catch (IOException e) { // Client failed. LOG.warn("Process one query failed because IOException: ", e); @@ -796,4 +814,10 @@ private Map userVariableFromThrift(Map t throw new TException(e.getMessage()); } } + + + protected void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedStatementContext prepCtx, + ByteBuffer packetBuf, TUniqueId queryId) { + throw new NotSupportedException("Just MysqlConnectProcessor support execute"); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/MasterOpExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/MasterOpExecutor.java index 23f11f41173140..0e7b5a2f473d88 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/MasterOpExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/MasterOpExecutor.java @@ -23,6 +23,7 @@ import org.apache.doris.common.ClientPool; import org.apache.doris.common.DdlException; import org.apache.doris.common.ErrorCode; +import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.thrift.FrontendService; import org.apache.doris.thrift.TExpr; import org.apache.doris.thrift.TExprNode; @@ -183,6 +184,12 @@ private TMasterOpRequest buildStmtForwardParams() { if (null != ctx.queryId()) { params.setQueryId(ctx.queryId()); } + + if (ctx.getCommand() == MysqlCommand.COM_STMT_EXECUTE) { + if (null != ctx.getPrepareExecuteBuffer()) { + params.setPrepareExecuteBuffer(ctx.getPrepareExecuteBuffer()); + } + } return params; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java index 0c89b4a8403743..c54c05ae5bd429 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java @@ -46,6 +46,7 @@ import org.apache.doris.nereids.trees.plans.PlaceholderId; import org.apache.doris.nereids.trees.plans.commands.ExecuteCommand; import org.apache.doris.nereids.trees.plans.commands.PrepareCommand; +import org.apache.doris.thrift.TUniqueId; import com.google.common.base.Preconditions; import com.google.common.base.Strings; @@ -170,7 +171,18 @@ private void handleExecute(PrepareStmt prepareStmt, long stmtId) { } } - private void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedStatementContext prepCtx) { + private String getHexStr(ByteBuffer packetBuf) { + byte[] bytes = packetBuf.array(); + StringBuilder hex = new StringBuilder(); + for (int i = packetBuf.position(); i < packetBuf.limit(); ++i) { + hex.append(String.format("%02X ", bytes[i])); + } + return hex.toString(); + } + + @Override + protected void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedStatementContext prepCtx, + ByteBuffer packetBuf, TUniqueId queryId) { int paramCount = prepareCommand.placeholderCount(); LOG.debug("execute prepared statement {}, paramCount {}", stmtId, paramCount); // null bitmap @@ -178,6 +190,12 @@ private void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedS try { StatementContext statementContext = prepCtx.statementContext; if (paramCount > 0) { + if (LOG.isDebugEnabled()) { + LOG.debug("execute param buf: {}, array: {}", packetBuf, getHexStr(packetBuf)); + } + if (!ctx.isProxy()) { + ctx.setPrepareExecuteBuffer(packetBuf.duplicate()); + } byte[] nullbitmapData = new byte[(paramCount + 7) / 8]; packetBuf.get(nullbitmapData); // new_params_bind_flag @@ -218,7 +236,11 @@ private void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedS stmt.setOrigStmt(prepareCommand.getOriginalStmt()); executor = new StmtExecutor(ctx, stmt); ctx.setExecutor(executor); - executor.execute(); + if (null != queryId) { + executor.execute(queryId); + } else { + executor.execute(); + } if (ctx.getSessionVariable().isEnablePreparedStmtAuditLog()) { stmtStr = executeStmt.toSql(); stmtStr = stmtStr + " /*originalSql = " + prepareCommand.getOriginalStmt().originStmt + "*/"; @@ -266,7 +288,7 @@ private void handleExecute() { "msg: Not supported such prepared statement"); return; } - handleExecute(preparedStatementContext.command, stmtId, preparedStatementContext); + handleExecute(preparedStatementContext.command, stmtId, preparedStatementContext, packetBuf, null); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java index 1d1a93e92e60ed..f5ab4137f29698 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java @@ -143,6 +143,7 @@ import org.apache.doris.nereids.glue.LogicalPlanAdapter; import org.apache.doris.nereids.minidump.MinidumpUtils; import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.trees.expressions.Placeholder; import org.apache.doris.nereids.trees.plans.commands.Command; import org.apache.doris.nereids.trees.plans.commands.CreateTableCommand; import org.apache.doris.nereids.trees.plans.commands.DeleteFromCommand; @@ -269,6 +270,7 @@ public class StmtExecutor { private boolean isCached; private String stmtName; private StatementBase prepareStmt = null; + private String prepareStmtName; // for prox private String mysqlLoadId; // Distinguish from prepare and execute command private boolean isExecuteStmt = false; @@ -682,8 +684,12 @@ private void executeByNereids(TUniqueId queryId) throws Exception { } long stmtId = Config.prepared_stmt_start_id > 0 ? Config.prepared_stmt_start_id : context.getPreparedStmtId(); - logicalPlan = new PrepareCommand(String.valueOf(stmtId), - logicalPlan, statementContext.getPlaceholders(), originStmt); + this.prepareStmtName = String.valueOf(stmtId); + // When proxy executing, this.statementContext is created in constructor. + // But context.statementContext is created in LogicalPlanBuilder. + List placeholders = context == null + ? statementContext.getPlaceholders() : context.getStatementContext().getPlaceholders(); + logicalPlan = new PrepareCommand(prepareStmtName, logicalPlan, placeholders, originStmt); } // when we in transaction mode, we only support insert into command and transaction command if (context.isTxnModel()) { @@ -3488,4 +3494,8 @@ public void sendProxyQueryResult() throws IOException { context.getMysqlChannel().sendOnePacket(byteBuffer); } } + + public String getPrepareStmtName() { + return this.prepareStmtName; + } } diff --git a/gensrc/thrift/FrontendService.thrift b/gensrc/thrift/FrontendService.thrift index 08a4aa9c061729..8903eba29e4a50 100644 --- a/gensrc/thrift/FrontendService.thrift +++ b/gensrc/thrift/FrontendService.thrift @@ -571,6 +571,7 @@ struct TMasterOpRequest { // transaction load 29: optional TTxnLoadInfo txnLoadInfo 30: optional TGroupCommitInfo groupCommitInfo + 31: optional binary prepareExecuteBuffer } struct TColumnDefinition { diff --git a/regression-test/suites/query_p0/test_forward_qeury.groovy b/regression-test/suites/query_p0/test_forward_qeury.groovy index d4761c835a26e0..e2b11e9535ff93 100644 --- a/regression-test/suites/query_p0/test_forward_qeury.groovy +++ b/regression-test/suites/query_p0/test_forward_qeury.groovy @@ -43,7 +43,12 @@ suite("test_forward_query", 'docker') { cluster.injectDebugPoints(NodeType.FE, ['StmtExecutor.forward_all_queries' : [forwardAllQueries:true, execute:1]]) - def ret = sql """ SELECT * FROM ${tbl} """ + def stmt = prepareStatement("""INSERT INTO ${tbl} VALUES(?);""") + stmt.setInt(1, 2) + stmt.executeUpdate() + + def ret = sql """ SELECT * FROM ${tbl} order by k1""" assertEquals(ret[0][0], 1) + assertEquals(ret[1][0], 2) } }