diff --git a/ddl/schematracker/dm_tracker.go b/ddl/schematracker/dm_tracker.go index afb3a75c1974b..75f8fa35b429d 100644 --- a/ddl/schematracker/dm_tracker.go +++ b/ddl/schematracker/dm_tracker.go @@ -49,6 +49,9 @@ var _ ddl.DDL = SchemaTracker{} // SchemaTracker is used to track schema changes by DM. It implements DDL interface and by applying DDL, it updates the // table structure to keep tracked with upstream changes. +// It embeds an InfoStore which stores DBInfo and TableInfo. The DBInfo and TableInfo can be treated as immutable, so +// after reading them by SchemaByName or TableByName, later modifications made by SchemaTracker will not change them. +// SchemaTracker is not thread-safe. type SchemaTracker struct { *InfoStore } @@ -108,16 +111,22 @@ func (d SchemaTracker) CreateSchemaWithInfo(ctx sessionctx.Context, dbInfo *mode } // AlterSchema implements the DDL interface. -func (d SchemaTracker) AlterSchema(ctx sessionctx.Context, stmt *ast.AlterDatabaseStmt) error { +func (d SchemaTracker) AlterSchema(ctx sessionctx.Context, stmt *ast.AlterDatabaseStmt) (err error) { dbInfo := d.SchemaByName(stmt.Name) if dbInfo == nil { return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(stmt.Name.O) } + dbInfo = dbInfo.Clone() + defer func() { + if err == nil { + d.PutSchema(dbInfo) + } + }() + // Resolve target charset and collation from options. var ( toCharset, toCollate string - err error ) for _, val := range stmt.Options { @@ -173,9 +182,15 @@ func (d SchemaTracker) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStm return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) } // suppress ErrTooLongKey + strictSQLModeBackup := ctx.GetSessionVars().StrictSQLMode ctx.GetSessionVars().StrictSQLMode = false // support drop PK + enableClusteredIndexBackup := ctx.GetSessionVars().EnableClusteredIndex ctx.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOff + defer func() { + ctx.GetSessionVars().StrictSQLMode = strictSQLModeBackup + ctx.GetSessionVars().EnableClusteredIndex = enableClusteredIndexBackup + }() var ( referTbl *model.TableInfo @@ -354,6 +369,13 @@ func (d SchemaTracker) CreateIndex(ctx sessionctx.Context, stmt *ast.CreateIndex stmt.IndexPartSpecifications, stmt.IndexOption, stmt.IfNotExists) } +func (d SchemaTracker) putTableIfNoError(err error, dbName model.CIStr, tbInfo *model.TableInfo) { + if err != nil { + return + } + _ = d.PutTable(dbName, tbInfo) +} + // createIndex is shared by CreateIndex and AlterTable. func (d SchemaTracker) createIndex( ctx sessionctx.Context, @@ -363,12 +385,15 @@ func (d SchemaTracker) createIndex( indexPartSpecifications []*ast.IndexPartSpecification, indexOption *ast.IndexOption, ifNotExists bool, -) error { +) (err error) { unique := keyType == ast.IndexKeyTypeUnique - tblInfo, err := d.TableByName(ti.Schema, ti.Name) + tblInfo, err := d.TableClonedByName(ti.Schema, ti.Name) if err != nil { return err } + + defer d.putTableIfNoError(err, ti.Schema, tblInfo) + t := tables.MockTableFromMeta(tblInfo) // Deal with anonymous index. @@ -432,12 +457,14 @@ func (d SchemaTracker) DropIndex(ctx sessionctx.Context, stmt *ast.DropIndexStmt } // dropIndex is shared by DropIndex and AlterTable. -func (d SchemaTracker) dropIndex(ctx sessionctx.Context, ti ast.Ident, indexName model.CIStr, ifExists bool) error { - tblInfo, err := d.TableByName(ti.Schema, ti.Name) +func (d SchemaTracker) dropIndex(ctx sessionctx.Context, ti ast.Ident, indexName model.CIStr, ifExists bool) (err error) { + tblInfo, err := d.TableClonedByName(ti.Schema, ti.Name) if err != nil { return infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name) } + defer d.putTableIfNoError(err, ti.Schema, tblInfo) + indexInfo := tblInfo.FindIndexByName(indexName.L) if indexInfo == nil { if ifExists { @@ -464,16 +491,19 @@ func (d SchemaTracker) dropIndex(ctx sessionctx.Context, ti ast.Ident, indexName } // addColumn is used by AlterTable. -func (d SchemaTracker) addColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTableSpec) error { +func (d SchemaTracker) addColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTableSpec) (err error) { specNewColumn := spec.NewColumns[0] schema := d.SchemaByName(ti.Schema) if schema == nil { return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ti.Schema) } - tblInfo, err := d.TableByName(ti.Schema, ti.Name) + tblInfo, err := d.TableClonedByName(ti.Schema, ti.Name) if err != nil { return err } + + defer d.putTableIfNoError(err, ti.Schema, tblInfo) + t := tables.MockTableFromMeta(tblInfo) colName := specNewColumn.Name.Name.O @@ -504,12 +534,14 @@ func (d SchemaTracker) addColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast } // dropColumn is used by AlterTable. -func (d *SchemaTracker) dropColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTableSpec) error { - tblInfo, err := d.TableByName(ti.Schema, ti.Name) +func (d *SchemaTracker) dropColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTableSpec) (err error) { + tblInfo, err := d.TableClonedByName(ti.Schema, ti.Name) if err != nil { return err } + defer d.putTableIfNoError(err, ti.Schema, tblInfo) + colName := spec.OldColumnName.Name colInfo := tblInfo.FindPublicColumnByName(colName.L) if colInfo == nil { @@ -546,14 +578,17 @@ func (d *SchemaTracker) dropColumn(ctx sessionctx.Context, ti ast.Ident, spec *a } // renameColumn is used by AlterTable. -func (d SchemaTracker) renameColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { +func (d SchemaTracker) renameColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) (err error) { oldColName := spec.OldColumnName.Name newColName := spec.NewColumnName.Name - tblInfo, err := d.TableByName(ident.Schema, ident.Name) + tblInfo, err := d.TableClonedByName(ident.Schema, ident.Name) if err != nil { return err } + + defer d.putTableIfNoError(err, ident.Schema, tblInfo) + tbl := tables.MockTableFromMeta(tblInfo) oldCol := table.FindCol(tbl.VisibleCols(), oldColName.L) @@ -595,12 +630,15 @@ func (d SchemaTracker) renameColumn(ctx sessionctx.Context, ident ast.Ident, spe } // alterColumn is used by AlterTable. -func (d SchemaTracker) alterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { +func (d SchemaTracker) alterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) (err error) { specNewColumn := spec.NewColumns[0] - tblInfo, err := d.TableByName(ident.Schema, ident.Name) + tblInfo, err := d.TableClonedByName(ident.Schema, ident.Name) if err != nil { return err } + + defer d.putTableIfNoError(err, ident.Schema, tblInfo) + t := tables.MockTableFromMeta(tblInfo) colName := specNewColumn.Name.Name @@ -664,11 +702,14 @@ func (d SchemaTracker) handleModifyColumn( ident ast.Ident, originalColName model.CIStr, spec *ast.AlterTableSpec, -) error { - tblInfo, err := d.TableByName(ident.Schema, ident.Name) +) (err error) { + tblInfo, err := d.TableClonedByName(ident.Schema, ident.Name) if err != nil { return err } + + defer d.putTableIfNoError(err, ident.Schema, tblInfo) + schema := d.SchemaByName(ident.Schema) t := tables.MockTableFromMeta(tblInfo) job, err := ddl.GetModifiableColumnJob(ctx, sctx, nil, ident, originalColName, schema, t, spec) @@ -714,11 +755,14 @@ func (d SchemaTracker) handleModifyColumn( } // renameIndex is used by AlterTable. -func (d SchemaTracker) renameIndex(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - tblInfo, err := d.TableByName(ident.Schema, ident.Name) +func (d SchemaTracker) renameIndex(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) (err error) { + tblInfo, err := d.TableClonedByName(ident.Schema, ident.Name) if err != nil { return err } + + defer d.putTableIfNoError(err, ident.Schema, tblInfo) + duplicate, err := ddl.ValidateRenameIndex(spec.FromKey, spec.ToKey, tblInfo) if duplicate { return nil @@ -732,12 +776,14 @@ func (d SchemaTracker) renameIndex(ctx sessionctx.Context, ident ast.Ident, spec } // addTablePartitions is used by AlterTable. -func (d SchemaTracker) addTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - tblInfo, err := d.TableByName(ident.Schema, ident.Name) +func (d SchemaTracker) addTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) (err error) { + tblInfo, err := d.TableClonedByName(ident.Schema, ident.Name) if err != nil { return errors.Trace(err) } + defer d.putTableIfNoError(err, ident.Schema, tblInfo) + pi := tblInfo.GetPartitionInfo() if pi == nil { return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) @@ -752,12 +798,14 @@ func (d SchemaTracker) addTablePartitions(ctx sessionctx.Context, ident ast.Iden } // dropTablePartitions is used by AlterTable. -func (d SchemaTracker) dropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - tblInfo, err := d.TableByName(ident.Schema, ident.Name) +func (d SchemaTracker) dropTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) (err error) { + tblInfo, err := d.TableClonedByName(ident.Schema, ident.Name) if err != nil { return errors.Trace(err) } + defer d.putTableIfNoError(err, ident.Schema, tblInfo) + pi := tblInfo.GetPartitionInfo() if pi == nil { return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) @@ -799,12 +847,14 @@ func (d SchemaTracker) createPrimaryKey( indexName model.CIStr, indexPartSpecifications []*ast.IndexPartSpecification, indexOption *ast.IndexOption, -) error { - tblInfo, err := d.TableByName(ti.Schema, ti.Name) +) (err error) { + tblInfo, err := d.TableClonedByName(ti.Schema, ti.Name) if err != nil { return errors.Trace(err) } + defer d.putTableIfNoError(err, ti.Schema, tblInfo) + indexName = model.NewCIStr(mysql.PrimaryKeyName) if indexInfo := tblInfo.FindIndexByName(indexName.L); indexInfo != nil || // If the table's PKIsHandle is true, it also means that this table has a primary key. @@ -888,7 +938,7 @@ func (d SchemaTracker) AlterTable(ctx context.Context, sctx sessionctx.Context, case ast.AlterTableRenameIndex: err = d.renameIndex(sctx, ident, spec) case ast.AlterTableDropPartition: - err = d.dropTablePartition(sctx, ident, spec) + err = d.dropTablePartitions(sctx, ident, spec) case ast.AlterTableAddConstraint: constr := spec.Constraint switch spec.Constraint.Tp { @@ -925,7 +975,9 @@ func (d SchemaTracker) AlterTable(ctx context.Context, sctx sessionctx.Context, case ast.TableOptionAutoIdCache: case ast.TableOptionAutoRandomBase: case ast.TableOptionComment: + tblInfo = tblInfo.Clone() tblInfo.Comment = opt.StrValue + _ = d.PutTable(ident.Schema, tblInfo) case ast.TableOptionCharset, ast.TableOptionCollate: // getCharsetAndCollateInTableOption will get the last charset and collate in the options, // so it should be handled only once. @@ -939,6 +991,7 @@ func (d SchemaTracker) AlterTable(ctx context.Context, sctx sessionctx.Context, } needsOverwriteCols := ddl.NeedToOverwriteColCharset(spec.Options) + tblInfo = tblInfo.Clone() if toCharset != "" { tblInfo.Charset = toCharset } @@ -957,6 +1010,7 @@ func (d SchemaTracker) AlterTable(ctx context.Context, sctx sessionctx.Context, } } } + _ = d.PutTable(ident.Schema, tblInfo) handledCharsetOrCollate = true case ast.TableOptionPlacementPolicy: @@ -970,11 +1024,13 @@ func (d SchemaTracker) AlterTable(ctx context.Context, sctx sessionctx.Context, } } case ast.AlterTableIndexInvisible: + tblInfo = tblInfo.Clone() idx := tblInfo.FindIndexByName(spec.IndexName.L) if idx == nil { return errors.Trace(infoschema.ErrKeyNotExists.GenWithStackByArgs(spec.IndexName.O, ident.Name)) } idx.Invisible = spec.Visibility == ast.IndexVisibilityInvisible + _ = d.PutTable(ident.Schema, tblInfo) case ast.AlterTablePartitionOptions, ast.AlterTableDropForeignKey, ast.AlterTableCoalescePartitions, diff --git a/ddl/schematracker/dm_tracker_test.go b/ddl/schematracker/dm_tracker_test.go index 8cfd34cde0590..01998d3dc0134 100644 --- a/ddl/schematracker/dm_tracker_test.go +++ b/ddl/schematracker/dm_tracker_test.go @@ -98,6 +98,12 @@ func execAlter(t *testing.T, tracker schematracker.SchemaTracker, sql string) { require.NoError(t, err) } +func mustTableByName(t *testing.T, tracker schematracker.SchemaTracker, schema, table string) *model.TableInfo { + tblInfo, err := tracker.TableByName(model.NewCIStr(schema), model.NewCIStr(table)) + require.NoError(t, err) + return tblInfo +} + func TestAlterPK(t *testing.T) { sql := "create table test.t (c1 int primary key, c2 blob);" @@ -105,20 +111,24 @@ func TestAlterPK(t *testing.T) { tracker.CreateTestDB() execCreate(t, tracker, sql) - tblInfo, err := tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) - require.NoError(t, err) + tblInfo := mustTableByName(t, tracker, "test", "t") require.Equal(t, 1, len(tblInfo.Indices)) sql = "alter table test.t drop primary key;" execAlter(t, tracker, sql) + // TableInfo should be immutable. + require.Equal(t, 1, len(tblInfo.Indices)) + tblInfo = mustTableByName(t, tracker, "test", "t") require.Equal(t, 0, len(tblInfo.Indices)) sql = "alter table test.t add primary key(c1);" execAlter(t, tracker, sql) + tblInfo = mustTableByName(t, tracker, "test", "t") require.Equal(t, 1, len(tblInfo.Indices)) sql = "alter table test.t drop primary key;" execAlter(t, tracker, sql) + tblInfo = mustTableByName(t, tracker, "test", "t") require.Equal(t, 0, len(tblInfo.Indices)) } @@ -129,20 +139,22 @@ func TestDropColumn(t *testing.T) { tracker.CreateTestDB() execCreate(t, tracker, sql) - tblInfo, err := tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) - require.NoError(t, err) + tblInfo := mustTableByName(t, tracker, "test", "t") require.Equal(t, 1, len(tblInfo.Indices)) sql = "alter table test.t drop column b" execAlter(t, tracker, sql) + tblInfo = mustTableByName(t, tracker, "test", "t") require.Equal(t, 0, len(tblInfo.Indices)) sql = "alter table test.t add index idx_2_col(a, c)" execAlter(t, tracker, sql) + tblInfo = mustTableByName(t, tracker, "test", "t") require.Equal(t, 1, len(tblInfo.Indices)) sql = "alter table test.t drop column c" execAlter(t, tracker, sql) + tblInfo = mustTableByName(t, tracker, "test", "t") require.Equal(t, 1, len(tblInfo.Indices)) require.Equal(t, 1, len(tblInfo.Columns)) } @@ -172,8 +184,7 @@ func TestIndexLength(t *testing.T) { tracker.CreateTestDB() execCreate(t, tracker, sql) - tblInfo, err := tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) - require.NoError(t, err) + tblInfo := mustTableByName(t, tracker, "test", "t") expected := "CREATE TABLE `t` (\n" + " `a` text DEFAULT NULL,\n" + @@ -185,7 +196,7 @@ func TestIndexLength(t *testing.T) { ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin" checkShowCreateTable(t, tblInfo, expected) - err = tracker.DeleteTable(model.NewCIStr("test"), model.NewCIStr("t")) + err := tracker.DeleteTable(model.NewCIStr("test"), model.NewCIStr("t")) require.NoError(t, err) sql = "create table test.t(a text, b text charset ascii, c blob);" @@ -198,9 +209,7 @@ func TestIndexLength(t *testing.T) { sql = "alter table test.t add index (c(3072))" execAlter(t, tracker, sql) - tblInfo, err = tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) - require.NoError(t, err) - + tblInfo = mustTableByName(t, tracker, "test", "t") checkShowCreateTable(t, tblInfo, expected) } @@ -225,8 +234,7 @@ func TestIssue5092(t *testing.T) { sql = "alter table test.t add column b2 int after b1, add column c2 int first" execAlter(t, tracker, sql) - tblInfo, err := tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) - require.NoError(t, err) + tblInfo := mustTableByName(t, tracker, "test", "t") expected := "CREATE TABLE `t` (\n" + " `c2` int(11) DEFAULT NULL,\n" + @@ -303,8 +311,7 @@ func TestAddExpressionIndex(t *testing.T) { sql = "alter table test.t add index idx_multi((a+b),(a+1), b);" execAlter(t, tracker, sql) - tblInfo, err := tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) - require.NoError(t, err) + tblInfo := mustTableByName(t, tracker, "test", "t") expected := "CREATE TABLE `t` (\n" + " `a` int(11) DEFAULT NULL,\n" + @@ -319,6 +326,8 @@ func TestAddExpressionIndex(t *testing.T) { sql = "alter table test.t drop index idx_multi;" execAlter(t, tracker, sql) + tblInfo = mustTableByName(t, tracker, "test", "t") + expected = "CREATE TABLE `t` (\n" + " `a` int(11) DEFAULT NULL,\n" + " `b` double DEFAULT NULL\n" + @@ -330,8 +339,7 @@ func TestAddExpressionIndex(t *testing.T) { sql = "alter table test.t2 add unique index ei_ab ((concat(a, b)));" execAlter(t, tracker, sql) - tblInfo, err = tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t2")) - require.NoError(t, err) + tblInfo = mustTableByName(t, tracker, "test", "t2") expected = "CREATE TABLE `t2` (\n" + " `a` varchar(10) DEFAULT NULL,\n" + @@ -343,6 +351,8 @@ func TestAddExpressionIndex(t *testing.T) { sql = "alter table test.t2 alter index ei_ab invisible;" execAlter(t, tracker, sql) + tblInfo = mustTableByName(t, tracker, "test", "t2") + expected = "CREATE TABLE `t2` (\n" + " `a` varchar(10) DEFAULT NULL,\n" + " `b` varchar(10) DEFAULT NULL,\n" + @@ -353,8 +363,7 @@ func TestAddExpressionIndex(t *testing.T) { sql = "create table test.t3(a int, key((a+1)), key((a+2)), key idx((a+3)), key((a+4)), UNIQUE KEY ((a * 2)));" execCreate(t, tracker, sql) - tblInfo, err = tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t3")) - require.NoError(t, err) + tblInfo = mustTableByName(t, tracker, "test", "t3") expected = "CREATE TABLE `t3` (\n" + " `a` int(11) DEFAULT NULL,\n" + @@ -381,8 +390,7 @@ func TestAddExpressionIndex(t *testing.T) { sql = "alter table test.t4 add index idx((a+c));" execAlter(t, tracker, sql) - tblInfo, err = tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t4")) - require.NoError(t, err) + tblInfo = mustTableByName(t, tracker, "test", "t4") expected = "CREATE TABLE `t4` (\n" + " `a` int(11) DEFAULT NULL,\n" + @@ -408,8 +416,7 @@ func TestAtomicMultiSchemaChange(t *testing.T) { sql = "alter table test.t add b int, add c int;" execAlter(t, tracker, sql) - tblInfo, err := tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) - require.NoError(t, err) + tblInfo := mustTableByName(t, tracker, "test", "t") require.Len(t, tblInfo.Columns, 3) sql = "alter table test.t add d int, add a int;" @@ -422,11 +429,45 @@ func TestAtomicMultiSchemaChange(t *testing.T) { err = tracker.AlterTable(ctx, sctx, stmt.(*ast.AlterTableStmt)) require.True(t, infoschema.ErrColumnExists.Equal(err)) - tblInfo, err = tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) - require.NoError(t, err) + tblInfo = mustTableByName(t, tracker, "test", "t") require.Len(t, tblInfo.Columns, 3) } +func TestImmutableTableInfo(t *testing.T) { + sql := "create table test.t (a varchar(20)) charset latin1;" + + tracker := schematracker.NewSchemaTracker(2) + tracker.CreateTestDB() + execCreate(t, tracker, sql) + + tblInfo := mustTableByName(t, tracker, "test", "t") + require.Equal(t, "", tblInfo.Comment) + + sql = "alter table test.t comment = '123';" + execAlter(t, tracker, sql) + require.Equal(t, "", tblInfo.Comment) + + tblInfo = mustTableByName(t, tracker, "test", "t") + require.Equal(t, "123", tblInfo.Comment) + + require.Equal(t, "latin1", tblInfo.Charset) + require.Equal(t, "latin1_bin", tblInfo.Collate) + require.Equal(t, "latin1", tblInfo.Columns[0].GetCharset()) + require.Equal(t, "latin1_bin", tblInfo.Columns[0].GetCollate()) + + sql = "alter table test.t convert to character set utf8mb4 collate utf8mb4_general_ci;" + execAlter(t, tracker, sql) + require.Equal(t, "latin1", tblInfo.Charset) + require.Equal(t, "latin1_bin", tblInfo.Collate) + require.Equal(t, "latin1", tblInfo.Columns[0].GetCharset()) + require.Equal(t, "latin1_bin", tblInfo.Columns[0].GetCollate()) + tblInfo = mustTableByName(t, tracker, "test", "t") + require.Equal(t, "utf8mb4", tblInfo.Charset) + require.Equal(t, "utf8mb4_general_ci", tblInfo.Collate) + require.Equal(t, "utf8mb4", tblInfo.Columns[0].GetCharset()) + require.Equal(t, "utf8mb4_general_ci", tblInfo.Columns[0].GetCollate()) +} + var _ sqlexec.RestrictedSQLExecutor = (*mockRestrictedSQLExecutor)(nil) type mockRestrictedSQLExecutor struct { @@ -462,7 +503,6 @@ func TestModifyFromNullToNotNull(t *testing.T) { err = tracker.AlterTable(ctx, executorCtx, stmt.(*ast.AlterTableStmt)) require.NoError(t, err) - tblInfo, err := tracker.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) - require.NoError(t, err) + tblInfo := mustTableByName(t, tracker, "test", "t") require.Len(t, tblInfo.Columns, 2) } diff --git a/ddl/schematracker/info_store.go b/ddl/schematracker/info_store.go index 6c0739d960b3c..d6bb970591c8b 100644 --- a/ddl/schematracker/info_store.go +++ b/ddl/schematracker/info_store.go @@ -88,6 +88,15 @@ func (i *InfoStore) TableByName(schema, table model.CIStr) (*model.TableInfo, er return tbl, nil } +// TableClonedByName is like TableByName, plus it will clone the TableInfo. +func (i *InfoStore) TableClonedByName(schema, table model.CIStr) (*model.TableInfo, error) { + tbl, err := i.TableByName(schema, table) + if err != nil { + return nil, err + } + return tbl.Clone(), nil +} + // PutTable puts a TableInfo, it will overwrite the old one. If the schema doesn't exist, it will return ErrDatabaseNotExists. func (i *InfoStore) PutTable(schemaName model.CIStr, tblInfo *model.TableInfo) error { schemaKey := i.ciStr2Key(schemaName)