Skip to content

Commit

Permalink
Support create table SQL statement (#424) (#653)
Browse files Browse the repository at this point in the history
* add node config support (#464)

* Support MySQL CAST_CHAR function.

* format style

* Support MySQL CAST_TIME function. (#570)

* Support MySQL CAST_DATE function. (#569)

* Support MySQL CAST_DATETIME function. (#568)

* Support MySQL CAST_TIME/CAST_DATE/CAST_DATETIME function

* Resolve Conversation

* Support CREATE TABLE

* add: IfNotExists

* fix: reformat imports

* Resolve Conversation
  • Loading branch information
csynineyang authored Apr 18, 2023
1 parent a8c1c6a commit a00866f
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pkg/executor/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (executor *RedirectExecutor) doExecutorComQuery(ctx *proto.Context, act ast
err = errNoDatabaseSelected
}
case *ast.TruncateTableStmt, *ast.DropTableStmt, *ast.ExplainStmt, *ast.DropIndexStmt, *ast.CreateIndexStmt,
*ast.AnalyzeTableStmt, *ast.OptimizeTableStmt, *ast.CheckTableStmt, *ast.RenameTableStmt:
*ast.AnalyzeTableStmt, *ast.OptimizeTableStmt, *ast.CheckTableStmt, *ast.RenameTableStmt, *ast.CreateTableStmt:
res, warn, err = executeStmt(ctx, schemaless, rt)
case *ast.DropTriggerStmt, *ast.SetStmt, *ast.KillStmt:
res, warn, err = rt.Execute(ctx)
Expand Down
23 changes: 23 additions & 0 deletions pkg/runtime/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ func FromStmtNode(node ast.StmtNode) (Statement, error) {
return cc.convOptimizeTable(stmt), nil
case *ast.CheckTableStmt:
return cc.convCheckTableStmt(stmt), nil
case *ast.CreateTableStmt:
return cc.convCreateTableStmt(stmt), nil
case *ast.RenameTableStmt:
return cc.convRenameTableStmt(stmt), nil
case *ast.KillStmt:
Expand Down Expand Up @@ -1687,6 +1689,27 @@ func (cc *convCtx) convCheckTableStmt(stmt *ast.CheckTableStmt) Statement {
return &CheckTableStmt{Tables: tables}
}

func (cc *convCtx) convCreateTableStmt(stmt *ast.CreateTableStmt) Statement {
table := &TableName{
stmt.Table.Name.String(),
}
var refTable *TableName
if stmt.ReferTable != nil {
refTable = &TableName{
stmt.ReferTable.Name.String(),
}
}

return &CreateTableStmt{
IfNotExists: stmt.IfNotExists,
Table: table,
ReferTable: refTable,
Cols: stmt.Cols,
Constraints: stmt.Constraints,
Options: stmt.Options,
}
}

func (cc *convCtx) convRenameTableStmt(stmt *ast.RenameTableStmt) Statement {
tableToTables := make([]*TableToTable, len(stmt.TableToTables))
for i, tableToTable := range stmt.TableToTables {
Expand Down
109 changes: 109 additions & 0 deletions pkg/runtime/ast/create_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ast

import (
"strings"
)

import (
"github.com/pkg/errors"
)

import (
"github.com/arana-db/parser/ast"
"github.com/arana-db/parser/format"
)

var _ Statement = (*CreateTableStmt)(nil)

type CreateTableStmt struct {
IfNotExists bool
//TemporaryKeyword
// Meanless when TemporaryKeyword is not TemporaryGlobal.
// ON COMMIT DELETE ROWS => true
// ON COMMIT PRESERVE ROW => false
//OnCommitDelete bool
Table *TableName
ReferTable *TableName
Cols []*ast.ColumnDef
Constraints []*ast.Constraint
Options []*ast.TableOption
//Partition *PartitionOptions
//OnDuplicate OnDuplicateKeyHandlingType
//Select ResultSetNode
}

func NewCreateTableStmt() *CreateTableStmt {
return &CreateTableStmt{}
}

func (c *CreateTableStmt) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
sb.WriteString("CREATE TABLE ")
if c.IfNotExists {
sb.WriteString(" IF NOT EXISTS ")
}
if err := c.Table.Restore(flag, sb, args); err != nil {
return errors.Wrapf(err, "An error occurred while restore AnalyzeTableStatement.Tables[%s]", c.Table)
}

if c.ReferTable != nil {
sb.WriteString(" LIKE ")
if err := c.ReferTable.Restore(flag, sb, args); err != nil {
return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt ReferTable")
}
}

rsCtx := format.NewRestoreCtx(format.RestoreFlags(flag), sb)

lenCols := len(c.Cols)
lenConstraints := len(c.Constraints)
if lenCols+lenConstraints > 0 {
sb.WriteString(" (")
for i, col := range c.Cols {
if i > 0 {
sb.WriteString(",")
}
if err := col.Restore(rsCtx); err != nil {
return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt ColumnDef: [%v]", i)
}
}
for i, constraint := range c.Constraints {
if i > 0 || lenCols >= 1 {
sb.WriteString(",")
}
if err := constraint.Restore(rsCtx); err != nil {
return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt Constraints: [%v]", i)
}
}
sb.WriteString(")")
}

for i, option := range c.Options {
sb.WriteString(" ")
if err := option.Restore(rsCtx); err != nil {
return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt TableOption: [%v]", i)
}
}

return nil
}

func (c *CreateTableStmt) Mode() SQLType {
return SQLTypeCreateTable
}
2 changes: 2 additions & 0 deletions pkg/runtime/ast/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ const (
SQLTypeKill // KILL
SQLTypeCheckTable // CHECK TABLE
SQLTypeRenameTable // RENAME TABLE
SQLTypeCreateTable // CREATE TABLE
)

var _sqlTypeNames = [...]string{
Expand Down Expand Up @@ -97,6 +98,7 @@ var _sqlTypeNames = [...]string{
SQLTypeKill: "KILL",
SQLTypeCheckTable: "CHECK TABLE",
SQLTypeRenameTable: "RENAME TABLE",
SQLTypeCreateTable: "CREATE TABLE",
}

// SQLType represents the type of SQL.
Expand Down
85 changes: 85 additions & 0 deletions pkg/runtime/optimize/ddl/create_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ddl

import (
"context"
)

import (
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/proto/rule"
"github.com/arana-db/arana/pkg/runtime/ast"
"github.com/arana-db/arana/pkg/runtime/optimize"
"github.com/arana-db/arana/pkg/runtime/plan/ddl"
"github.com/arana-db/arana/pkg/runtime/plan/dml"
"github.com/arana-db/arana/pkg/util/log"
)

func init() {
optimize.Register(ast.SQLTypeCreateTable, optimizeCreateTable)
}

func optimizeCreateTable(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) {
stmt := o.Stmt.(*ast.CreateTableStmt)

var (
shards rule.DatabaseTables
fullScan bool
)
vt, ok := o.Rule.VTable(stmt.Table.Suffix())
fullScan = ok

log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan)

toSingle := func(db, tbl string) (proto.Plan, error) {
ret := &ddl.CreateTablePlan{
Stmt: stmt,
Database: db,
Tables: []string{tbl},
}
ret.BindArgs(o.Args)

return ret, nil
}

// Go through first table if not full scan.
if !fullScan {
return toSingle("", stmt.Table.Suffix())
}

// expand all shards if all shards matched
shards = vt.Topology().Enumerate()

plans := make([]proto.Plan, 0, len(shards))
for k, v := range shards {
next := &ddl.CreateTablePlan{
Database: k,
Tables: v,
Stmt: stmt,
}
next.BindArgs(o.Args)
plans = append(plans, next)
}

tmpPlan := &dml.CompositePlan{
Plans: plans,
}

return tmpPlan, nil
}
113 changes: 113 additions & 0 deletions pkg/runtime/plan/ddl/create_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ddl

import (
"context"
"strings"
)

import (
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
"github.com/arana-db/arana/pkg/runtime/plan"
)

type CreateTablePlan struct {
plan.BasePlan
Stmt *ast.CreateTableStmt
Database string
Tables []string
}

func NewCreateTablePlan(
stmt *ast.CreateTableStmt,
db string,
tb []string,
) *CreateTablePlan {
return &CreateTablePlan{
Stmt: stmt,
Database: db,
Tables: tb,
}
}

// Type get plan type
func (c *CreateTablePlan) Type() proto.PlanType {
return proto.PlanTypeExec
}

func (c *CreateTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
var (
sb strings.Builder
args []int
err error
)

ctx, span := plan.Tracer.Start(ctx, "CreateTable.ExecIn")
defer span.End()

switch len(c.Tables) {
case 0:
// no table reset
return resultx.New(), nil
case 1:
// single shard table
if err := c.Stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil {
return nil, err
}
if _, err = conn.Query(ctx, c.Database, sb.String(), c.ToArgs(args)...); err != nil {
return nil, err
}
default:
// multiple shard tables
stmt := new(ast.CreateTableStmt)
*stmt = *c.Stmt // do copy

restore := func(table string) error {
sb.Reset()
if err = c.resetTable(stmt, table); err != nil {
return err
}
if err = stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil {
return err
}
if _, err = conn.Query(ctx, c.Database, sb.String(), c.ToArgs(args)...); err != nil {
return err
}
return nil
}

for i := 0; i < len(c.Tables); i++ {
if err := restore(c.Tables[i]); err != nil {
return nil, err
}
}
}

return resultx.New(), nil
}

func (c *CreateTablePlan) resetTable(stmt *ast.CreateTableStmt, table string) error {
stmt.Table = &ast.TableName{
table,
}

return nil
}

0 comments on commit a00866f

Please sign in to comment.