Skip to content

Commit

Permalink
extension: make RelatedTables work when the statement fails (#50989) (
Browse files Browse the repository at this point in the history
#53310)

close #50988
  • Loading branch information
ti-chi-bot authored May 16, 2024
1 parent 19b3128 commit ce9e337
Show file tree
Hide file tree
Showing 9 changed files with 476 additions and 102 deletions.
9 changes: 8 additions & 1 deletion pkg/extension/event_listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,18 @@ func TestExtensionStmtEvents(t *testing.T) {
dispatchData: append([]byte{mysql.ComInitDB}, []byte("db1")...),
originalText: "use `db1`",
redactText: "use `db1`",
tables: []stmtctx.TableEntry{
{DB: "db1", Table: ""},
},
},
{
dispatchData: append([]byte{mysql.ComInitDB}, []byte("noexistdb")...),
originalText: "use `noexistdb`",
redactText: "use `noexistdb`",
err: "[schema:1049]Unknown database 'noexistdb'",
tables: []stmtctx.TableEntry{
{DB: "noexistdb", Table: ""},
},
},
{
sql: "set @@tidb_session_alias='alias123'",
Expand Down Expand Up @@ -448,7 +454,8 @@ func TestExtensionStmtEvents(t *testing.T) {
r := record.tables[j]
return l.DB < r.DB || (l.DB == r.DB && l.Table < r.Table)
})
require.Equal(t, subCase.tables, record.tables)
require.Equal(t, subCase.tables, record.tables,
"sql: %s\noriginalText: %s\n", subCase.sql, subCase.originalText)

require.Equal(t, len(subCase.executeParams), len(record.params))
for k, param := range subCase.executeParams {
Expand Down
2 changes: 2 additions & 0 deletions pkg/extension/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ type StmtEventInfo interface {
// AffectedRows will return the affected rows of the current statement
AffectedRows() uint64
// RelatedTables will return the related tables of the current statement
// For statements succeeding to build logical plan, it uses the `visitinfo` to get the related tables
// For statements failing to build logical plan, it traverses the ast node to get the related tables
RelatedTables() []stmtctx.TableEntry
// GetError will return the error when the current statement is failed
GetError() error
Expand Down
7 changes: 7 additions & 0 deletions pkg/parser/ast/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,13 @@ func (n *FlushStmt) Accept(v Visitor) (Node, bool) {
return v.Leave(newNode)
}
n = newNode.(*FlushStmt)
for i, t := range n.Tables {
node, ok := t.Accept(v)
if !ok {
return n, false
}
n.Tables[i] = node.(*TableName)
}
return v.Leave(n)
}

Expand Down
1 change: 1 addition & 0 deletions pkg/planner/core/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ go_test(
"rule_join_reorder_test.go",
"runtime_filter_generator_test.go",
"stringer_test.go",
"util_test.go",
],
data = glob(["testdata/**"]),
embed = [":core"],
Expand Down
214 changes: 119 additions & 95 deletions pkg/planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3419,8 +3419,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) {
}

func tblInfoFromCol(from ast.ResultSetNode, name *types.FieldName) *model.TableInfo {
var tableList []*ast.TableName
tableList = extractTableList(from, tableList, true)
tableList := ExtractTableList(from, true)
for _, field := range tableList {
if field.Name.L == name.TblName.L {
return field.TableInfo
Expand Down Expand Up @@ -6094,8 +6093,7 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (
return nil, err
}

var tableList []*ast.TableName
tableList = extractTableList(update.TableRefs.TableRefs, tableList, false)
tableList := ExtractTableList(update.TableRefs.TableRefs, false)
for _, t := range tableList {
dbName := t.Schema.L
if dbName == "" {
Expand Down Expand Up @@ -6623,8 +6621,7 @@ func (b *PlanBuilder) buildDelete(ctx context.Context, ds *ast.DeleteStmt) (Plan
}
} else {
// Delete from a, b, c, d.
var tableList []*ast.TableName
tableList = extractTableList(ds.TableRefs.TableRefs, tableList, false)
tableList := ExtractTableList(ds.TableRefs.TableRefs, false)
for _, v := range tableList {
if isCTE(v) {
return nil, ErrNonUpdatableTable.GenWithStackByArgs(v.Name.O, "DELETE")
Expand Down Expand Up @@ -7444,17 +7441,6 @@ func buildWindowSpecs(specs []ast.WindowSpec) (map[string]*ast.WindowSpec, error
return specsMap, nil
}

func unfoldSelectList(list *ast.SetOprSelectList, unfoldList *ast.SetOprSelectList) {
for _, sel := range list.Selects {
switch s := sel.(type) {
case *ast.SelectStmt:
unfoldList.Selects = append(unfoldList.Selects, s)
case *ast.SetOprSelectList:
unfoldSelectList(s, unfoldList)
}
}
}

type updatableTableListResolver struct {
updatableTableList []*ast.TableName
}
Expand Down Expand Up @@ -7483,111 +7469,149 @@ func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) {
return inNode, true
}

// extractTableList extracts all the TableNames from node.
// ExtractTableList is a wrapper for tableListExtractor and removes duplicate TableName
// If asName is true, extract AsName prior to OrigName.
// Privilege check should use OrigName, while expression may use AsName.
// TODO: extracting all tables by vistor model maybe a better way
func extractTableList(node ast.Node, input []*ast.TableName, asName bool) []*ast.TableName {
switch x := node.(type) {
case *ast.SelectStmt:
if x.From != nil {
input = extractTableList(x.From.TableRefs, input, asName)
}
if x.Where != nil {
input = extractTableList(x.Where, input, asName)
}
if x.With != nil {
for _, cte := range x.With.CTEs {
input = extractTableList(cte.Query, input, asName)
}
}
for _, f := range x.Fields.Fields {
if s, ok := f.Expr.(*ast.SubqueryExpr); ok {
input = extractTableList(s, input, asName)
}
}
case *ast.DeleteStmt:
input = extractTableList(x.TableRefs.TableRefs, input, asName)
if x.IsMultiTable {
for _, t := range x.Tables.Tables {
input = extractTableList(t, input, asName)
}
}
if x.Where != nil {
input = extractTableList(x.Where, input, asName)
}
if x.With != nil {
for _, cte := range x.With.CTEs {
input = extractTableList(cte.Query, input, asName)
}
}
case *ast.UpdateStmt:
input = extractTableList(x.TableRefs.TableRefs, input, asName)
for _, e := range x.List {
input = extractTableList(e.Expr, input, asName)
}
if x.Where != nil {
input = extractTableList(x.Where, input, asName)
}
if x.With != nil {
for _, cte := range x.With.CTEs {
input = extractTableList(cte.Query, input, asName)
func ExtractTableList(node ast.Node, asName bool) []*ast.TableName {
if node == nil {
return []*ast.TableName{}
}
e := &tableListExtractor{
asName: asName,
tableNames: []*ast.TableName{},
}
node.Accept(e)
tableNames := e.tableNames
m := make(map[string]map[string]*ast.TableName) // k1: schemaName, k2: tableName, v: ast.TableName
for _, x := range tableNames {
k1, k2 := x.Schema.L, x.Name.L
// allow empty schema name OR empty table name
if k1 != "" || k2 != "" {
if _, ok := m[k1]; !ok {
m[k1] = make(map[string]*ast.TableName)
}
m[k1][k2] = x
}
case *ast.InsertStmt:
input = extractTableList(x.Table.TableRefs, input, asName)
input = extractTableList(x.Select, input, asName)
case *ast.SetOprStmt:
l := &ast.SetOprSelectList{}
unfoldSelectList(x.SelectList, l)
for _, s := range l.Selects {
input = extractTableList(s.(ast.ResultSetNode), input, asName)
}
case *ast.PatternInExpr:
if s, ok := x.Sel.(*ast.SubqueryExpr); ok {
input = extractTableList(s, input, asName)
}
tableNames = tableNames[:0]
for _, x := range m {
for _, v := range x {
tableNames = append(tableNames, v)
}
case *ast.ExistsSubqueryExpr:
if s, ok := x.Sel.(*ast.SubqueryExpr); ok {
input = extractTableList(s, input, asName)
}
return tableNames
}

// tableListExtractor extracts all the TableNames from node.
type tableListExtractor struct {
asName bool
tableNames []*ast.TableName
}

func (e *tableListExtractor) Enter(n ast.Node) (_ ast.Node, skipChildren bool) {
innerExtract := func(inner ast.Node) []*ast.TableName {
if inner == nil {
return nil
}
case *ast.BinaryOperationExpr:
if s, ok := x.R.(*ast.SubqueryExpr); ok {
input = extractTableList(s, input, asName)
innerExtractor := &tableListExtractor{
asName: e.asName,
tableNames: []*ast.TableName{},
}
case *ast.SubqueryExpr:
input = extractTableList(x.Query, input, asName)
case *ast.Join:
input = extractTableList(x.Left, input, asName)
input = extractTableList(x.Right, input, asName)
inner.Accept(innerExtractor)
return innerExtractor.tableNames
}

switch x := n.(type) {
case *ast.TableName:
e.tableNames = append(e.tableNames, x)
case *ast.TableSource:
if s, ok := x.Source.(*ast.TableName); ok {
if x.AsName.L != "" && asName {
if x.AsName.L != "" && e.asName {
newTableName := *s
newTableName.Name = x.AsName
newTableName.Schema = model.NewCIStr("")
input = append(input, &newTableName)
e.tableNames = append(e.tableNames, &newTableName)
} else {
input = append(input, s)
e.tableNames = append(e.tableNames, s)
}
} else if s, ok := x.Source.(*ast.SelectStmt); ok {
if s.From != nil {
var innerList []*ast.TableName
innerList = extractTableList(s.From.TableRefs, innerList, asName)
innerList := innerExtract(s.From.TableRefs)
if len(innerList) > 0 {
innerTableName := innerList[0]
if x.AsName.L != "" && asName {
if x.AsName.L != "" && e.asName {
newTableName := *innerList[0]
newTableName.Name = x.AsName
newTableName.Schema = model.NewCIStr("")
innerTableName = &newTableName
}
input = append(input, innerTableName)
e.tableNames = append(e.tableNames, innerTableName)
}
}
}
return n, true

case *ast.ShowStmt:
if x.DBName != "" {
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)})
}
case *ast.CreateDatabaseStmt:
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name})
case *ast.AlterDatabaseStmt:
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name})
case *ast.DropDatabaseStmt:
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name})

case *ast.FlashBackDatabaseStmt:
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName})
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.NewName)})
case *ast.FlashBackToTimestampStmt:
if x.DBName.L != "" {
e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName})
}
case *ast.FlashBackTableStmt:
if newName := x.NewName; newName != "" {
e.tableNames = append(e.tableNames, &ast.TableName{
Schema: x.Table.Schema,
Name: model.NewCIStr(newName)})
}

case *ast.GrantStmt:
if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone {
if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable {
e.tableNames = append(e.tableNames, &ast.TableName{
Schema: model.NewCIStr(x.Level.DBName),
Name: model.NewCIStr(x.Level.TableName),
})
}
}
case *ast.RevokeStmt:
if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone {
if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable {
e.tableNames = append(e.tableNames, &ast.TableName{
Schema: model.NewCIStr(x.Level.DBName),
Name: model.NewCIStr(x.Level.TableName),
})
}
}
case *ast.BRIEStmt:
if x.Kind == ast.BRIEKindBackup || x.Kind == ast.BRIEKindRestore {
for _, v := range x.Schemas {
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(v)})
}
}
case *ast.UseStmt:
e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)})
case *ast.ExecuteStmt:
if v, ok := x.PrepStmt.(*PlanCacheStmt); ok {
e.tableNames = append(e.tableNames, innerExtract(v.PreparedAst.Stmt)...)
}
}
return input
return n, false
}

func (*tableListExtractor) Leave(n ast.Node) (ast.Node, bool) {
return n, true
}

func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, info *map[string]*ast.TableName) {
Expand Down
3 changes: 1 addition & 2 deletions pkg/planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -1677,8 +1677,7 @@ func buildPointUpdatePlan(ctx sessionctx.Context, pointPlan PhysicalPlan, dbName
}
if tbl.GetPartitionInfo() != nil {
pt := t.(table.PartitionedTable)
var updateTableList []*ast.TableName
updateTableList = extractTableList(updateStmt.TableRefs.TableRefs, updateTableList, true)
updateTableList := ExtractTableList(updateStmt.TableRefs.TableRefs, true)
updatePlan.PartitionedTable = make([]table.PartitionedTable, 0, len(updateTableList))
for _, updateTable := range updateTableList {
if len(updateTable.PartitionNames) > 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, def
}

// Check the bind operation is not on any temporary table.
tblNames := extractTableList(originNode, nil, false)
tblNames := ExtractTableList(originNode, false)
for _, tn := range tblNames {
tbl, err := p.tableByName(tn)
if err != nil {
Expand Down
Loading

0 comments on commit ce9e337

Please sign in to comment.