Skip to content

Commit

Permalink
planner: Move the Selectivity function from the stats package into ca…
Browse files Browse the repository at this point in the history
…rdinality package (#46438)

ref #46358
  • Loading branch information
qw4990 authored Aug 28, 2023
1 parent c11a999 commit 2163271
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 105 deletions.
3 changes: 2 additions & 1 deletion planner/cardinality/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,14 @@ go_test(
timeout = "short",
srcs = [
"main_test.go",
"row_count_test.go",
"selectivity_test.go",
"trace_test.go",
],
data = glob(["testdata/**"]),
embed = [":cardinality"],
flaky = True,
shard_count = 30,
shard_count = 31,
deps = [
"//config",
"//domain",
Expand Down
59 changes: 57 additions & 2 deletions planner/cardinality/row_count_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ func init() {
statistics.GetRowCountByColumnRanges = GetRowCountByColumnRanges
statistics.GetRowCountByIntColumnRanges = GetRowCountByIntColumnRanges
statistics.GetRowCountByIndexRanges = GetRowCountByIndexRanges
statistics.EqualRowCountOnColumn = equalRowCountOnColumn
statistics.BetweenRowCountOnColumn = betweenRowCountOnColumn
}

// GetRowCountByColumnRanges estimates the row count by a slice of Range.
Expand Down Expand Up @@ -306,3 +304,60 @@ func betweenRowCountOnColumn(sctx sessionctx.Context, c *statistics.Column, l, r
}
return float64(c.TopN.BetweenCount(sctx, lowEncoded, highEncoded)) + histBetweenCnt
}

// functions below are mainly for testing.

// ColumnGreaterRowCount estimates the row count where the column greater than value.
func ColumnGreaterRowCount(sctx sessionctx.Context, t *statistics.Table, value types.Datum, colID int64) float64 {
c, ok := t.Columns[colID]
if !ok || c.IsInvalid(sctx, t.Pseudo) {
return float64(t.RealtimeCount) / pseudoLessRate
}
return c.GreaterRowCount(value) * c.GetIncreaseFactor(t.RealtimeCount)
}

// ColumnLessRowCount estimates the row count where the column less than value. Note that null values are not counted.
func ColumnLessRowCount(sctx sessionctx.Context, t *statistics.Table, value types.Datum, colID int64) float64 {
c, ok := t.Columns[colID]
if !ok || c.IsInvalid(sctx, t.Pseudo) {
return float64(t.RealtimeCount) / pseudoLessRate
}
return c.LessRowCount(sctx, value) * c.GetIncreaseFactor(t.RealtimeCount)
}

// ColumnBetweenRowCount estimates the row count where column greater or equal to a and less than b.
func ColumnBetweenRowCount(sctx sessionctx.Context, t *statistics.Table, a, b types.Datum, colID int64) (float64, error) {
sc := sctx.GetSessionVars().StmtCtx
c, ok := t.Columns[colID]
if !ok || c.IsInvalid(sctx, t.Pseudo) {
return float64(t.RealtimeCount) / pseudoBetweenRate, nil
}
aEncoded, err := codec.EncodeKey(sc, nil, a)
if err != nil {
return 0, err
}
bEncoded, err := codec.EncodeKey(sc, nil, b)
if err != nil {
return 0, err
}
count := betweenRowCountOnColumn(sctx, c, a, b, aEncoded, bEncoded)
if a.IsNull() {
count += float64(c.NullCount)
}
return count * c.GetIncreaseFactor(t.RealtimeCount), nil
}

// ColumnEqualRowCount estimates the row count where the column equals to value.
func ColumnEqualRowCount(sctx sessionctx.Context, t *statistics.Table, value types.Datum, colID int64) (float64, error) {
c, ok := t.Columns[colID]
if !ok || c.IsInvalid(sctx, t.Pseudo) {
return float64(t.RealtimeCount) / pseudoEqualRate, nil
}
encodedVal, err := codec.EncodeKey(sctx.GetSessionVars().StmtCtx, nil, value)
if err != nil {
return 0, err
}
result, err := equalRowCountOnColumn(sctx, c, value, encodedVal, t.ModifyCount)
result *= c.GetIncreaseFactor(t.RealtimeCount)
return result, errors.Trace(err)
}
56 changes: 56 additions & 0 deletions planner/cardinality/row_count_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cardinality

import (
"testing"

"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/mock"
"github.com/stretchr/testify/require"
)

func TestPseudoTable(t *testing.T) {
ti := &model.TableInfo{}
colInfo := &model.ColumnInfo{
ID: 1,
FieldType: *types.NewFieldType(mysql.TypeLonglong),
State: model.StatePublic,
}
ti.Columns = append(ti.Columns, colInfo)
tbl := statistics.PseudoTable(ti)
require.Len(t, tbl.Columns, 1)
require.Greater(t, tbl.RealtimeCount, int64(0))
sctx := mock.NewContext()
count := ColumnLessRowCount(sctx, tbl, types.NewIntDatum(100), colInfo.ID)
require.Equal(t, 3333, int(count))
count, err := ColumnEqualRowCount(sctx, tbl, types.NewIntDatum(1000), colInfo.ID)
require.NoError(t, err)
require.Equal(t, 10, int(count))
count, _ = ColumnBetweenRowCount(sctx, tbl, types.NewIntDatum(1000), types.NewIntDatum(5000), colInfo.ID)
require.Equal(t, 250, int(count))
ti.Columns = append(ti.Columns, &model.ColumnInfo{
ID: 2,
FieldType: *types.NewFieldType(mysql.TypeLonglong),
Hidden: true,
State: model.StatePublic,
})
tbl = statistics.PseudoTable(ti)
// We added a hidden column. The pseudo table still only have one column.
require.Equal(t, len(tbl.Columns), 1)
}
2 changes: 1 addition & 1 deletion statistics/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ go_test(
data = glob(["testdata/**"]),
embed = [":statistics"],
flaky = True,
shard_count = 41,
shard_count = 40,
deps = [
"//config",
"//parser/ast",
Expand Down
1 change: 1 addition & 0 deletions statistics/handle/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ go_test(
"//config",
"//domain",
"//parser/model",
"//planner/cardinality",
"//sessionctx/stmtctx",
"//sessionctx/variable",
"//statistics",
Expand Down
9 changes: 5 additions & 4 deletions statistics/handle/ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"testing"

"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/planner/cardinality"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/mock"
Expand Down Expand Up @@ -51,9 +52,9 @@ func TestDDLAfterLoad(t *testing.T) {
tableInfo = tbl.Meta()

sctx := mock.NewContext()
count := statsTbl.ColumnGreaterRowCount(sctx, types.NewDatum(recordCount+1), tableInfo.Columns[0].ID)
count := cardinality.ColumnGreaterRowCount(sctx, statsTbl, types.NewDatum(recordCount+1), tableInfo.Columns[0].ID)
require.Equal(t, 0.0, count)
count = statsTbl.ColumnGreaterRowCount(sctx, types.NewDatum(recordCount+1), tableInfo.Columns[2].ID)
count = cardinality.ColumnGreaterRowCount(sctx, statsTbl, types.NewDatum(recordCount+1), tableInfo.Columns[2].ID)
require.Equal(t, 333, int(count))
}

Expand Down Expand Up @@ -133,10 +134,10 @@ func TestDDLHistogram(t *testing.T) {
require.False(t, statsTbl.Pseudo)
require.True(t, statsTbl.Columns[tableInfo.Columns[3].ID].IsStatsInitialized())
sctx := mock.NewContext()
count, err := statsTbl.ColumnEqualRowCount(sctx, types.NewIntDatum(0), tableInfo.Columns[3].ID)
count, err := cardinality.ColumnEqualRowCount(sctx, statsTbl, types.NewIntDatum(0), tableInfo.Columns[3].ID)
require.NoError(t, err)
require.Equal(t, float64(2), count)
count, err = statsTbl.ColumnEqualRowCount(sctx, types.NewIntDatum(1), tableInfo.Columns[3].ID)
count, err = cardinality.ColumnEqualRowCount(sctx, statsTbl, types.NewIntDatum(1), tableInfo.Columns[3].ID)
require.NoError(t, err)
require.Equal(t, float64(0), count)

Expand Down
2 changes: 1 addition & 1 deletion statistics/handle/handletest/handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestEmptyTable(t *testing.T) {
require.NoError(t, err)
tableInfo := tbl.Meta()
statsTbl := do.StatsHandle().GetTableStats(tableInfo)
count := statsTbl.ColumnGreaterRowCount(mock.NewContext(), types.NewDatum(1), tableInfo.Columns[0].ID)
count := cardinality.ColumnGreaterRowCount(mock.NewContext(), statsTbl, types.NewDatum(1), tableInfo.Columns[0].ID)
require.Equal(t, 0.0, count)
}

Expand Down
1 change: 1 addition & 0 deletions statistics/handle/handletest/statstest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ go_test(
deps = [
"//config",
"//parser/model",
"//planner/cardinality",
"//statistics/handle/internal",
"//testkit",
"//testkit/testsetup",
Expand Down
5 changes: 3 additions & 2 deletions statistics/handle/handletest/statstest/stats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/planner/cardinality"
"github.com/pingcap/tidb/statistics/handle/internal"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -335,9 +336,9 @@ func TestLoadStats(t *testing.T) {
require.Nil(t, cms)

// Column stats are loaded after they are needed.
_, err = stat.ColumnEqualRowCount(testKit.Session(), types.NewIntDatum(1), colAID)
_, err = cardinality.ColumnEqualRowCount(testKit.Session(), stat, types.NewIntDatum(1), colAID)
require.NoError(t, err)
_, err = stat.ColumnEqualRowCount(testKit.Session(), types.NewIntDatum(1), colCID)
_, err = cardinality.ColumnEqualRowCount(testKit.Session(), stat, types.NewIntDatum(1), colCID)
require.NoError(t, err)
require.NoError(t, h.LoadNeededHistograms())
stat = h.GetTableStats(tableInfo)
Expand Down
1 change: 1 addition & 0 deletions statistics/handle/updatetest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ go_test(
deps = [
"//parser/model",
"//parser/mysql",
"//planner/cardinality",
"//sessionctx/variable",
"//statistics",
"//statistics/handle",
Expand Down
3 changes: 2 additions & 1 deletion statistics/handle/updatetest/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/planner/cardinality"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/statistics/handle"
Expand Down Expand Up @@ -87,7 +88,7 @@ func TestSingleSessionInsert(t *testing.T) {
require.Equal(t, int64(rowCount1*2), stats1.RealtimeCount)

// Test IncreaseFactor.
count, err := stats1.ColumnEqualRowCount(testKit.Session(), types.NewIntDatum(1), tableInfo1.Columns[0].ID)
count, err := cardinality.ColumnEqualRowCount(testKit.Session(), stats1, types.NewIntDatum(1), tableInfo1.Columns[0].ID)
require.NoError(t, err)
require.Equal(t, float64(rowCount1*2), count)

Expand Down
30 changes: 0 additions & 30 deletions statistics/statistics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,36 +190,6 @@ func TestMergeHistogram(t *testing.T) {
}
}

func TestPseudoTable(t *testing.T) {
ti := &model.TableInfo{}
colInfo := &model.ColumnInfo{
ID: 1,
FieldType: *types.NewFieldType(mysql.TypeLonglong),
State: model.StatePublic,
}
ti.Columns = append(ti.Columns, colInfo)
tbl := PseudoTable(ti)
require.Len(t, tbl.Columns, 1)
require.Greater(t, tbl.RealtimeCount, int64(0))
sctx := mock.NewContext()
count := tbl.ColumnLessRowCount(sctx, types.NewIntDatum(100), colInfo.ID)
require.Equal(t, 3333, int(count))
count, err := tbl.ColumnEqualRowCount(sctx, types.NewIntDatum(1000), colInfo.ID)
require.NoError(t, err)
require.Equal(t, 10, int(count))
count, _ = tbl.ColumnBetweenRowCount(sctx, types.NewIntDatum(1000), types.NewIntDatum(5000), colInfo.ID)
require.Equal(t, 250, int(count))
ti.Columns = append(ti.Columns, &model.ColumnInfo{
ID: 2,
FieldType: *types.NewFieldType(mysql.TypeLonglong),
Hidden: true,
State: model.StatePublic,
})
tbl = PseudoTable(ti)
// We added a hidden column. The pseudo table still only have one column.
require.Equal(t, len(tbl.Columns), 1)
}

func buildCMSketch(values []types.Datum) *CMSketch {
cms := NewCMSketch(8, 2048)
for _, val := range values {
Expand Down
63 changes: 0 additions & 63 deletions statistics/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"strings"
"sync"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/model"
Expand All @@ -30,7 +29,6 @@ import (
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/ranger"
"go.uber.org/atomic"
)
Expand Down Expand Up @@ -65,12 +63,6 @@ var (

// GetRowCountByColumnRanges is a function type to get row count by column ranges.
GetRowCountByColumnRanges func(sctx sessionctx.Context, coll *HistColl, colID int64, colRanges []*ranger.Range) (result float64, err error)

// EqualRowCountOnColumn is a function type to get the row count by equal condition on column.
EqualRowCountOnColumn func(sctx sessionctx.Context, c *Column, val types.Datum, encodedVal []byte, realtimeRowCount int64) (result float64, err error)

// BetweenRowCountOnColumn is a function type to get the row count by between condition on column.
BetweenRowCountOnColumn func(sctx sessionctx.Context, c *Column, l, r types.Datum, lowEncoded, highEncoded []byte) float64
)

// Table represents statistics for a table.
Expand Down Expand Up @@ -488,61 +480,6 @@ func (t *Table) IsOutdated() bool {
return false
}

// ColumnGreaterRowCount estimates the row count where the column greater than value.
func (t *Table) ColumnGreaterRowCount(sctx sessionctx.Context, value types.Datum, colID int64) float64 {
c, ok := t.Columns[colID]
if !ok || c.IsInvalid(sctx, t.Pseudo) {
return float64(t.RealtimeCount) / pseudoLessRate
}
return c.GreaterRowCount(value) * c.GetIncreaseFactor(t.RealtimeCount)
}

// ColumnLessRowCount estimates the row count where the column less than value. Note that null values are not counted.
func (t *Table) ColumnLessRowCount(sctx sessionctx.Context, value types.Datum, colID int64) float64 {
c, ok := t.Columns[colID]
if !ok || c.IsInvalid(sctx, t.Pseudo) {
return float64(t.RealtimeCount) / pseudoLessRate
}
return c.LessRowCount(sctx, value) * c.GetIncreaseFactor(t.RealtimeCount)
}

// ColumnBetweenRowCount estimates the row count where column greater or equal to a and less than b.
func (t *Table) ColumnBetweenRowCount(sctx sessionctx.Context, a, b types.Datum, colID int64) (float64, error) {
sc := sctx.GetSessionVars().StmtCtx
c, ok := t.Columns[colID]
if !ok || c.IsInvalid(sctx, t.Pseudo) {
return float64(t.RealtimeCount) / pseudoBetweenRate, nil
}
aEncoded, err := codec.EncodeKey(sc, nil, a)
if err != nil {
return 0, err
}
bEncoded, err := codec.EncodeKey(sc, nil, b)
if err != nil {
return 0, err
}
count := BetweenRowCountOnColumn(sctx, c, a, b, aEncoded, bEncoded)
if a.IsNull() {
count += float64(c.NullCount)
}
return count * c.GetIncreaseFactor(t.RealtimeCount), nil
}

// ColumnEqualRowCount estimates the row count where the column equals to value.
func (t *Table) ColumnEqualRowCount(sctx sessionctx.Context, value types.Datum, colID int64) (float64, error) {
c, ok := t.Columns[colID]
if !ok || c.IsInvalid(sctx, t.Pseudo) {
return float64(t.RealtimeCount) / pseudoEqualRate, nil
}
encodedVal, err := codec.EncodeKey(sctx.GetSessionVars().StmtCtx, nil, value)
if err != nil {
return 0, err
}
result, err := EqualRowCountOnColumn(sctx, c, value, encodedVal, t.ModifyCount)
result *= c.GetIncreaseFactor(t.RealtimeCount)
return result, errors.Trace(err)
}

// PseudoAvgCountPerValue gets a pseudo average count if histogram not exists.
func (t *Table) PseudoAvgCountPerValue() float64 {
return float64(t.RealtimeCount) / pseudoEqualRate
Expand Down

0 comments on commit 2163271

Please sign in to comment.