diff --git a/session/session.go b/session/session.go index 8b6b7cab6b981..f604e8f4933ab 100644 --- a/session/session.go +++ b/session/session.go @@ -1030,12 +1030,6 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu } func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet, inMulitQuery bool) ([]sqlexec.RecordSet, error) { - s.SetValue(sessionctx.QueryString, stmt.OriginText()) - if _, ok := stmtNode.(ast.DDLNode); ok { - s.SetValue(sessionctx.LastExecuteDDL, true) - } else { - s.ClearValue(sessionctx.LastExecuteDDL) - } logStmt(stmtNode, s.sessionVars) startTime := time.Now() recordSet, err := runStmt(ctx, s, stmt) diff --git a/session/session_test.go b/session/session_test.go index 1d78bc50673d9..0ac2c99a84e45 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -182,6 +182,20 @@ func (s *testSessionSuite) TestQueryString(c *C) { tk.MustExec("create table mutil1 (a int);create table multi2 (a int)") queryStr := tk.Se.Value(sessionctx.QueryString) c.Assert(queryStr, Equals, "create table multi2 (a int)") + + // Test execution of DDL through the "ExecutePreparedStmt" interface. + _, err := tk.Se.Execute(context.Background(), "use test;") + c.Assert(err, IsNil) + _, err = tk.Se.Execute(context.Background(), "CREATE TABLE t (id bigint PRIMARY KEY, age int)") + c.Assert(err, IsNil) + _, err = tk.Se.Execute(context.Background(), "show create table t") + c.Assert(err, IsNil) + id, _, _, err := tk.Se.PrepareStmt("CREATE TABLE t2(id bigint PRIMARY KEY, age int)") + c.Assert(err, IsNil) + _, err = tk.Se.ExecutePreparedStmt(context.Background(), id) + c.Assert(err, IsNil) + qs := tk.Se.Value(sessionctx.QueryString) + c.Assert(qs.(string), Equals, "CREATE TABLE t2(id bigint PRIMARY KEY, age int)") } func (s *testSessionSuite) TestAffectedRows(c *C) { diff --git a/session/tidb.go b/session/tidb.go index d86a48aaf0640..568edddb240c5 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -234,6 +234,13 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) span1.LogKV("sql", s.OriginText()) defer span1.Finish() } + sctx.SetValue(sessionctx.QueryString, s.OriginText()) + if _, ok := s.(*executor.ExecStmt).StmtNode.(ast.DDLNode); ok { + sctx.SetValue(sessionctx.LastExecuteDDL, true) + } else { + sctx.ClearValue(sessionctx.LastExecuteDDL) + } + se := sctx.(*session) sessVars := se.GetSessionVars() // Save origTxnCtx here to avoid it reset in the transaction retry.