diff --git a/planner/core/cache.go b/planner/core/cache.go index ca0a1c6340b44..0d01f2bb98f73 100644 --- a/planner/core/cache.go +++ b/planner/core/cache.go @@ -208,6 +208,7 @@ func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.Ta type CachedPrepareStmt struct { PreparedAst *ast.Prepared VisitInfos []visitInfo + ColumnInfos interface{} Executor interface{} NormalizedSQL string NormalizedPlan string diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 677298c3c7fc4..6f3221cd21388 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -80,7 +80,8 @@ func (ts *TiDBStatement) Execute(ctx context.Context, args []types.Datum) (rs Re return } rs = &tidbResultSet{ - recordSet: tidbRecordset, + recordSet: tidbRecordset, + preparedStmt: ts.ctx.GetSessionVars().PreparedStmts[ts.id].(*core.CachedPrepareStmt), } return } @@ -296,10 +297,11 @@ func (tc *TiDBContext) GetStmtStats() *stmtstats.StatementStats { } type tidbResultSet struct { - recordSet sqlexec.RecordSet - columns []*ColumnInfo - rows []chunk.Row - closed int32 + recordSet sqlexec.RecordSet + columns []*ColumnInfo + rows []chunk.Row + closed int32 + preparedStmt *core.CachedPrepareStmt } func (trs *tidbResultSet) NewChunk(alloc chunk.Allocator) *chunk.Chunk { @@ -341,12 +343,23 @@ func (trs *tidbResultSet) Columns() []*ColumnInfo { if trs.columns != nil { return trs.columns } - + // for prepare statement, try to get cached columnInfo array + if trs.preparedStmt != nil { + ps := trs.preparedStmt + if colInfos, ok := ps.ColumnInfos.([]*ColumnInfo); ok { + trs.columns = colInfos + } + } if trs.columns == nil { fields := trs.recordSet.Fields() for _, v := range fields { trs.columns = append(trs.columns, convertColumnInfo(v)) } + if trs.preparedStmt != nil { + // if ColumnInfo struct has allocated object, + // here maybe we need deep copy ColumnInfo to do caching + trs.preparedStmt.ColumnInfos = trs.columns + } } return trs.columns } diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 141681e1df24e..5bbf88ad0e392 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -366,47 +366,6 @@ func TestPrepareCount(t *testing.T) { require.NoError(t, qctx.Close()) } -func TestPrepareExecute(t *testing.T) { - ts, cleanup := createTidbTestSuite(t) - defer cleanup() - - qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) - require.NoError(t, err) - - ctx := context.Background() - _, err = qctx.Execute(ctx, "use test") - require.NoError(t, err) - _, err = qctx.Execute(ctx, "create table t1(id int primary key, v int)") - require.NoError(t, err) - _, err = qctx.Execute(ctx, "insert into t1 values(1, 100)") - require.NoError(t, err) - - stmt, _, _, err := qctx.Prepare("select * from t1 where id=1") - require.NoError(t, err) - rs, err := stmt.Execute(ctx, nil) - require.NoError(t, err) - req := rs.NewChunk(nil) - require.NoError(t, rs.Next(ctx, req)) - require.Equal(t, 2, req.NumCols()) - require.Equal(t, req.NumCols(), len(rs.Columns())) - require.Equal(t, 1, req.NumRows()) - require.Equal(t, int64(1), req.GetRow(0).GetInt64(0)) - require.Equal(t, int64(100), req.GetRow(0).GetInt64(1)) - - // issue #33509 - _, err = qctx.Execute(ctx, "alter table t1 drop column v") - require.NoError(t, err) - - rs, err = stmt.Execute(ctx, nil) - require.NoError(t, err) - req = rs.NewChunk(nil) - require.NoError(t, rs.Next(ctx, req)) - require.Equal(t, 1, req.NumCols()) - require.Equal(t, req.NumCols(), len(rs.Columns())) - require.Equal(t, 1, req.NumRows()) - require.Equal(t, int64(1), req.GetRow(0).GetInt64(0)) -} - func TestDefaultCharacterAndCollation(t *testing.T) { ts, cleanup := createTidbTestSuite(t) defer cleanup()