diff --git a/pkg/planner/BUILD.bazel b/pkg/planner/BUILD.bazel index 7103c9c40fc51..3d6398176dc79 100644 --- a/pkg/planner/BUILD.bazel +++ b/pkg/planner/BUILD.bazel @@ -18,7 +18,10 @@ go_library( "//pkg/planner/core", "//pkg/planner/core/base", "//pkg/planner/core/resolve", + "//pkg/planner/indexadvisor", + "//pkg/planner/property", "//pkg/planner/util/debugtrace", + "//pkg/planner/util/optimizetrace", "//pkg/privilege", "//pkg/sessionctx", "//pkg/sessionctx/variable", diff --git a/pkg/planner/indexadvisor/BUILD.bazel b/pkg/planner/indexadvisor/BUILD.bazel new file mode 100644 index 0000000000000..fa19fb7585b0a --- /dev/null +++ b/pkg/planner/indexadvisor/BUILD.bazel @@ -0,0 +1,48 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "indexadvisor", + srcs = [ + "model.go", + "optimizer.go", + "utils.go", + ], + importpath = "github.com/pingcap/tidb/pkg/planner/indexadvisor", + visibility = ["//visibility:public"], + deps = [ + "//pkg/domain", + "//pkg/infoschema", + "//pkg/meta/model", + "//pkg/parser", + "//pkg/parser/ast", + "//pkg/parser/model", + "//pkg/parser/mysql", + "//pkg/parser/opcode", + "//pkg/planner/util/fixcontrol", + "//pkg/sessionctx", + "//pkg/types", + "//pkg/types/parser_driver", + "//pkg/util/logutil", + "//pkg/util/parser", + "//pkg/util/set", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "indexadvisor_test", + timeout = "short", + srcs = [ + "optimizer_test.go", + "utils_test.go", + ], + flaky = True, + shard_count = 17, + deps = [ + ":indexadvisor", + "//pkg/parser/mysql", + "//pkg/testkit", + "//pkg/util/set", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/planner/indexadvisor/model.go b/pkg/planner/indexadvisor/model.go new file mode 100644 index 0000000000000..14e767235480f --- /dev/null +++ b/pkg/planner/indexadvisor/model.go @@ -0,0 +1,137 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexadvisor + +import ( + "fmt" + "math" + "strings" +) + +// Query represents a Query statement. +type Query struct { // DQL or DML + Alias string + SchemaName string + Text string + Frequency int + CostPerMon float64 +} + +// Key returns the key of the Query. +func (q Query) Key() string { + return q.Text +} + +// Column represents a column. +type Column struct { + SchemaName string + TableName string + ColumnName string +} + +// NewColumn creates a new column. +func NewColumn(schemaName, tableName, columnName string) Column { + return Column{SchemaName: strings.ToLower(schemaName), + TableName: strings.ToLower(tableName), ColumnName: strings.ToLower(columnName)} +} + +// NewColumns creates new columns. +func NewColumns(schemaName, tableName string, columnNames ...string) []Column { + cols := make([]Column, 0, len(columnNames)) + for _, col := range columnNames { + cols = append(cols, NewColumn(schemaName, tableName, col)) + } + return cols +} + +// Key returns the key of the column. +func (c Column) Key() string { + return fmt.Sprintf("%v.%v.%v", c.SchemaName, c.TableName, c.ColumnName) +} + +// Index represents an index. +type Index struct { + SchemaName string + TableName string + IndexName string + Columns []Column +} + +// NewIndex creates a new index. +func NewIndex(schemaName, tableName, indexName string, columns ...string) Index { + return Index{SchemaName: strings.ToLower(schemaName), TableName: strings.ToLower(tableName), + IndexName: strings.ToLower(indexName), Columns: NewColumns(schemaName, tableName, columns...)} +} + +// NewIndexWithColumns creates a new index with columns. +func NewIndexWithColumns(indexName string, columns ...Column) Index { + names := make([]string, len(columns)) + for i, col := range columns { + names[i] = col.ColumnName + } + return NewIndex(columns[0].SchemaName, columns[0].TableName, indexName, names...) +} + +// Key returns the key of the index. +func (i Index) Key() string { + names := make([]string, 0, len(i.Columns)) + for _, col := range i.Columns { + names = append(names, col.ColumnName) + } + return fmt.Sprintf("%v.%v(%v)", i.SchemaName, i.TableName, strings.Join(names, ",")) +} + +// PrefixContain returns whether j is a prefix of i. +func (i Index) PrefixContain(j Index) bool { + if i.SchemaName != j.SchemaName || i.TableName != j.TableName || len(i.Columns) < len(j.Columns) { + return false + } + for k := range j.Columns { + if i.Columns[k].ColumnName != j.Columns[k].ColumnName { + return false + } + } + return true +} + +// IndexSetCost is the cost of a index configuration. +type IndexSetCost struct { + TotalWorkloadQueryCost float64 + TotalNumberOfIndexColumns int + IndexKeysStr string // IndexKeysStr is the string representation of the index keys. +} + +// Less returns whether the cost of c is less than the cost of other. +func (c IndexSetCost) Less(other IndexSetCost) bool { + if c.TotalWorkloadQueryCost == 0 { // not initialized + return false + } + if other.TotalWorkloadQueryCost == 0 { // not initialized + return true + } + cc, cOther := c.TotalWorkloadQueryCost, other.TotalWorkloadQueryCost + if math.Abs(cc-cOther) > 10 && math.Abs(cc-cOther)/math.Max(cc, cOther) > 0.001 { + // their cost is very different, then the less cost, the better. + return cc < cOther + } + + if c.TotalNumberOfIndexColumns != other.TotalNumberOfIndexColumns { + // if they have the same cost, then the less columns, the better. + return c.TotalNumberOfIndexColumns < other.TotalNumberOfIndexColumns + } + + // to make the result stable. + return c.IndexKeysStr < other.IndexKeysStr +} diff --git a/pkg/planner/indexadvisor/optimizer.go b/pkg/planner/indexadvisor/optimizer.go new file mode 100644 index 0000000000000..c24036e7da645 --- /dev/null +++ b/pkg/planner/indexadvisor/optimizer.go @@ -0,0 +1,264 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexadvisor + +import ( + "context" + "fmt" + "strings" + + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/ast" + model2 "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/types" +) + +// QueryPlanCostHook is used to calculate the cost of the query plan on this sctx. +// This hook is used to avoid cyclic import. +var QueryPlanCostHook func(sctx sessionctx.Context, stmt ast.StmtNode) (float64, error) + +// Optimizer is the interface of a what-if optimizer. +// This interface encapsulates all methods the Index Advisor needs to interact with the TiDB optimizer. +// This interface is not thread-safe. +type Optimizer interface { + // ColumnType returns the column type of the specified column. + ColumnType(c Column) (*types.FieldType, error) + + // PrefixContainIndex returns whether the specified index is a prefix of an existing index. + PrefixContainIndex(idx Index) (bool, error) + + // PossibleColumns returns the possible columns that match the specified column name. + PossibleColumns(schema, colName string) ([]Column, error) + + // TableColumns returns the columns of the specified table. + TableColumns(schema, table string) ([]Column, error) + + // IndexNameExist returns whether the specified index name exists in the specified table. + IndexNameExist(schema, table, indexName string) (bool, error) + + // EstIndexSize return the estimated index size of the specified table and columns + EstIndexSize(db, table string, cols ...string) (indexSize float64, err error) + + // QueryPlanCost return the cost of the query plan. + QueryPlanCost(sql string, hypoIndexes ...Index) (cost float64, err error) +} + +// optimizerImpl is the implementation of Optimizer. +type optimizerImpl struct { + sctx sessionctx.Context +} + +// NewOptimizer creates a new Optimizer. +func NewOptimizer(sctx sessionctx.Context) Optimizer { + return &optimizerImpl{sctx} +} + +func (opt *optimizerImpl) is() infoschema.InfoSchema { + return opt.sctx.GetDomainInfoSchema().(infoschema.InfoSchema) +} + +// IndexNameExist returns whether the specified index name exists in the specified table. +func (opt *optimizerImpl) IndexNameExist(schema, table, indexName string) (bool, error) { + tbl, err := opt.is().TableByName(context.Background(), model2.NewCIStr(schema), model2.NewCIStr(table)) + if err != nil { + return false, err + } + for _, idx := range tbl.Indices() { + if idx.Meta().Name.L == indexName { + return true, nil + } + } + return false, nil +} + +// TableColumns returns the columns of the specified table. +func (opt *optimizerImpl) TableColumns(schema, table string) ([]Column, error) { + tbl, err := opt.is().TableByName(context.Background(), model2.NewCIStr(schema), model2.NewCIStr(table)) + if err != nil { + return nil, err + } + cols := make([]Column, 0) + for _, col := range tbl.Cols() { + cols = append(cols, Column{ + SchemaName: schema, + TableName: table, + ColumnName: col.Name.L, + }) + } + return cols, nil +} + +// PossibleColumns returns the possible columns that match the specified column name. +func (opt *optimizerImpl) PossibleColumns(schema, colName string) ([]Column, error) { + // filtering system schema + schema = strings.ToLower(schema) + if schema == "information_schema" || schema == "metrics_schema" || + schema == "performance_schema" || schema == "mysql" { + return nil, nil + } + + cols := make([]Column, 0) + tbls, err := opt.is().SchemaTableInfos(context.Background(), model2.NewCIStr(schema)) + if err != nil { + return nil, err + } + for _, tbl := range tbls { + for _, col := range tbl.Cols() { + if strings.ToLower(col.Name.L) == colName { + cols = append(cols, Column{ + SchemaName: schema, + TableName: tbl.Name.L, + ColumnName: col.Name.L, + }) + } + } + } + return cols, nil +} + +// PrefixContainIndex returns whether the specified index is a prefix of an existing index. +func (opt *optimizerImpl) PrefixContainIndex(idx Index) (bool, error) { + tbl, err := opt.is().TableByName(context.Background(), model2.NewCIStr(idx.SchemaName), model2.NewCIStr(idx.TableName)) + if err != nil { + return false, err + } + for _, tblIndex := range tbl.Indices() { + if len(tblIndex.Meta().Columns) < len(idx.Columns) { + continue + } + prefixMatched := true + for i, idxCol := range idx.Columns { + if tblIndex.Meta().Columns[i].Name.L != strings.ToLower(idxCol.ColumnName) { + prefixMatched = false + break + } + } + if prefixMatched { + return true, nil + } + } + return false, nil +} + +// ColumnType returns the column type of the specified column. +func (opt *optimizerImpl) ColumnType(c Column) (*types.FieldType, error) { + tbl, err := opt.is().TableByName(context.Background(), model2.NewCIStr(c.SchemaName), model2.NewCIStr(c.TableName)) + if err != nil { + return nil, err + } + for _, col := range tbl.Cols() { + if col.Name.L == strings.ToLower(c.ColumnName) { + return &col.FieldType, nil + } + } + return nil, fmt.Errorf("column %v not found in table %v.%v", c.ColumnName, c.SchemaName, c.TableName) +} + +func (opt *optimizerImpl) addHypoIndex(hypoIndexes ...Index) error { + for _, h := range hypoIndexes { + tInfo, err := opt.is().TableByName(context.Background(), model2.NewCIStr(h.SchemaName), model2.NewCIStr(h.TableName)) + if err != nil { + return err + } + + var cols []*model.IndexColumn + for _, col := range h.Columns { + colOffset := -1 + for i, tCol := range tInfo.Cols() { + if tCol.Name.L == strings.ToLower(col.ColumnName) { + colOffset = i + break + } + } + if colOffset == -1 { + return fmt.Errorf("column %v not found in table %v.%v", col.ColumnName, h.SchemaName, h.TableName) + } + cols = append(cols, &model.IndexColumn{ + Name: model2.NewCIStr(col.ColumnName), + Offset: colOffset, + Length: types.UnspecifiedLength, + }) + } + idxInfo := &model.IndexInfo{ + Name: model2.NewCIStr(h.IndexName), + Columns: cols, + State: model.StatePublic, + Tp: model2.IndexTypeHypo, + } + + if opt.sctx.GetSessionVars().HypoIndexes == nil { + opt.sctx.GetSessionVars().HypoIndexes = make(map[string]map[string]map[string]*model.IndexInfo) + } + if opt.sctx.GetSessionVars().HypoIndexes[h.SchemaName] == nil { + opt.sctx.GetSessionVars().HypoIndexes[h.SchemaName] = make(map[string]map[string]*model.IndexInfo) + } + if opt.sctx.GetSessionVars().HypoIndexes[h.SchemaName][h.TableName] == nil { + opt.sctx.GetSessionVars().HypoIndexes[h.SchemaName][h.TableName] = make(map[string]*model.IndexInfo) + } + opt.sctx.GetSessionVars().HypoIndexes[h.SchemaName][h.TableName][h.IndexName] = idxInfo + } + return nil +} + +// QueryPlanCost return the cost of the query plan. +func (opt *optimizerImpl) QueryPlanCost(sql string, hypoIndexes ...Index) (cost float64, err error) { + stmt, err := ParseOneSQL(sql) + if err != nil { + return 0, err + } + + originalFix43817 := opt.sctx.GetSessionVars().OptimizerFixControl[fixcontrol.Fix43817] + originalWarns := opt.sctx.GetSessionVars().StmtCtx.GetWarnings() + originalExtraWarns := opt.sctx.GetSessionVars().StmtCtx.GetExtraWarnings() + originalHypoIndexes := opt.sctx.GetSessionVars().HypoIndexes + defer func() { + opt.sctx.GetSessionVars().OptimizerFixControl[fixcontrol.Fix43817] = originalFix43817 + opt.sctx.GetSessionVars().StmtCtx.SetWarnings(originalWarns) + opt.sctx.GetSessionVars().StmtCtx.SetExtraWarnings(originalExtraWarns) + opt.sctx.GetSessionVars().HypoIndexes = originalHypoIndexes + opt.sctx.GetSessionVars().StmtCtx.InExplainStmt = false + }() + opt.sctx.GetSessionVars().OptimizerFixControl[fixcontrol.Fix43817] = "on" + opt.sctx.GetSessionVars().StmtCtx.InExplainStmt = true + opt.sctx.GetSessionVars().HypoIndexes = nil + + if err := opt.addHypoIndex(hypoIndexes...); err != nil { + return 0, err + } + return QueryPlanCostHook(opt.sctx, stmt) +} + +// EstIndexSize return the estimated index size of the specified table and columns +func (opt *optimizerImpl) EstIndexSize(db, table string, cols ...string) (indexSize float64, err error) { + tbl, err := opt.is().TableByName(context.Background(), model2.NewCIStr(db), model2.NewCIStr(table)) + if err != nil { + return 0, err + } + stats := domain.GetDomain(opt.sctx).StatsHandle() + tblStats := stats.GetTableStats(tbl.Meta()) + for _, colName := range cols { + colStats := tblStats.ColumnByName(colName) + if colStats == nil { // might be not loaded + indexSize += float64(8) * float64(tblStats.RealtimeCount) + } else { + indexSize += float64(colStats.TotColSize) + } + } + return indexSize, nil +} diff --git a/pkg/planner/indexadvisor/optimizer_test.go b/pkg/planner/indexadvisor/optimizer_test.go new file mode 100644 index 0000000000000..69af075d041ef --- /dev/null +++ b/pkg/planner/indexadvisor/optimizer_test.go @@ -0,0 +1,262 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexadvisor_test + +import ( + "context" + "fmt" + "sort" + "testing" + + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/indexadvisor" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/stretchr/testify/require" +) + +func TestOptimizerColumnType(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1 (a int, b float, c varchar(255))`) + tk.MustExec(`create table t2 (a int, b decimal(10,2), c varchar(1024))`) + opt := indexadvisor.NewOptimizer(tk.Session()) + + tp, err := opt.ColumnType(indexadvisor.Column{SchemaName: "test", TableName: "t1", ColumnName: "a"}) + require.NoError(t, err) + require.Equal(t, mysql.TypeLong, tp.GetType()) + + tp, err = opt.ColumnType(indexadvisor.Column{SchemaName: "test", TableName: "t1", ColumnName: "b"}) + require.NoError(t, err) + require.Equal(t, mysql.TypeFloat, tp.GetType()) + + tp, err = opt.ColumnType(indexadvisor.Column{SchemaName: "test", TableName: "t1", ColumnName: "c"}) + require.NoError(t, err) + require.Equal(t, mysql.TypeVarchar, tp.GetType()) + + tp, err = opt.ColumnType(indexadvisor.Column{SchemaName: "test", TableName: "t2", ColumnName: "a"}) + require.NoError(t, err) + require.Equal(t, mysql.TypeLong, tp.GetType()) + + tp, err = opt.ColumnType(indexadvisor.Column{SchemaName: "test", TableName: "t2", ColumnName: "b"}) + require.NoError(t, err) + require.Equal(t, mysql.TypeNewDecimal, tp.GetType()) + + tp, err = opt.ColumnType(indexadvisor.Column{SchemaName: "test", TableName: "t2", ColumnName: "c"}) + require.NoError(t, err) + require.Equal(t, mysql.TypeVarchar, tp.GetType()) + + _, err = opt.ColumnType(indexadvisor.Column{SchemaName: "test", TableName: "t2", ColumnName: "d"}) + require.Error(t, err) + + _, err = opt.ColumnType(indexadvisor.Column{SchemaName: "test", TableName: "t3", ColumnName: "a"}) + require.Error(t, err) +} + +func TestOptimizerPrefixContainIndex(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1 (a int, b int, c int, d int, key(a), key(b, c))`) + tk.MustExec(`create table t2 (a int, b int, c int, d int, key(a, b, c, d), key(d, c, b, a))`) + opt := indexadvisor.NewOptimizer(tk.Session()) + + check := func(expected bool, tableName string, columns ...string) { + ok, err := opt.PrefixContainIndex(indexadvisor.NewIndex("test", tableName, "idx", columns...)) + require.NoError(t, err) + require.Equal(t, expected, ok) + } + + check(true, "t1", "a") + check(true, "t1", "b") + check(true, "t1", "b", "c") + check(false, "t1", "c") + check(false, "t1", "a", "b") + check(false, "t1", "b", "c", "a") + check(true, "t2", "a") + check(true, "t2", "a", "b") + check(true, "t2", "a", "b", "c") + check(true, "t2", "a", "b", "c", "d") + check(true, "t2", "d") + check(true, "t2", "d", "c") + check(false, "t2", "b") + check(false, "t2", "b", "a") + check(false, "t2", "b", "a", "c") +} + +func TestOptimizerPossibleColumns(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1 (a int, b int, c int, d int)`) + tk.MustExec(`create table t2 (a int, b int, c int, d int)`) + tk.MustExec(`create table t3 (c int, d int, e int, f int)`) + opt := indexadvisor.NewOptimizer(tk.Session()) + + check := func(schema, colName string, expected []string) { + cols, err := opt.PossibleColumns(schema, colName) + require.NoError(t, err) + var tmp []string + for _, col := range cols { + tmp = append(tmp, fmt.Sprintf("%v.%v", col.TableName, col.ColumnName)) + } + sort.Strings(tmp) + require.Equal(t, expected, tmp) + } + + check("test", "a", []string{"t1.a", "t2.a"}) + check("test", "b", []string{"t1.b", "t2.b"}) + check("test", "c", []string{"t1.c", "t2.c", "t3.c"}) + check("test", "d", []string{"t1.d", "t2.d", "t3.d"}) + check("test", "e", []string{"t3.e"}) + check("test", "f", []string{"t3.f"}) + check("test", "g", nil) +} + +func TestOptimizerTableColumns(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1 (a int, b int, c int, d int)`) + tk.MustExec(`create table t2 (a int, b int, c int, d int)`) + tk.MustExec(`create table t3 (c int, d int, e int, f int)`) + opt := indexadvisor.NewOptimizer(tk.Session()) + + check := func(schemaName, tableName string, columns []string) { + cols, err := opt.TableColumns(schemaName, tableName) + require.NoError(t, err) + var tmp []string + for _, col := range cols { + require.Equal(t, schemaName, col.SchemaName) + require.Equal(t, tableName, col.TableName) + tmp = append(tmp, col.ColumnName) + } + sort.Strings(tmp) + require.Equal(t, columns, tmp) + } + + check("test", "t1", []string{"a", "b", "c", "d"}) + check("test", "t2", []string{"a", "b", "c", "d"}) + check("test", "t3", []string{"c", "d", "e", "f"}) +} + +func TestOptimizerIndexNameExist(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1 (a int, b int, c int, d int, index ka(a), index kbc(b, c))`) + tk.MustExec(`create table t2 (a int, b int, c int, d int, index ka(a), index kbc(b, c))`) + opt := indexadvisor.NewOptimizer(tk.Session()) + + check := func(schema, table, indexName string, expected bool) { + ok, err := opt.IndexNameExist(schema, table, indexName) + require.NoError(t, err) + require.Equal(t, expected, ok) + } + + check("test", "t1", "ka", true) + check("test", "t1", "kbc", true) + check("test", "t1", "kbc2", false) + check("test", "t2", "ka", true) + check("test", "t2", "kbc", true) + check("test", "t2", "kbc2", false) +} + +func TestOptimizerEstIndexSize(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + h := dom.StatsHandle() + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t (a int, b varchar(64))`) + opt := indexadvisor.NewOptimizer(tk.Session()) + + s, err := opt.EstIndexSize("test", "t", "a") + require.NoError(t, err) + require.Equal(t, float64(0), s) + + s, err = opt.EstIndexSize("test", "t", "b") + require.NoError(t, err) + require.Equal(t, float64(0), s) + + tk.MustExec(`insert into t values (1, space(32))`) + require.NoError(t, h.DumpStatsDeltaToKV(true)) + require.NoError(t, h.Update(context.Background(), dom.InfoSchema())) + s, err = opt.EstIndexSize("test", "t", "a") + require.NoError(t, err) + require.Equal(t, float64(1), s) + + s, err = opt.EstIndexSize("test", "t", "b") + require.NoError(t, err) + require.Equal(t, float64(33), s) // 32 + 1 + + s, err = opt.EstIndexSize("test", "t", "a", "b") + require.NoError(t, err) + require.Equal(t, float64(34), s) // 32 + 1 + 1 + + tk.MustExec(`insert into t values (1, space(64))`) + require.NoError(t, h.DumpStatsDeltaToKV(true)) + require.NoError(t, h.Update(context.Background(), dom.InfoSchema())) + s, err = opt.EstIndexSize("test", "t", "a") + require.NoError(t, err) + require.Equal(t, float64(2), s) // 2 rows + + s, err = opt.EstIndexSize("test", "t", "b") + require.NoError(t, err) + require.Equal(t, float64(99), s) // 32 + 64 + x + + s, err = opt.EstIndexSize("test", "t", "b", "a") + require.NoError(t, err) + require.Equal(t, float64(99+2), s) +} + +func TestOptimizerQueryCost(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1 (a int, b int, c int, d int, index ka(a), index kbc(b, c))`) + tk.MustExec(`create table t2 (a int, b int, c int, d int, index ka(a), index kbc(b, c))`) +} + +func TestOptimizerQueryPlanCost(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t0 (a int, b int, c int)`) + + opt := indexadvisor.NewOptimizer(tk.Session()) + cost1, err := opt.QueryPlanCost("select a, b from t0 where a=1 and b=1") + require.NoError(t, err) + + cost2, err := opt.QueryPlanCost("select a, b from t0 where a=1 and b=1", indexadvisor.Index{ + SchemaName: "test", + TableName: "t0", + IndexName: "idx_a", + Columns: []indexadvisor.Column{ + {SchemaName: "test", TableName: "t0", ColumnName: "a"}}, + }) + require.NoError(t, err) + require.True(t, cost2 < cost1) + + cost3, err := opt.QueryPlanCost("select a, b from t0 where a=1 and b=1", indexadvisor.Index{ + SchemaName: "test", + TableName: "t0", + IndexName: "idx_a", + Columns: []indexadvisor.Column{ + {SchemaName: "test", TableName: "t0", ColumnName: "a"}, + {SchemaName: "test", TableName: "t0", ColumnName: "b"}}, + }) + require.NoError(t, err) + require.True(t, cost3 < cost2) +} diff --git a/pkg/planner/indexadvisor/utils.go b/pkg/planner/indexadvisor/utils.go new file mode 100644 index 0000000000000..3d7a8d8126a51 --- /dev/null +++ b/pkg/planner/indexadvisor/utils.go @@ -0,0 +1,473 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexadvisor + +import ( + "fmt" + "strings" + + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/opcode" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" + "github.com/pingcap/tidb/pkg/util/logutil" + parser2 "github.com/pingcap/tidb/pkg/util/parser" + s "github.com/pingcap/tidb/pkg/util/set" + "go.uber.org/zap" +) + +// ParseOneSQL parses the given Query text and returns the AST. +func ParseOneSQL(sqlText string) (ast.StmtNode, error) { + p := parser.New() + return p.ParseOneStmt(sqlText, "", "") +} + +// NormalizeDigest normalizes the given Query text and returns the normalized Query text and its digest. +func NormalizeDigest(sqlText string) (normalizedSQL, digest string) { + norm, d := parser.NormalizeDigest(sqlText) + return norm, d.String() +} + +type nodeVisitor struct { + enter func(n ast.Node) (skip bool) + leave func(n ast.Node) (ok bool) +} + +func (v *nodeVisitor) Enter(n ast.Node) (out ast.Node, skipChildren bool) { + if v.enter != nil { + return n, v.enter(n) + } + return n, false +} + +func (v *nodeVisitor) Leave(n ast.Node) (out ast.Node, ok bool) { + if v.leave != nil { + return n, v.leave(n) + } + return n, true +} + +func visitNode(n ast.Node, enter func(n ast.Node) (skip bool), leave func(n ast.Node) (ok bool)) { + n.Accept(&nodeVisitor{enter, leave}) +} + +// CollectTableNamesFromQuery returns all referenced table names in the given Query text. +// The returned format is []string{"schema.table", "schema.table", ...}. +func CollectTableNamesFromQuery(defaultSchema, query string) ([]string, error) { + node, err := ParseOneSQL(query) + if err != nil { + return nil, err + } + cteNames := make(map[string]struct{}) + var tableNames []string + visitNode(node, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.WithClause: + for _, cte := range x.CTEs { + cteNames[fmt.Sprintf("%v.%v", defaultSchema, cte.Name.String())] = struct{}{} + } + case *ast.TableName: + var tableName string + if x.Schema.L == "" { + tableName = fmt.Sprintf("%v.%v", defaultSchema, x.Name.String()) + } else { + tableName = fmt.Sprintf("%v.%v", x.Schema.L, x.Name.String()) + } + if _, ok := cteNames[tableName]; !ok { + tableNames = append(tableNames, tableName) + } + } + return false + }, nil) + return tableNames, nil +} + +// CollectSelectColumnsFromQuery parses the given Query text and returns the selected columns. +// For example, "select a, b, c from t" returns []string{"a", "b", "c"}. +func CollectSelectColumnsFromQuery(q Query) (s.Set[Column], error) { + names, err := CollectTableNamesFromQuery(q.SchemaName, q.Text) + if err != nil { + return nil, err + } + if len(names) != 1 { // unsupported yet + return nil, nil + } + tmp := strings.Split(names[0], ".") + node, err := ParseOneSQL(q.Text) + if err != nil { + return nil, err + } + underSelectField := false + selectCols := s.NewSet[Column]() + visitNode(node, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.SelectField: + underSelectField = true + case *ast.ColumnNameExpr: + if underSelectField { + selectCols.Add(Column{ + SchemaName: tmp[0], + TableName: tmp[1], + ColumnName: x.Name.Name.O}) + } + } + return false + }, func(n ast.Node) bool { + if _, ok := n.(*ast.SelectField); ok { + underSelectField = false + } + return true + }) + return selectCols, nil +} + +// CollectOrderByColumnsFromQuery parses the given Query text and returns the order-by columns. +// For example, "select a, b from t order by a, b" returns []string{"a", "b"}. +func CollectOrderByColumnsFromQuery(q Query) ([]Column, error) { + names, err := CollectTableNamesFromQuery(q.SchemaName, q.Text) + if err != nil { + return nil, err + } + if len(names) != 1 { // unsupported yet + return nil, nil + } + tmp := strings.Split(names[0], ".") + node, err := ParseOneSQL(q.Text) + if err != nil { + return nil, err + } + var orderByCols []Column + exit := false + visitNode(node, func(n ast.Node) bool { + if exit { + return true + } + if x, ok := n.(*ast.OrderByClause); ok { + for _, byItem := range x.Items { + colExpr, ok := byItem.Expr.(*ast.ColumnNameExpr) + if !ok { + orderByCols = nil + exit = true + return true + } + orderByCols = append(orderByCols, Column{ + SchemaName: tmp[0], + TableName: tmp[1], + ColumnName: colExpr.Name.Name.O}) + } + } + return false + }, nil) + return orderByCols, nil +} + +// CollectDNFColumnsFromQuery parses the given Query text and returns the DNF columns. +// For a query `select ... where c1=1 or c2=2 or c3=3`, the DNF columns are `c1`, `c2` and `c3`. +func CollectDNFColumnsFromQuery(q Query) (s.Set[Column], error) { + names, err := CollectTableNamesFromQuery(q.SchemaName, q.Text) + if err != nil { + return nil, err + } + if len(names) != 1 { // unsupported yet + return nil, nil + } + tmp := strings.Split(names[0], ".") + node, err := ParseOneSQL(q.Text) + if err != nil { + return nil, err + } + dnfColSet := s.NewSet[Column]() + + visitNode(node, func(n ast.Node) bool { + if dnfColSet.Size() > 0 { // already collected + return true + } + if x, ok := n.(*ast.SelectStmt); ok { + cnf := flattenCNF(x.Where) + for _, expr := range cnf { + dnf := flattenDNF(expr) + if len(dnf) <= 1 { + continue + } + // c1=1 or c2=2 or c3=3 + var dnfCols []*ast.ColumnNameExpr + fail := false + for _, dnfExpr := range dnf { + col, _ := flattenColEQConst(dnfExpr) + if col == nil { + fail = true + break + } + dnfCols = append(dnfCols, col) + } + if fail { + continue + } + for _, col := range dnfCols { + dnfColSet.Add(Column{SchemaName: tmp[0], TableName: tmp[1], ColumnName: col.Name.Name.O}) + } + } + } + return false + }, nil) + + return dnfColSet, nil +} + +func flattenColEQConst(expr ast.ExprNode) (*ast.ColumnNameExpr, *driver.ValueExpr) { + if _, ok := expr.(*ast.ParenthesesExpr); ok { + return flattenColEQConst(expr.(*ast.ParenthesesExpr).Expr) + } + + if op, ok := expr.(*ast.BinaryOperationExpr); ok && op.Op == opcode.EQ { + l, r := op.L, op.R + _, lIsCol := l.(*ast.ColumnNameExpr) + _, lIsCon := l.(*driver.ValueExpr) + _, rIsCol := r.(*ast.ColumnNameExpr) + _, rIsCon := r.(*driver.ValueExpr) + if lIsCol && rIsCon { + return l.(*ast.ColumnNameExpr), r.(*driver.ValueExpr) + } + if lIsCon && rIsCol { + return r.(*ast.ColumnNameExpr), l.(*driver.ValueExpr) + } + } + return nil, nil +} + +func flattenCNF(expr ast.ExprNode) []ast.ExprNode { + if _, ok := expr.(*ast.ParenthesesExpr); ok { + return flattenCNF(expr.(*ast.ParenthesesExpr).Expr) + } + + var cnf []ast.ExprNode + if op, ok := expr.(*ast.BinaryOperationExpr); ok && op.Op == opcode.LogicAnd { + cnf = append(cnf, flattenCNF(op.L)...) + cnf = append(cnf, flattenCNF(op.R)...) + } else { + cnf = append(cnf, expr) + } + return cnf +} + +func flattenDNF(expr ast.ExprNode) []ast.ExprNode { + if _, ok := expr.(*ast.ParenthesesExpr); ok { + return flattenDNF(expr.(*ast.ParenthesesExpr).Expr) + } + + var cnf []ast.ExprNode + if op, ok := expr.(*ast.BinaryOperationExpr); ok && op.Op == opcode.LogicOr { + cnf = append(cnf, flattenDNF(op.L)...) + cnf = append(cnf, flattenDNF(op.R)...) + } else { + cnf = append(cnf, expr) + } + return cnf +} + +// RestoreSchemaName restores the schema name of the given Query set. +func RestoreSchemaName(defaultSchema string, sqls s.Set[Query], ignoreErr bool) (s.Set[Query], error) { + s := s.NewSet[Query]() + for _, sql := range sqls.ToList() { + if sql.SchemaName == "" { + sql.SchemaName = defaultSchema + } + stmt, err := ParseOneSQL(sql.Text) + if err != nil { + if ignoreErr { + continue + } + return nil, fmt.Errorf("invalid query: %v, err: %v", sql.Text, err) + } + sql.Text = parser2.RestoreWithDefaultDB(stmt, sql.SchemaName, sql.Text) + s.Add(sql) + } + return s, nil +} + +// FilterInvalidQueries filters out invalid queries from the given query set. +// some queries might be forbidden by the fix-control 43817. +func FilterInvalidQueries(opt Optimizer, sqls s.Set[Query], ignoreErr bool) (s.Set[Query], error) { + s := s.NewSet[Query]() + for _, sql := range sqls.ToList() { + _, err := opt.QueryPlanCost(sql.Text) + if err != nil { + if ignoreErr { + continue + } + return nil, err + } + s.Add(sql) + } + return s, nil +} + +// FilterSQLAccessingSystemTables filters out queries that access system tables. +func FilterSQLAccessingSystemTables(sqls s.Set[Query], ignoreErr bool) (s.Set[Query], error) { + s := s.NewSet[Query]() + for _, sql := range sqls.ToList() { + accessSystemTable := false + names, err := CollectTableNamesFromQuery(sql.SchemaName, sql.Text) + if err != nil { + if ignoreErr { + continue + } + return nil, err + } + if len(names) == 0 { + // `select @@some_var` or `select some_func()` + continue + } + for _, name := range names { + schemaName := strings.ToLower(strings.Split(name, ".")[0]) + if schemaName == "information_schema" || schemaName == "metrics_schema" || + schemaName == "performance_schema" || schemaName == "mysql" { + accessSystemTable = true + break + } + } + if !accessSystemTable { + s.Add(sql) + } + } + return s, nil +} + +// CollectIndexableColumnsForQuerySet finds all columns that appear in any range-filter, order-by, or group-by clause. +func CollectIndexableColumnsForQuerySet(opt Optimizer, querySet s.Set[Query]) (s.Set[Column], error) { + indexableColumnSet := s.NewSet[Column]() + queryList := querySet.ToList() + for _, q := range queryList { + cols, err := CollectIndexableColumnsFromQuery(q, opt) + if err != nil { + return nil, err + } + querySet.Add(q) + indexableColumnSet.Add(cols.ToList()...) + } + return indexableColumnSet, nil +} + +// CollectIndexableColumnsFromQuery parses the given Query text and returns the indexable columns. +func CollectIndexableColumnsFromQuery(q Query, opt Optimizer) (s.Set[Column], error) { + tableNames, err := CollectTableNamesFromQuery(q.SchemaName, q.Text) + if err != nil { + return nil, err + } + possibleSchemas := make(map[string]bool) + possibleSchemas[q.SchemaName] = true + for _, name := range tableNames { + schemaName := strings.Split(name, ".")[0] + possibleSchemas[strings.ToLower(schemaName)] = true + } + + stmt, err := ParseOneSQL(q.Text) + if err != nil { + return nil, err + } + cols := s.NewSet[Column]() + var collectColumn func(n ast.Node) + collectColumn = func(n ast.Node) { + switch x := n.(type) { + case *ast.ColumnNameExpr: + collectColumn(x.Name) + case *ast.ColumnName: + var schemaNames []string + if x.Schema.L != "" { + schemaNames = append(schemaNames, x.Schema.L) + } else { + for schemaName := range possibleSchemas { + schemaNames = append(schemaNames, schemaName) + } + } + + var possibleColumns []Column + for _, schemaName := range schemaNames { + cols, err := opt.PossibleColumns(schemaName, x.Name.L) + if err != nil { + l().Warn("failed to get possible columns", + zap.String("schema", schemaName), + zap.String("column", x.Name.L)) + continue + } + possibleColumns = append(possibleColumns, cols...) + } + + for _, c := range possibleColumns { + colType, err := opt.ColumnType(c) + if err != nil { + l().Warn("failed to get column type", + zap.String("schema", c.SchemaName), + zap.String("table", c.TableName), + zap.String("column", c.ColumnName)) + continue + } + if !isIndexableColumnType(colType) { + continue + } + cols.Add(c) + } + } + } + + visitNode(stmt, func(n ast.Node) (skip bool) { + switch x := n.(type) { + case *ast.GroupByClause: // group by {col} + for _, item := range x.Items { + collectColumn(item.Expr) + } + return true + case *ast.OrderByClause: // order by {col} + for _, item := range x.Items { + collectColumn(item.Expr) + } + return true + case *ast.BetweenExpr: // {col} between ? and ? + collectColumn(x.Expr) + case *ast.PatternInExpr: // {col} in (?, ?, ...) + collectColumn(x.Expr) + case *ast.BinaryOperationExpr: // range predicates like `{col} > ?` + switch x.Op { + case opcode.EQ, opcode.LT, opcode.LE, opcode.GT, opcode.GE: // {col} = ? + collectColumn(x.L) + collectColumn(x.R) + } + default: + } + return false + }, nil) + return cols, nil +} + +func isIndexableColumnType(tp *types.FieldType) bool { + if tp == nil { + return false + } + switch tp.GetType() { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear, + mysql.TypeFloat, mysql.TypeDouble, mysql.TypeNewDecimal, + mysql.TypeDuration, mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + return true + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString: + return tp.GetFlen() <= 512 + } + return false +} + +func l() *zap.Logger { + return logutil.BgLogger().With(zap.String("component", "index_advisor")) +} diff --git a/pkg/planner/indexadvisor/utils_test.go b/pkg/planner/indexadvisor/utils_test.go new file mode 100644 index 0000000000000..6fe1893f4e5d5 --- /dev/null +++ b/pkg/planner/indexadvisor/utils_test.go @@ -0,0 +1,211 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexadvisor_test + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/planner/indexadvisor" + "github.com/pingcap/tidb/pkg/testkit" + s "github.com/pingcap/tidb/pkg/util/set" + "github.com/stretchr/testify/require" +) + +func TestCollectTableFromQuery(t *testing.T) { + names, err := indexadvisor.CollectTableNamesFromQuery("test", "select * from t where a = 1") + require.NoError(t, err) + require.Equal(t, names[0], "test.t") + + names, err = indexadvisor.CollectTableNamesFromQuery("test", "select * from t1, t2") + require.NoError(t, err) + require.Equal(t, names[0], "test.t1") + require.Equal(t, names[1], "test.t2") + + names, err = indexadvisor.CollectTableNamesFromQuery("test", "select * from t1 where t1.a < (select max(b) from t2)") + require.NoError(t, err) + require.Equal(t, names[0], "test.t1") + require.Equal(t, names[1], "test.t2") + + names, err = indexadvisor.CollectTableNamesFromQuery("test", "select * from t1 where t1.a < (select max(b) from db2.t2)") + require.NoError(t, err) + require.Equal(t, names[0], "test.t1") + require.Equal(t, names[1], "db2.t2") +} + +func TestCollectSelectColumnsFromQuery(t *testing.T) { + names, err := indexadvisor.CollectSelectColumnsFromQuery(indexadvisor.Query{Text: "select a, b from test.t"}) + require.NoError(t, err) + require.True(t, names.String() == "{test.t.a, test.t.b}") + + names, err = indexadvisor.CollectSelectColumnsFromQuery(indexadvisor.Query{Text: "select a, b, c from test.t"}) + require.NoError(t, err) + require.True(t, names.String() == "{test.t.a, test.t.b, test.t.c}") +} + +func TestCollectOrderByColumnsFromQuery(t *testing.T) { + cols, err := indexadvisor.CollectOrderByColumnsFromQuery(indexadvisor.Query{Text: "select a, b from test.t order by a"}) + require.NoError(t, err) + require.Equal(t, len(cols), 1) + require.Equal(t, cols[0].Key(), "test.t.a") + + cols, err = indexadvisor.CollectOrderByColumnsFromQuery(indexadvisor.Query{Text: "select a, b from test.t order by a, b"}) + require.NoError(t, err) + require.Equal(t, len(cols), 2) + require.Equal(t, cols[0].Key(), "test.t.a") + require.Equal(t, cols[1].Key(), "test.t.b") +} + +func TestCollectDNFColumnsFromQuery(t *testing.T) { + cols, err := indexadvisor.CollectDNFColumnsFromQuery(indexadvisor.Query{Text: "select a, b from test.t where a = 1 or b = 2"}) + require.NoError(t, err) + require.Equal(t, cols.String(), "{test.t.a, test.t.b}") + + cols, err = indexadvisor.CollectDNFColumnsFromQuery(indexadvisor.Query{Text: "select a, b from test.t where a = 1 or b = 2 or c=3"}) + require.NoError(t, err) + require.Equal(t, cols.String(), "{test.t.a, test.t.b, test.t.c}") +} + +func TestRestoreSchemaName(t *testing.T) { + q1 := indexadvisor.Query{Text: "select * from t1"} + q2 := indexadvisor.Query{Text: "select * from t2", SchemaName: "test2"} + q3 := indexadvisor.Query{Text: "select * from t3"} + q4 := indexadvisor.Query{Text: "wrong"} + set1 := s.NewSet[indexadvisor.Query]() + set1.Add(q1, q2, q3, q4) + + set2, err := indexadvisor.RestoreSchemaName("test", set1, true) + require.NoError(t, err) + require.Equal(t, set2.String(), "{SELECT * FROM `test2`.`t2`, SELECT * FROM `test`.`t1`, SELECT * FROM `test`.`t3`}") + + _, err = indexadvisor.RestoreSchemaName("test", set1, false) + require.Error(t, err) +} + +func TestFilterSQLAccessingSystemTables(t *testing.T) { + set1 := s.NewSet[indexadvisor.Query]() + set1.Add(indexadvisor.Query{Text: "select * from mysql.stats_meta"}) + set1.Add(indexadvisor.Query{Text: "select * from information_schema.test"}) + set1.Add(indexadvisor.Query{Text: "select * from metrics_schema.test"}) + set1.Add(indexadvisor.Query{Text: "select * from performance_schema.test"}) + set1.Add(indexadvisor.Query{Text: "select * from mysql.stats_meta", SchemaName: "test"}) + set1.Add(indexadvisor.Query{Text: "select * from mysql.stats_meta, test.t1", SchemaName: "test"}) + set1.Add(indexadvisor.Query{Text: "select * from test.t1", SchemaName: "mysql"}) + set1.Add(indexadvisor.Query{Text: "select @@var", SchemaName: "test"}) + set1.Add(indexadvisor.Query{Text: "select sleep(1)", SchemaName: "test"}) + set1.Add(indexadvisor.Query{Text: "wrong", SchemaName: "information_schema"}) + + set2, err := indexadvisor.FilterSQLAccessingSystemTables(set1, true) + require.NoError(t, err) + require.Equal(t, set2.String(), "{select * from test.t1}") + + _, err = indexadvisor.FilterSQLAccessingSystemTables(set1, false) + require.Error(t, err) +} + +func TestFilterInvalidQueries(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1 (a int, b int, c int)`) + tk.MustExec(`create table t2 (a int, b int, c int)`) + opt := indexadvisor.NewOptimizer(tk.Session()) + + set1 := s.NewSet[indexadvisor.Query]() + set1.Add(indexadvisor.Query{Text: "select * from test.t1"}) + set1.Add(indexadvisor.Query{Text: "select * from test.t3"}) // table t3 does not exist + set1.Add(indexadvisor.Query{Text: "select d from t1"}) // column d does not exist + set1.Add(indexadvisor.Query{Text: "select * from t1 where a<(select max(b) from t2)"}) // Fix43817 + set1.Add(indexadvisor.Query{Text: "wrong"}) // invalid query + + set2, err := indexadvisor.FilterInvalidQueries(opt, set1, true) + require.NoError(t, err) + require.Equal(t, set2.String(), "{select * from test.t1}") + + _, err = indexadvisor.FilterInvalidQueries(opt, set1, false) + require.Error(t, err) +} + +func TestCollectIndexableColumnsForQuerySet(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t (a int, b int, c int, d int, e int, f int)`) + opt := indexadvisor.NewOptimizer(tk.Session()) + + set1 := s.NewSet[indexadvisor.Query]() + set1.Add(indexadvisor.Query{Text: "select * from test.t where a=1 and b=1 and e like 'abc'"}) + set1.Add(indexadvisor.Query{Text: "select * from test.t where a<1 and b>1 and e like 'abc'"}) + set1.Add(indexadvisor.Query{Text: "select * from test.t where c in (1, 2, 3) order by d"}) + set1.Add(indexadvisor.Query{Text: "select 1 from test.t where c in (1, 2, 3) group by e"}) + + set2, err := indexadvisor.CollectIndexableColumnsForQuerySet(opt, set1) + require.NoError(t, err) + require.Equal(t, "{test.t.a, test.t.b, test.t.c, test.t.d, test.t.e}", set2.String()) +} + +func TestCollectIndexableColumnsFromQuery(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t (a int, b int, c int, d int, e int)`) + opt := indexadvisor.NewOptimizer(tk.Session()) + + cols, err := indexadvisor.CollectIndexableColumnsFromQuery( + indexadvisor.Query{SchemaName: "test", Text: "select * from t where a<1 and b>1 and e like 'abc'"}, opt) + require.NoError(t, err) + require.Equal(t, cols.String(), "{test.t.a, test.t.b}") + + cols, err = indexadvisor.CollectIndexableColumnsFromQuery( + indexadvisor.Query{SchemaName: "test", Text: "select * from t where c in (1, 2, 3) order by d"}, opt) + require.NoError(t, err) + require.Equal(t, cols.String(), "{test.t.c, test.t.d}") + + cols, err = indexadvisor.CollectIndexableColumnsFromQuery( + indexadvisor.Query{SchemaName: "test", Text: "select 1 from t where c in (1, 2, 3) group by d"}, opt) + require.NoError(t, err) + require.Equal(t, cols.String(), "{test.t.c, test.t.d}") + + tk.MustExec("drop table t") + + tk.MustExec(`create table t1 (a int)`) + tk.MustExec(`create table t2 (a int)`) + cols, err = indexadvisor.CollectIndexableColumnsFromQuery( + indexadvisor.Query{SchemaName: "test", Text: "select * from t2 tx where a<1"}, opt) + require.NoError(t, err) + require.Equal(t, cols.String(), "{test.t1.a, test.t2.a}") + tk.MustExec("drop table t1") + tk.MustExec("drop table t2") + + tk.MustExec(`create database tpch`) + tk.MustExec(`use tpch`) + tk.MustExec(`CREATE TABLE tpch.nation ( N_NATIONKEY bigint(20) NOT NULL, + N_NAME char(25) NOT NULL, N_REGIONKEY bigint(20) NOT NULL, N_COMMENT varchar(152) DEFAULT NULL, + PRIMARY KEY (N_NATIONKEY) /*T![clustered_index] CLUSTERED */)`) + q := ` select supp_nation, cust_nation, l_year, sum(volume) as revenue from + ( select n1.n_name as supp_nation, n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume + from supplier, lineitem, orders, customer, nation n1, nation n2 + where s_suppkey = l_suppkey and o_orderkey = l_orderkey and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey and c_nationkey = n2.n_nationkey + and ( (n1.n_name = 'MOZAMBIQUE' and n2.n_name = 'UNITED KINGDOM') + or (n1.n_name = 'UNITED KINGDOM' and n2.n_name = 'MOZAMBIQUE') + ) and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping group by supp_nation, cust_nation, l_year + order by supp_nation, cust_nation, l_year` + cols, err = indexadvisor.CollectIndexableColumnsFromQuery( + indexadvisor.Query{SchemaName: "tpch", Text: q}, opt) + require.NoError(t, err) + require.Equal(t, cols.String(), "{tpch.nation.n_name, tpch.nation.n_nationkey}") +} diff --git a/pkg/planner/optimize.go b/pkg/planner/optimize.go index ac596e7491112..808d467a76bb9 100644 --- a/pkg/planner/optimize.go +++ b/pkg/planner/optimize.go @@ -35,7 +35,10 @@ import ( "github.com/pingcap/tidb/pkg/planner/core" "github.com/pingcap/tidb/pkg/planner/core/base" "github.com/pingcap/tidb/pkg/planner/core/resolve" + "github.com/pingcap/tidb/pkg/planner/indexadvisor" + "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" + "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" "github.com/pingcap/tidb/pkg/privilege" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" @@ -623,9 +626,24 @@ func hypoIndexChecker(ctx context.Context, is infoschema.InfoSchema) func(db, tb } } +// queryPlanCost returns the plan cost of this node, which is mainly for the Index Advisor. +func queryPlanCost(sctx sessionctx.Context, stmt ast.StmtNode) (float64, error) { + nodeW := resolve.NewNodeW(stmt) + plan, _, err := Optimize(context.Background(), sctx, nodeW, sctx.GetDomainInfoSchema().(infoschema.InfoSchema)) + if err != nil { + return 0, err + } + pp, ok := plan.(base.PhysicalPlan) + if !ok { + return 0, errors.Errorf("plan is not a physical plan: %T", plan) + } + return core.GetPlanCost(pp, property.RootTaskType, optimizetrace.NewDefaultPlanCostOption()) +} + func init() { core.OptimizeAstNode = Optimize core.IsReadOnly = IsReadOnly + indexadvisor.QueryPlanCostHook = queryPlanCost bindinfo.GetGlobalBindingHandle = func(sctx sessionctx.Context) bindinfo.GlobalBindingHandle { return domain.GetDomain(sctx).BindHandle() } diff --git a/pkg/util/set/BUILD.bazel b/pkg/util/set/BUILD.bazel index 443f3c384023a..5e3b1d0343d1a 100644 --- a/pkg/util/set/BUILD.bazel +++ b/pkg/util/set/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "float64_set.go", "int_set.go", "mem_aware_map.go", + "set.go", "set_with_memory_usage.go", "string_set.go", ], @@ -26,6 +27,7 @@ go_test( "int_set_test.go", "main_test.go", "mem_aware_map_test.go", + "set_test.go", "set_with_memory_usage_test.go", "string_set_test.go", ], @@ -34,6 +36,7 @@ go_test( deps = [ "//pkg/testkit/testsetup", "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", "@org_uber_go_goleak//:goleak", ], ) diff --git a/pkg/util/set/set.go b/pkg/util/set/set.go new file mode 100644 index 0000000000000..65777b72481a0 --- /dev/null +++ b/pkg/util/set/set.go @@ -0,0 +1,184 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package set + +import ( + "fmt" + "sort" + "strings" +) + +// Key is the interface for the key of a set item. +type Key interface { + Key() string +} + +// Set is the interface for a set. +type Set[T Key] interface { + Add(items ...T) + Contains(item T) bool + Remove(item T) + ToList() []T + Size() int + Clone() Set[T] + String() string +} + +type setImpl[T Key] struct { + s map[string]T +} + +// NewSet creates a new set. +func NewSet[T Key]() Set[T] { + return new(setImpl[T]) +} + +func (s *setImpl[T]) Add(items ...T) { + if s.s == nil { + s.s = make(map[string]T) + } + for _, item := range items { + s.s[item.Key()] = item + } +} + +func (s *setImpl[T]) Contains(item T) bool { + if s.s == nil { + return false + } + _, ok := s.s[item.Key()] + return ok +} + +func (s *setImpl[T]) ToList() []T { + if s == nil { + return nil + } + list := make([]T, 0, len(s.s)) + for _, v := range s.s { + list = append(list, v) + } + sort.Slice(list, func(i, j int) bool { + return list[i].Key() < list[j].Key() + }) // to make the result stable + return list +} + +func (s *setImpl[T]) Remove(item T) { + delete(s.s, item.Key()) +} + +func (s *setImpl[T]) Size() int { + if s == nil { + return 0 + } + return len(s.s) +} + +func (s *setImpl[T]) Clone() Set[T] { + clone := NewSet[T]() + clone.Add(s.ToList()...) + return clone +} + +func (s *setImpl[T]) String() string { + items := make([]string, 0, len(s.s)) + for _, item := range s.s { + items = append(items, item.Key()) + } + sort.Strings(items) + return fmt.Sprintf("{%v}", strings.Join(items, ", ")) +} + +// ListToSet converts a list to a set. +func ListToSet[T Key](items ...T) Set[T] { + s := NewSet[T]() + for _, item := range items { + s.Add(item) + } + return s +} + +// UnionSet returns the union set of the given sets. +func UnionSet[T Key](ss ...Set[T]) Set[T] { + if len(ss) == 0 { + return NewSet[T]() + } + if len(ss) == 1 { + return ss[0].Clone() + } + s := NewSet[T]() + for _, set := range ss { + s.Add(set.ToList()...) + } + return s +} + +// AndSet returns the intersection set of the given sets. +func AndSet[T Key](ss ...Set[T]) Set[T] { + if len(ss) == 0 { + return NewSet[T]() + } + if len(ss) == 1 { + return ss[0].Clone() + } + s := NewSet[T]() + for _, item := range ss[0].ToList() { + contained := true + for _, set := range ss[1:] { + if !set.Contains(item) { + contained = false + break + } + } + if contained { + s.Add(item) + } + } + return s +} + +// DiffSet returns a set of items that are in s1 but not in s2. +// DiffSet({1, 2, 3, 4}, {2, 3}) = {1, 4} +func DiffSet[T Key](s1, s2 Set[T]) Set[T] { + s := NewSet[T]() + for _, item := range s1.ToList() { + if !s2.Contains(item) { + s.Add(item) + } + } + return s +} + +// CombSet returns all combinations of `numberOfItems` items in the given set. +// For example ({a, b, c}, 2) returns {ab, ac, bc}. +func CombSet[T Key](s Set[T], numberOfItems int) []Set[T] { + return combSetIterate(s.ToList(), NewSet[T](), 0, numberOfItems) +} + +func combSetIterate[T Key](itemList []T, currSet Set[T], depth, numberOfItems int) []Set[T] { + if currSet.Size() == numberOfItems { + return []Set[T]{currSet.Clone()} + } + if depth == len(itemList) || currSet.Size() > numberOfItems { + return nil + } + var res []Set[T] + currSet.Add(itemList[depth]) + res = append(res, combSetIterate(itemList, currSet, depth+1, numberOfItems)...) + currSet.Remove(itemList[depth]) + res = append(res, combSetIterate(itemList, currSet, depth+1, numberOfItems)...) + return res +} diff --git a/pkg/util/set/set_test.go b/pkg/util/set/set_test.go new file mode 100644 index 0000000000000..270a21cad6b66 --- /dev/null +++ b/pkg/util/set/set_test.go @@ -0,0 +1,96 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package set + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type item struct { + Text string +} + +func (i item) Key() string { + return i.Text +} + +func TestSetBasic(t *testing.T) { + s := NewSet[item]() + s.Add(item{Text: "q1"}, item{Text: "q2"}, item{Text: "q3"}) + require.True(t, s.Contains(item{Text: "q1"})) + require.True(t, s.Contains(item{Text: "q2"})) + require.True(t, s.Contains(item{Text: "q3"})) + require.False(t, s.Contains(item{Text: "q4"})) + require.Equal(t, 3, s.Size()) + require.Equal(t, []item{{Text: "q1"}, {Text: "q2"}, {Text: "q3"}}, s.ToList()) + s.Remove(item{Text: "q2"}) + require.False(t, s.Contains(item{Text: "q2"})) + require.Equal(t, 2, s.Size()) + + clonedS := s.Clone() + require.True(t, clonedS.Contains(item{Text: "q1"})) + s.Remove(item{Text: "q1"}) + require.False(t, s.Contains(item{Text: "q1"})) + require.True(t, clonedS.Contains(item{Text: "q1"})) + require.Equal(t, 2, clonedS.Size()) +} + +func TestSetOperation(t *testing.T) { + s1 := NewSet[item]() + s1.Add(item{Text: "q1"}, item{Text: "q2"}, item{Text: "q3"}) + s2 := NewSet[item]() + s2.Add(item{Text: "q2"}, item{Text: "q3"}, item{Text: "q4"}) + unionSet := UnionSet(s1, s2) + require.Equal(t, []item{{Text: "q1"}, {Text: "q2"}, {Text: "q3"}, {Text: "q4"}}, unionSet.ToList()) + + andSet := AndSet(s1, s2) + require.Equal(t, []item{{Text: "q2"}, {Text: "q3"}}, andSet.ToList()) + + diffSet := DiffSet(s1, s2) + require.Equal(t, []item{{Text: "q1"}}, diffSet.ToList()) + diffSet = DiffSet(s2, s1) + require.Equal(t, []item{{Text: "q4"}}, diffSet.ToList()) +} + +func TestSetCombination(t *testing.T) { + s := NewSet[item]() + s.Add(item{Text: "q1"}, item{Text: "q2"}, item{Text: "q3"}, item{Text: "q4"}) + + setListStr := func(setList []Set[item]) string { + var tmp []string + for _, set := range setList { + tmp = append(tmp, set.String()) + } + return strings.Join(tmp, ", ") + } + + s1 := CombSet(s, 1) + require.Equal(t, "{q1}, {q2}, {q3}, {q4}", setListStr(s1)) + + s2 := CombSet(s, 2) + require.Equal(t, "{q1, q2}, {q1, q3}, {q1, q4}, {q2, q3}, {q2, q4}, {q3, q4}", setListStr(s2)) + + s3 := CombSet(s, 3) + require.Equal(t, "{q1, q2, q3}, {q1, q2, q4}, {q1, q3, q4}, {q2, q3, q4}", setListStr(s3)) + + s4 := CombSet(s, 4) + require.Equal(t, "{q1, q2, q3, q4}", setListStr(s4)) + + s5 := CombSet(s, 5) + require.Equal(t, "", setListStr(s5)) +}