Skip to content

Commit

Permalink
planner: fix column evaluator can not detect input's column-ref and t…
Browse files Browse the repository at this point in the history
…hus swapping and destroying later column ref projection logic (#53794) (#56199)

close #53713
  • Loading branch information
ti-chi-bot authored Sep 26, 2024
1 parent e329890 commit 8e57797
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 2 deletions.
1 change: 1 addition & 0 deletions pkg/expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ go_library(
"//pkg/util/encrypt",
"//pkg/util/generatedexpr",
"//pkg/util/hack",
"//pkg/util/intest",
"//pkg/util/intset",
"//pkg/util/logutil",
"//pkg/util/mathutil",
Expand Down
99 changes: 98 additions & 1 deletion pkg/expression/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,30 @@
package expression

import (
"sync/atomic"

"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/disjointset"
"github.com/pingcap/tidb/pkg/util/intest"
)

type columnEvaluator struct {
inputIdxToOutputIdxes map[int][]int
// mergedInputIdxToOutputIdxes is only determined in runtime when saw the input chunk.
mergedInputIdxToOutputIdxes atomic.Pointer[map[int][]int]
}

// run evaluates "Column" expressions.
// NOTE: It should be called after all the other expressions are evaluated
//
// since it will change the content of the input Chunk.
func (e *columnEvaluator) run(ctx sessionctx.Context, input, output *chunk.Chunk) error {
for inputIdx, outputIdxes := range e.inputIdxToOutputIdxes {
// mergedInputIdxToOutputIdxes only can be determined in runtime when we saw the input chunk structure.
if e.mergedInputIdxToOutputIdxes.Load() == nil {
e.mergeInputIdxToOutputIdxes(input, e.inputIdxToOutputIdxes)
}
for inputIdx, outputIdxes := range *e.mergedInputIdxToOutputIdxes.Load() {
if err := output.SwapColumn(outputIdxes[0], input, inputIdx); err != nil {
return err
}
Expand All @@ -39,6 +49,93 @@ func (e *columnEvaluator) run(ctx sessionctx.Context, input, output *chunk.Chunk
return nil
}

// mergeInputIdxToOutputIdxes merges separate inputIdxToOutputIdxes entries when column references
// are detected within the input chunk. This process ensures consistent handling of columns derived
// from the same original source.
//
// Consider the following scenario:
//
// Initial scan operation produces a column 'a':
//
// scan: a (addr: ???)
//
// This column 'a' is used in the first projection (proj1) to create two columns a1 and a2, both referencing 'a':
//
// proj1
// / \
// / \
// / \
// a1 (addr: 0xe) a2 (addr: 0xe)
// / \
// / \
// / \
// proj2 proj2
// / \ / \
// / \ / \
// a3 a4 a5 a6
//
// (addr: 0xe) (addr: 0xe) (addr: 0xe) (addr: 0xe)
//
// Here, a1 and a2 share the same address (0xe), indicating they reference the same data from the original 'a'.
//
// When moving to the second projection (proj2), the system tries to project these columns further:
// - The first set (left side) consists of a3 and a4, derived from a1, both retaining the address (0xe).
// - The second set (right side) consists of a5 and a6, derived from a2, also starting with address (0xe).
//
// When proj1 is complete, the output chunk contains two columns [a1, a2], both derived from the single column 'a' from the scan.
// Since both a1 and a2 are column references with the same address (0xe), they are treated as referencing the same data.
//
// In proj2, two separate <inputIdx, []outputIdxes> items are created:
// - <0, [0,1]>: This means the 0th input column (a1) is projected twice, into the 0th and 1st columns of the output chunk.
// - <1, [2,3]>: This means the 1st input column (a2) is projected twice, into the 2nd and 3rd columns of the output chunk.
//
// Due to the column swapping logic in each projection, after applying the <0, [0,1]> projection,
// the addresses for a1 and a2 may become swapped or invalid:
//
// proj1: a1 (addr: invalid) a2 (addr: invalid)
//
// This can lead to issues in proj2, where further operations on these columns may be unsafe:
//
// proj2: a3 (addr: 0xe) a4 (addr: 0xe) a5 (addr: ???) a6 (addr: ???)
//
// Therefore, it's crucial to identify and merge the original column references early, ensuring
// the final inputIdxToOutputIdxes mapping accurately reflects the shared origins of the data.
// For instance, <0, [0,1,2,3]> indicates that the 0th input column (original 'a') is referenced
// by all four output columns in the final output.
//
// mergeInputIdxToOutputIdxes merges inputIdxToOutputIdxes based on detected column references.
// This ensures that columns with the same reference are correctly handled in the output chunk.
func (e *columnEvaluator) mergeInputIdxToOutputIdxes(input *chunk.Chunk, inputIdxToOutputIdxes map[int][]int) {
originalDJSet := disjointset.NewSet[int](4)
flag := make([]bool, input.NumCols())
// Detect self column-references inside the input chunk by comparing column addresses
for i := 0; i < input.NumCols(); i++ {
if flag[i] {
continue
}
for j := i + 1; j < input.NumCols(); j++ {
if input.Column(i) == input.Column(j) {
flag[j] = true
originalDJSet.Union(i, j)
}
}
}
// Merge inputIdxToOutputIdxes based on the detected column references.
newInputIdxToOutputIdxes := make(map[int][]int, len(inputIdxToOutputIdxes))
for inputIdx := range inputIdxToOutputIdxes {
// Root idx is internal offset, not the right column index.
originalRootIdx := originalDJSet.FindRoot(inputIdx)
originalVal, ok := originalDJSet.FindVal(originalRootIdx)
intest.Assert(ok)
mergedOutputIdxes := newInputIdxToOutputIdxes[originalVal]
mergedOutputIdxes = append(mergedOutputIdxes, inputIdxToOutputIdxes[inputIdx]...)
newInputIdxToOutputIdxes[originalVal] = mergedOutputIdxes
}
// Update the merged inputIdxToOutputIdxes automatically.
// Once failed, it means other worker has done this job at meantime.
e.mergedInputIdxToOutputIdxes.CompareAndSwap(nil, &newInputIdxToOutputIdxes)
}

type defaultEvaluator struct {
outputIdxes []int
exprs []Expression
Expand Down
40 changes: 40 additions & 0 deletions pkg/expression/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package expression

import (
"slices"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -593,3 +594,42 @@ func TestMod(t *testing.T) {
require.NoError(t, err)
require.Equal(t, types.NewDatum(1.5), r)
}

func TestMergeInputIdxToOutputIdxes(t *testing.T) {
ctx := createContext(t)
inputIdxToOutputIdxes := make(map[int][]int)
// input 0th should be column referred as 0th and 1st in output columns.
inputIdxToOutputIdxes[0] = []int{0, 1}
// input 1th should be column referred as 2nd and 3rd in output columns.
inputIdxToOutputIdxes[1] = []int{2, 3}
columnEval := columnEvaluator{inputIdxToOutputIdxes: inputIdxToOutputIdxes}

input := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}, 2)
input.AppendInt64(0, 99)
// input chunk's 0th and 1st are column referred itself.
input.MakeRef(0, 1)

// chunk: col1 <---(ref) col2
// ____________/ \___________/ \___
// proj: col1 col2 col3 col4
//
// original case after inputIdxToOutputIdxes[0], the original col2 will be nil pointer
// cause consecutive col3,col4 ref projection are invalid.
//
// after fix, the new inputIdxToOutputIdxes should be: inputIdxToOutputIdxes[0]: {0, 1, 2, 3}

output := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong),
types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}, 2)

err := columnEval.run(ctx, input, output)
require.NoError(t, err)
// all four columns are column-referred, pointing to the first one.
require.Equal(t, output.Column(0), output.Column(1))
require.Equal(t, output.Column(1), output.Column(2))
require.Equal(t, output.Column(2), output.Column(3))
require.Equal(t, output.GetRow(0).GetInt64(0), int64(99))

require.Equal(t, len(*columnEval.mergedInputIdxToOutputIdxes.Load()), 1)
slices.Sort((*columnEval.mergedInputIdxToOutputIdxes.Load())[0])
require.Equal(t, (*columnEval.mergedInputIdxToOutputIdxes.Load())[0], []int{0, 1, 2, 3})
}
5 changes: 4 additions & 1 deletion pkg/util/disjointset/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "disjointset",
srcs = ["int_set.go"],
srcs = [
"int_set.go",
"set.go",
],
importpath = "github.com/pingcap/tidb/pkg/util/disjointset",
visibility = ["//visibility:public"],
)
Expand Down
1 change: 1 addition & 0 deletions pkg/util/disjointset/int_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (m *IntSet) FindRoot(a int) int {
if a == m.parent[a] {
return a
}
// Path compression, which leads the time complexity to the inverse Ackermann function.
m.parent[a] = m.FindRoot(m.parent[a])
return m.parent[a]
}
85 changes: 85 additions & 0 deletions pkg/util/disjointset/set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package disjointset

// Set is the universal implementation of a disjoint set.
// It's designed for sparse cases or non-integer types.
// If you are dealing with continuous integers, you should use SimpleIntSet to avoid the cost of a hash map.
// We hash the original value to an integer index and then apply the core disjoint set algorithm.
// Time complexity: the union operation has an inverse Ackermann function time complexity, which is very close to O(1).
type Set[T comparable] struct {
parent []int
val2Idx map[T]int
idx2Val map[int]T
tailIdx int
}

// NewSet creates a disjoint set.
func NewSet[T comparable](size int) *Set[T] {
return &Set[T]{
parent: make([]int, 0, size),
val2Idx: make(map[T]int, size),
idx2Val: make(map[int]T, size),
tailIdx: 0,
}
}

func (s *Set[T]) findRootOriginalVal(a T) int {
idx, ok := s.val2Idx[a]
if !ok {
s.parent = append(s.parent, s.tailIdx)
s.val2Idx[a] = s.tailIdx
s.tailIdx++
s.idx2Val[s.tailIdx-1] = a
return s.tailIdx - 1
}
return s.findRootInternal(idx)
}

// findRoot is an internal implementation. Call it inside findRootOriginalVal.
func (s *Set[T]) findRootInternal(a int) int {
if s.parent[a] != a {
// Path compression, which leads the time complexity to the inverse Ackermann function.
s.parent[a] = s.findRootInternal(s.parent[a])
}
return s.parent[a]
}

// InSameGroup checks whether a and b are in the same group.
func (s *Set[T]) InSameGroup(a, b T) bool {
return s.findRootOriginalVal(a) == s.findRootOriginalVal(b)
}

// Union joins two sets in the disjoint set.
func (s *Set[T]) Union(a, b T) {
rootA := s.findRootOriginalVal(a)
rootB := s.findRootOriginalVal(b)
// take b as successor, respect the rootA as the root of the new set.
if rootA != rootB {
s.parent[rootB] = rootA
}
}

// FindRoot finds the root of the set that contains a.
func (s *Set[T]) FindRoot(a T) int {
// if a is not in the set, assign a new index to it.
return s.findRootOriginalVal(a)
}

// FindVal finds the value of the set corresponding to the index.
func (s *Set[T]) FindVal(idx int) (T, bool) {
v, ok := s.idx2Val[s.findRootInternal(idx)]
return v, ok
}

0 comments on commit 8e57797

Please sign in to comment.