diff --git a/pkg/executor/redirect.go b/pkg/executor/redirect.go index a49d849c..8502aebf 100644 --- a/pkg/executor/redirect.go +++ b/pkg/executor/redirect.go @@ -282,7 +282,7 @@ func (executor *RedirectExecutor) doExecutorComQuery(ctx *proto.Context, act ast err = errNoDatabaseSelected } case *ast.TruncateTableStmt, *ast.DropTableStmt, *ast.ExplainStmt, *ast.DropIndexStmt, *ast.CreateIndexStmt, - *ast.AnalyzeTableStmt, *ast.OptimizeTableStmt, *ast.CheckTableStmt, *ast.RenameTableStmt: + *ast.AnalyzeTableStmt, *ast.OptimizeTableStmt, *ast.CheckTableStmt, *ast.RenameTableStmt, *ast.CreateTableStmt: res, warn, err = executeStmt(ctx, schemaless, rt) case *ast.DropTriggerStmt, *ast.SetStmt, *ast.KillStmt: res, warn, err = rt.Execute(ctx) diff --git a/pkg/runtime/ast/ast.go b/pkg/runtime/ast/ast.go index 232a2e27..354d29e2 100644 --- a/pkg/runtime/ast/ast.go +++ b/pkg/runtime/ast/ast.go @@ -131,6 +131,8 @@ func FromStmtNode(node ast.StmtNode) (Statement, error) { return cc.convOptimizeTable(stmt), nil case *ast.CheckTableStmt: return cc.convCheckTableStmt(stmt), nil + case *ast.CreateTableStmt: + return cc.convCreateTableStmt(stmt), nil case *ast.RenameTableStmt: return cc.convRenameTableStmt(stmt), nil case *ast.KillStmt: @@ -1687,6 +1689,27 @@ func (cc *convCtx) convCheckTableStmt(stmt *ast.CheckTableStmt) Statement { return &CheckTableStmt{Tables: tables} } +func (cc *convCtx) convCreateTableStmt(stmt *ast.CreateTableStmt) Statement { + table := &TableName{ + stmt.Table.Name.String(), + } + var refTable *TableName + if stmt.ReferTable != nil { + refTable = &TableName{ + stmt.ReferTable.Name.String(), + } + } + + return &CreateTableStmt{ + IfNotExists: stmt.IfNotExists, + Table: table, + ReferTable: refTable, + Cols: stmt.Cols, + Constraints: stmt.Constraints, + Options: stmt.Options, + } +} + func (cc *convCtx) convRenameTableStmt(stmt *ast.RenameTableStmt) Statement { tableToTables := make([]*TableToTable, len(stmt.TableToTables)) for i, tableToTable := range stmt.TableToTables { diff --git a/pkg/runtime/ast/create_table.go b/pkg/runtime/ast/create_table.go new file mode 100644 index 00000000..e1e131e7 --- /dev/null +++ b/pkg/runtime/ast/create_table.go @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ast + +import ( + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/parser/ast" + "github.com/arana-db/parser/format" +) + +var _ Statement = (*CreateTableStmt)(nil) + +type CreateTableStmt struct { + IfNotExists bool + //TemporaryKeyword + // Meanless when TemporaryKeyword is not TemporaryGlobal. + // ON COMMIT DELETE ROWS => true + // ON COMMIT PRESERVE ROW => false + //OnCommitDelete bool + Table *TableName + ReferTable *TableName + Cols []*ast.ColumnDef + Constraints []*ast.Constraint + Options []*ast.TableOption + //Partition *PartitionOptions + //OnDuplicate OnDuplicateKeyHandlingType + //Select ResultSetNode +} + +func NewCreateTableStmt() *CreateTableStmt { + return &CreateTableStmt{} +} + +func (c *CreateTableStmt) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + sb.WriteString("CREATE TABLE ") + if c.IfNotExists { + sb.WriteString(" IF NOT EXISTS ") + } + if err := c.Table.Restore(flag, sb, args); err != nil { + return errors.Wrapf(err, "An error occurred while restore AnalyzeTableStatement.Tables[%s]", c.Table) + } + + if c.ReferTable != nil { + sb.WriteString(" LIKE ") + if err := c.ReferTable.Restore(flag, sb, args); err != nil { + return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt ReferTable") + } + } + + rsCtx := format.NewRestoreCtx(format.RestoreFlags(flag), sb) + + lenCols := len(c.Cols) + lenConstraints := len(c.Constraints) + if lenCols+lenConstraints > 0 { + sb.WriteString(" (") + for i, col := range c.Cols { + if i > 0 { + sb.WriteString(",") + } + if err := col.Restore(rsCtx); err != nil { + return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt ColumnDef: [%v]", i) + } + } + for i, constraint := range c.Constraints { + if i > 0 || lenCols >= 1 { + sb.WriteString(",") + } + if err := constraint.Restore(rsCtx); err != nil { + return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt Constraints: [%v]", i) + } + } + sb.WriteString(")") + } + + for i, option := range c.Options { + sb.WriteString(" ") + if err := option.Restore(rsCtx); err != nil { + return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt TableOption: [%v]", i) + } + } + + return nil +} + +func (c *CreateTableStmt) Mode() SQLType { + return SQLTypeCreateTable +} diff --git a/pkg/runtime/ast/proto.go b/pkg/runtime/ast/proto.go index 43401f35..0c0bbefb 100644 --- a/pkg/runtime/ast/proto.go +++ b/pkg/runtime/ast/proto.go @@ -62,6 +62,7 @@ const ( SQLTypeKill // KILL SQLTypeCheckTable // CHECK TABLE SQLTypeRenameTable // RENAME TABLE + SQLTypeCreateTable // CREATE TABLE ) var _sqlTypeNames = [...]string{ @@ -97,6 +98,7 @@ var _sqlTypeNames = [...]string{ SQLTypeKill: "KILL", SQLTypeCheckTable: "CHECK TABLE", SQLTypeRenameTable: "RENAME TABLE", + SQLTypeCreateTable: "CREATE TABLE", } // SQLType represents the type of SQL. diff --git a/pkg/runtime/optimize/ddl/create_table.go b/pkg/runtime/optimize/ddl/create_table.go new file mode 100644 index 00000000..81e1ef97 --- /dev/null +++ b/pkg/runtime/optimize/ddl/create_table.go @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/ddl" + "github.com/arana-db/arana/pkg/runtime/plan/dml" + "github.com/arana-db/arana/pkg/util/log" +) + +func init() { + optimize.Register(ast.SQLTypeCreateTable, optimizeCreateTable) +} + +func optimizeCreateTable(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.CreateTableStmt) + + var ( + shards rule.DatabaseTables + fullScan bool + ) + vt, ok := o.Rule.VTable(stmt.Table.Suffix()) + fullScan = ok + + log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) + + toSingle := func(db, tbl string) (proto.Plan, error) { + ret := &ddl.CreateTablePlan{ + Stmt: stmt, + Database: db, + Tables: []string{tbl}, + } + ret.BindArgs(o.Args) + + return ret, nil + } + + // Go through first table if not full scan. + if !fullScan { + return toSingle("", stmt.Table.Suffix()) + } + + // expand all shards if all shards matched + shards = vt.Topology().Enumerate() + + plans := make([]proto.Plan, 0, len(shards)) + for k, v := range shards { + next := &ddl.CreateTablePlan{ + Database: k, + Tables: v, + Stmt: stmt, + } + next.BindArgs(o.Args) + plans = append(plans, next) + } + + tmpPlan := &dml.CompositePlan{ + Plans: plans, + } + + return tmpPlan, nil +} diff --git a/pkg/runtime/plan/ddl/create_table.go b/pkg/runtime/plan/ddl/create_table.go new file mode 100644 index 00000000..b563eb7e --- /dev/null +++ b/pkg/runtime/plan/ddl/create_table.go @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" + "strings" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +type CreateTablePlan struct { + plan.BasePlan + Stmt *ast.CreateTableStmt + Database string + Tables []string +} + +func NewCreateTablePlan( + stmt *ast.CreateTableStmt, + db string, + tb []string, +) *CreateTablePlan { + return &CreateTablePlan{ + Stmt: stmt, + Database: db, + Tables: tb, + } +} + +// Type get plan type +func (c *CreateTablePlan) Type() proto.PlanType { + return proto.PlanTypeExec +} + +func (c *CreateTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + args []int + err error + ) + + ctx, span := plan.Tracer.Start(ctx, "CreateTable.ExecIn") + defer span.End() + + switch len(c.Tables) { + case 0: + // no table reset + return resultx.New(), nil + case 1: + // single shard table + if err := c.Stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { + return nil, err + } + if _, err = conn.Query(ctx, c.Database, sb.String(), c.ToArgs(args)...); err != nil { + return nil, err + } + default: + // multiple shard tables + stmt := new(ast.CreateTableStmt) + *stmt = *c.Stmt // do copy + + restore := func(table string) error { + sb.Reset() + if err = c.resetTable(stmt, table); err != nil { + return err + } + if err = stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { + return err + } + if _, err = conn.Query(ctx, c.Database, sb.String(), c.ToArgs(args)...); err != nil { + return err + } + return nil + } + + for i := 0; i < len(c.Tables); i++ { + if err := restore(c.Tables[i]); err != nil { + return nil, err + } + } + } + + return resultx.New(), nil +} + +func (c *CreateTablePlan) resetTable(stmt *ast.CreateTableStmt, table string) error { + stmt.Table = &ast.TableName{ + table, + } + + return nil +}