Skip to content

Commit

Permalink
ddl: check column and partition value have same type for range column…
Browse files Browse the repository at this point in the history
… partition (pingcap#12664)
  • Loading branch information
tiancaiamao authored and XiaTianliang committed Dec 21, 2019
1 parent 08236b3 commit aa6bfe3
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
17 changes: 17 additions & 0 deletions ddl/db_partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ create table log_message_1 (
"partition p1 values less than (1, 'a'))",
ddl.ErrRangeNotIncreasing,
},
{
"create table t (col datetime not null default '2000-01-01')" +
"partition by range columns (col) (" +
"PARTITION p0 VALUES LESS THAN (20190905)," +
"PARTITION p1 VALUES LESS THAN (20190906));",
ddl.ErrWrongTypeColumnValue,
},
}
for i, t := range cases {
_, err := tk.Exec(t.sql)
Expand Down Expand Up @@ -544,6 +551,16 @@ func (s *testIntegrationSuite5) TestAlterTableAddPartition(c *C) {
sql := "alter table t add partition ( partition p3 values less than ('2019-07-01'));"
tk.MustGetErrCode(sql, tmysql.ErrRangeNotIncreasing)
tk.MustExec("alter table t add partition ( partition p3 values less than ('2019-08-01'));")

// Add partition value's type should be the same with the column's type.
tk.MustExec("drop table if exists t;")
tk.MustExec(`create table t (
col date not null default '2000-01-01')
partition by range columns (col) (
PARTITION p0 VALUES LESS THAN ('20190905'),
PARTITION p1 VALUES LESS THAN ('20190906'));`)
sql = "alter table t add partition (partition p2 values less than (20190907));"
tk.MustGetErrCode(sql, tmysql.ErrWrongTypeColumnValue)
}

func (s *testIntegrationSuite5) TestAlterTableDropPartition(c *C) {
Expand Down
4 changes: 4 additions & 0 deletions ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ var (
ErrTableCantHandleFt = terror.ClassDDL.New(codeErrTableCantHandleFt, mysql.MySQLErrName[mysql.ErrTableCantHandleFt])
// ErrFieldNotFoundPart returns an error when 'partition by columns' are not found in table columns.
ErrFieldNotFoundPart = terror.ClassDDL.New(codeFieldNotFoundPart, mysql.MySQLErrName[mysql.ErrFieldNotFoundPart])
// ErrWrongTypeColumnValue returns 'Partition column values of incorrect type'
ErrWrongTypeColumnValue = terror.ClassDDL.New(codeWrongTypeColumnValue, mysql.MySQLErrName[mysql.ErrWrongTypeColumnValue])
)

// DDL is responsible for updating schema in data store and maintaining in-memory InfoSchema cache.
Expand Down Expand Up @@ -760,6 +762,7 @@ const (
codeSystemVersioningWrongPartitions = terror.ErrCode(mysql.ErrSystemVersioningWrongPartitions)
codeWrongPartitionTypeExpectedSystemTime = terror.ErrCode(mysql.ErrWrongPartitionTypeExpectedSystemTime)
codeOnlyOnRangeListPartition = terror.ErrCode(mysql.ErrOnlyOnRangeListPartition)
codeWrongTypeColumnValue = terror.ErrCode(mysql.ErrWrongTypeColumnValue)
)

func init() {
Expand Down Expand Up @@ -834,6 +837,7 @@ func init() {
codeSystemVersioningWrongPartitions: mysql.ErrSystemVersioningWrongPartitions,
codeWrongPartitionTypeExpectedSystemTime: mysql.ErrWrongPartitionTypeExpectedSystemTime,
codeOnlyOnRangeListPartition: mysql.ErrOnlyOnRangeListPartition,
codeWrongTypeColumnValue: mysql.ErrWrongTypeColumnValue,
}
terror.ErrClassToMySQLCodes[terror.ClassDDL] = ddlMySQLErrCodes
}
54 changes: 52 additions & 2 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,15 @@ func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *
return err
}

if s != nil {
for _, def := range s.Partition.Definitions {
exprs := def.Clause.(*ast.PartitionDefinitionClauseLessThan).Exprs
if err := checkRangeColumnsTypeAndValuesMatch(ctx, tbInfo, pi.Columns, exprs); err != nil {
return err
}
}
}

return checkRangeColumnsPartitionValue(ctx, tbInfo, pi)
}

Expand Down Expand Up @@ -2212,7 +2221,7 @@ func (d *ddl) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec *
return errors.Trace(ErrPartitionMgmtOnNonpartitioned)
}

partInfo, err := buildPartitionInfo(meta, d, spec)
partInfo, err := buildPartitionInfo(ctx, meta, d, spec)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -3527,7 +3536,7 @@ func validateCommentLength(vars *variable.SessionVars, comment string, maxLen in
return comment, nil
}

func buildPartitionInfo(meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) {
func buildPartitionInfo(ctx sessionctx.Context, meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) {
if meta.Partition.Type == model.PartitionTypeRange {
if len(spec.PartDefinitions) == 0 {
return nil, ast.ErrPartitionsMustBeDefined.GenWithStackByArgs(meta.Partition.Type)
Expand All @@ -3554,6 +3563,12 @@ func buildPartitionInfo(meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec)
}
// For RANGE partition only VALUES LESS THAN should be possible.
clause := def.Clause.(*ast.PartitionDefinitionClauseLessThan)
if len(part.Columns) > 0 {
if err := checkRangeColumnsTypeAndValuesMatch(ctx, meta, part.Columns, clause.Exprs); err != nil {
return nil, err
}
}

comment, _ := def.Comment()
piDef := model.PartitionDefinition{
Name: def.Name,
Expand All @@ -3572,6 +3587,41 @@ func buildPartitionInfo(meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec)
return part, nil
}

func checkRangeColumnsTypeAndValuesMatch(ctx sessionctx.Context, meta *model.TableInfo, colNames []model.CIStr, exprs []ast.ExprNode) error {
// Validate() has already checked len(colNames) = len(exprs)
// create table ... partition by range columns (cols)
// partition p0 values less than (expr)
// check the type of cols[i] and expr is consistent.
for i, colExpr := range exprs {
if _, ok := colExpr.(*ast.MaxValueExpr); ok {
continue
}

colName := colNames[i]
colInfo := getColumnInfoByName(meta, colName.L)
if colInfo == nil {
return errors.Trace(ErrFieldNotFoundPart)
}
colType := &colInfo.FieldType

val, err := expression.EvalAstExpr(ctx, colExpr)
if err != nil {
return err
}

// Check val.ConvertTo(colType) doesn't work, so we need this case by case check.
switch colType.Tp {
case mysql.TypeDate, mysql.TypeDatetime:
switch val.Kind() {
case types.KindString, types.KindBytes:
default:
return ErrWrongTypeColumnValue.GenWithStackByArgs()
}
}
}
return nil
}

// LockTables uses to execute lock tables statement.
func (d *ddl) LockTables(ctx sessionctx.Context, stmt *ast.LockTablesStmt) error {
lockTables := make([]model.TableLockTpInfo, 0, len(stmt.TableLocks))
Expand Down

0 comments on commit aa6bfe3

Please sign in to comment.