diff --git a/ddl/db_partition_test.go b/ddl/db_partition_test.go index 3ebab9da7b640..52d04e67c1883 100644 --- a/ddl/db_partition_test.go +++ b/ddl/db_partition_test.go @@ -106,12 +106,12 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) { sql4 := `create table t4 ( a int not null, - b int not null + b int not null ) partition by range( id ) ( partition p1 values less than maxvalue, - partition p2 values less than (1991), - partition p3 values less than (1995) + partition p2 values less than (1991), + partition p3 values less than (1995) );` assertErrorCode(c, tk, sql4, tmysql.ErrPartitionMaxvalue) @@ -121,10 +121,10 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) { c INT NOT NULL ) partition by range columns(a,b,c) ( - partition p0 values less than (10,5,1), - partition p2 values less than (50,maxvalue,10), - partition p3 values less than (65,30,13), - partition p4 values less than (maxvalue,30,40) + partition p0 values less than (10,5,1), + partition p2 values less than (50,maxvalue,10), + partition p3 values less than (65,30,13), + partition p4 values less than (maxvalue,30,40) );`) c.Assert(err, IsNil) @@ -139,13 +139,13 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) { sql7 := `create table t7 ( a int not null, - b int not null + b int not null ) partition by range( id ) ( partition p1 values less than (1991), partition p2 values less than maxvalue, - partition p3 values less than maxvalue, - partition p4 values less than (1995), + partition p3 values less than maxvalue, + partition p4 values less than (1995), partition p5 values less than maxvalue );` assertErrorCode(c, tk, sql7, tmysql.ErrPartitionMaxvalue) @@ -230,6 +230,9 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) { assertErrorCode(c, tk, `create table t31 (a int not null) partition by range( a );`, tmysql.ErrPartitionsMustBeDefined) assertErrorCode(c, tk, `create table t32 (a int not null) partition by range columns( a );`, tmysql.ErrPartitionsMustBeDefined) assertErrorCode(c, tk, `create table t33 (a int, b int) partition by hash(a) partitions 0;`, tmysql.ErrNoParts) + assertErrorCode(c, tk, `create table t33 (a timestamp, b int) partition by hash(a) partitions 30;`, tmysql.ErrFieldTypeNotAllowedAsPartitionField) + // TODO: fix this one + // assertErrorCode(c, tk, `create table t33 (a timestamp, b int) partition by hash(unix_timestamp(a)) partitions 30;`, tmysql.ErrPartitionFuncNotAllowed) } func (s *testIntegrationSuite7) TestCreateTableWithHashPartition(c *C) { diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 16650d6dbe9d6..73df8ba066690 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -1141,7 +1141,7 @@ func buildTableInfoWithCheck(ctx sessionctx.Context, d *ddl, s *ast.CreateTableS err = checkPartitionByRangeColumn(ctx, tbInfo, pi, s) } case model.PartitionTypeHash: - err = checkPartitionByHash(pi) + err = checkPartitionByHash(ctx, pi, s, cols, tbInfo) } if err != nil { return nil, errors.Trace(err) @@ -1365,41 +1365,41 @@ func buildViewInfoWithTableColumns(ctx sessionctx.Context, s *ast.CreateViewStmt return viewInfo, tableColumns } -func checkPartitionByHash(pi *model.PartitionInfo) error { +func checkPartitionByHash(ctx sessionctx.Context, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, tbInfo *model.TableInfo) error { if err := checkAddPartitionTooManyPartitions(pi.Num); err != nil { - return errors.Trace(err) + return err } - if err := checkNoHashPartitions(pi.Num); err != nil { - return errors.Trace(err) + if err := checkNoHashPartitions(ctx, pi.Num); err != nil { + return err } - return nil + if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil { + return err + } + return checkPartitionFuncType(ctx, s, cols, tbInfo) } func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, newConstraints []*ast.Constraint) error { if err := checkPartitionNameUnique(tbInfo, pi); err != nil { - return errors.Trace(err) + return err } if err := checkCreatePartitionValue(ctx, tbInfo, pi, cols); err != nil { - return errors.Trace(err) + return err } if err := checkAddPartitionTooManyPartitions(uint64(len(pi.Definitions))); err != nil { - return errors.Trace(err) + return err } if err := checkNoRangePartitions(len(pi.Definitions)); err != nil { - return errors.Trace(err) + return err } if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil { - return errors.Trace(err) + return err } - if err := checkPartitionFuncType(ctx, s, cols, tbInfo); err != nil { - return errors.Trace(err) - } - return nil + return checkPartitionFuncType(ctx, s, cols, tbInfo) } func checkPartitionByRangeColumn(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt) error { diff --git a/ddl/partition.go b/ddl/partition.go index 62481a2bdaf53..535ca98f1674b 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -202,7 +202,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, s *ast.CreateTableStmt, cols buf := new(bytes.Buffer) s.Partition.Expr.Format(buf) exprStr := buf.String() - if s.Partition.Tp == model.PartitionTypeRange { + if s.Partition.Tp == model.PartitionTypeRange || s.Partition.Tp == model.PartitionTypeHash { // if partition by columnExpr, check the column type if _, ok := s.Partition.Expr.(*ast.ColumnNameExpr); ok { for _, col := range cols { @@ -215,13 +215,19 @@ func checkPartitionFuncType(ctx sessionctx.Context, s *ast.CreateTableStmt, cols } } - e, err := expression.ParseSimpleExprWithTableInfo(ctx, buf.String(), tblInfo) + e, err := expression.ParseSimpleExprWithTableInfo(ctx, exprStr, tblInfo) if err != nil { return errors.Trace(err) } if e.GetType().EvalType() == types.ETInt { return nil } + if s.Partition.Tp == model.PartitionTypeHash { + if _, ok := s.Partition.Expr.(*ast.ColumnNameExpr); ok { + return ErrNotAllowedTypeInPartition.GenWithStackByArgs(exprStr) + } + } + return ErrPartitionFuncNotAllowed.GenWithStackByArgs("PARTITION") } @@ -428,7 +434,7 @@ func checkAddPartitionTooManyPartitions(piDefs uint64) error { return nil } -func checkNoHashPartitions(partitionNum uint64) error { +func checkNoHashPartitions(ctx sessionctx.Context, partitionNum uint64) error { if partitionNum == 0 { return ErrNoParts.GenWithStackByArgs("partitions") }