From 8832684721934a4d5bca73ced7fd2e2f1e7eb10d Mon Sep 17 00:00:00 2001 From: Shenghui Wu <793703860@qq.com> Date: Wed, 13 Nov 2024 14:46:32 +0800 Subject: [PATCH] executor: support left outer semi join for hash join v2 (#57053) ref pingcap/tidb#53127 --- pkg/executor/join/BUILD.bazel | 3 + pkg/executor/join/base_join_probe.go | 5 + pkg/executor/join/inner_join_probe_test.go | 7 + pkg/executor/join/inner_join_spill_test.go | 3 + .../join/left_outer_semi_join_probe.go | 283 ++++++++++ .../join/left_outer_semi_join_probe_test.go | 487 ++++++++++++++++++ pkg/expression/expression.go | 3 +- pkg/planner/core/physical_plans.go | 2 +- pkg/util/queue/BUILD.bazel | 17 + pkg/util/queue/queue.go | 85 +++ pkg/util/queue/queue_test.go | 87 ++++ 11 files changed, 980 insertions(+), 2 deletions(-) create mode 100644 pkg/executor/join/left_outer_semi_join_probe.go create mode 100644 pkg/executor/join/left_outer_semi_join_probe_test.go create mode 100644 pkg/util/queue/BUILD.bazel create mode 100644 pkg/util/queue/queue.go create mode 100644 pkg/util/queue/queue_test.go diff --git a/pkg/executor/join/BUILD.bazel b/pkg/executor/join/BUILD.bazel index d93d83563288e..73e52ee6e1437 100644 --- a/pkg/executor/join/BUILD.bazel +++ b/pkg/executor/join/BUILD.bazel @@ -20,6 +20,7 @@ go_library( "join_row_table.go", "join_table_meta.go", "joiner.go", + "left_outer_semi_join_probe.go", "merge_join.go", "outer_join_probe.go", "row_table_builder.go", @@ -58,6 +59,7 @@ go_library( "//pkg/util/logutil", "//pkg/util/memory", "//pkg/util/mvmap", + "//pkg/util/queue", "//pkg/util/ranger", "//pkg/util/serialization", "//pkg/util/sqlkiller", @@ -86,6 +88,7 @@ go_test( "join_table_meta_test.go", "joiner_test.go", "left_outer_join_probe_test.go", + "left_outer_semi_join_probe_test.go", "merge_join_test.go", "outer_join_spill_test.go", "right_outer_join_probe_test.go", diff --git a/pkg/executor/join/base_join_probe.go b/pkg/executor/join/base_join_probe.go index 71eb53e060813..618374dfa906f 100644 --- a/pkg/executor/join/base_join_probe.go +++ b/pkg/executor/join/base_join_probe.go @@ -747,6 +747,11 @@ func NewJoinProbe(ctx *HashJoinCtxV2, workID uint, joinType logicalop.JoinType, return newOuterJoinProbe(base, !rightAsBuildSide, rightAsBuildSide) case logicalop.RightOuterJoin: return newOuterJoinProbe(base, rightAsBuildSide, rightAsBuildSide) + case logicalop.LeftOuterSemiJoin: + if rightAsBuildSide { + return newLeftOuterSemiJoinProbe(base) + } + fallthrough default: panic("unsupported join type") } diff --git a/pkg/executor/join/inner_join_probe_test.go b/pkg/executor/join/inner_join_probe_test.go index 88c102b2a9cf2..3362dbfa5b1d2 100644 --- a/pkg/executor/join/inner_join_probe_test.go +++ b/pkg/executor/join/inner_join_probe_test.go @@ -301,6 +301,9 @@ func testJoinProbe(t *testing.T, withSel bool, leftKeyIndex []int, rightKeyIndex resultTypes[len(resultTypes)-1].DelFlag(mysql.NotNullFlag) } } + if joinType == logicalop.LeftOuterSemiJoin { + resultTypes = append(resultTypes, types.NewFieldType(mysql.TypeTiny)) + } meta := newTableMeta(buildKeyIndex, buildTypes, buildKeyTypes, probeKeyTypes, buildUsedByOtherCondition, buildUsed, needUsedFlag) hashJoinCtx := &HashJoinCtxV2{ @@ -458,6 +461,10 @@ func testJoinProbe(t *testing.T, withSel bool, leftKeyIndex []int, rightKeyIndex expectedChunks := genRightOuterJoinResult(t, hashJoinCtx.SessCtx, rightFilter, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes, rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, rightUsed, otherCondition, resultTypes) checkChunksEqual(t, expectedChunks, resultChunks, resultTypes) + case logicalop.LeftOuterSemiJoin: + expectedChunks := genLeftOuterSemiJoinResult(t, hashJoinCtx.SessCtx, leftFilter, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes, + rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, rightUsed, otherCondition, resultTypes) + checkChunksEqual(t, expectedChunks, resultChunks, resultTypes) default: require.NoError(t, errors.New("not supported join type")) } diff --git a/pkg/executor/join/inner_join_spill_test.go b/pkg/executor/join/inner_join_spill_test.go index 575c29115f9f3..3451376e4e647 100644 --- a/pkg/executor/join/inner_join_spill_test.go +++ b/pkg/executor/join/inner_join_spill_test.go @@ -194,6 +194,9 @@ func getReturnTypes(joinType logicalop.JoinType, param spillTestParam) []*types. resultTypes[len(resultTypes)-1].DelFlag(mysql.NotNullFlag) } } + if joinType == logicalop.LeftOuterSemiJoin { + resultTypes = append(resultTypes, types.NewFieldType(mysql.TypeTiny)) + } return resultTypes } diff --git a/pkg/executor/join/left_outer_semi_join_probe.go b/pkg/executor/join/left_outer_semi_join_probe.go new file mode 100644 index 0000000000000..552bfc65f4abf --- /dev/null +++ b/pkg/executor/join/left_outer_semi_join_probe.go @@ -0,0 +1,283 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/queue" + "github.com/pingcap/tidb/pkg/util/sqlkiller" +) + +type leftOuterSemiJoinProbe struct { + baseJoinProbe + + // isMatchedRows marks whether the left side row is matched + isMatchedRows []bool + // isNullRows marks whether the left side row matched result is null + isNullRows []bool + + // buffer isNull for other condition evaluation + isNulls []bool + + // used in other condition to record which rows need to be processed + unFinishedProbeRowIdxQueue *queue.Queue[int] +} + +var _ ProbeV2 = &leftOuterSemiJoinProbe{} + +func newLeftOuterSemiJoinProbe(base baseJoinProbe) *leftOuterSemiJoinProbe { + probe := &leftOuterSemiJoinProbe{ + baseJoinProbe: base, + } + if base.ctx.hasOtherCondition() { + probe.unFinishedProbeRowIdxQueue = queue.NewQueue[int](32) + } + return probe +} + +func (j *leftOuterSemiJoinProbe) SetChunkForProbe(chunk *chunk.Chunk) (err error) { + err = j.baseJoinProbe.SetChunkForProbe(chunk) + if err != nil { + return err + } + j.resetProbeState() + return nil +} + +func (j *leftOuterSemiJoinProbe) SetRestoredChunkForProbe(chunk *chunk.Chunk) (err error) { + err = j.baseJoinProbe.SetRestoredChunkForProbe(chunk) + if err != nil { + return err + } + j.resetProbeState() + return nil +} + +func (j *leftOuterSemiJoinProbe) resetProbeState() { + j.isMatchedRows = j.isMatchedRows[:0] + for i := 0; i < j.chunkRows; i++ { + j.isMatchedRows = append(j.isMatchedRows, false) + } + j.isNullRows = j.isNullRows[:0] + for i := 0; i < j.chunkRows; i++ { + j.isNullRows = append(j.isNullRows, false) + } + if j.ctx.hasOtherCondition() { + j.unFinishedProbeRowIdxQueue.Clear() + for i := 0; i < j.chunkRows; i++ { + if j.matchedRowsHeaders[i] != 0 { + j.unFinishedProbeRowIdxQueue.Push(i) + } + } + } +} + +func (*leftOuterSemiJoinProbe) NeedScanRowTable() bool { + return false +} + +func (*leftOuterSemiJoinProbe) IsScanRowTableDone() bool { + panic("should not reach here") +} + +func (*leftOuterSemiJoinProbe) InitForScanRowTable() { + panic("should not reach here") +} + +func (*leftOuterSemiJoinProbe) ScanRowTable(joinResult *hashjoinWorkerResult, _ *sqlkiller.SQLKiller) *hashjoinWorkerResult { + return joinResult +} + +func (j *leftOuterSemiJoinProbe) Probe(joinResult *hashjoinWorkerResult, sqlKiller *sqlkiller.SQLKiller) (ok bool, _ *hashjoinWorkerResult) { + joinedChk, remainCap, err := j.prepareForProbe(joinResult.chk) + if err != nil { + joinResult.err = err + return false, joinResult + } + + if j.ctx.hasOtherCondition() { + err = j.probeWithOtherCondition(joinResult.chk, joinedChk, remainCap, sqlKiller) + } else { + err = j.probeWithoutOtherCondition(joinResult.chk, joinedChk, remainCap, sqlKiller) + } + if err != nil { + joinResult.err = err + return false, joinResult + } + return true, joinResult +} + +func (j *leftOuterSemiJoinProbe) probeWithOtherCondition(chk, joinedChk *chunk.Chunk, remainCap int, sqlKiller *sqlkiller.SQLKiller) (err error) { + if !j.unFinishedProbeRowIdxQueue.IsEmpty() { + err = j.produceResult(joinedChk, sqlKiller) + if err != nil { + return err + } + j.currentProbeRow = 0 + } + + if j.unFinishedProbeRowIdxQueue.IsEmpty() { + startProbeRow := j.currentProbeRow + j.currentProbeRow = min(startProbeRow+remainCap, j.chunkRows) + j.buildResult(chk, startProbeRow) + } + return +} + +func (j *leftOuterSemiJoinProbe) produceResult(joinedChk *chunk.Chunk, sqlKiller *sqlkiller.SQLKiller) (err error) { + err = j.concatenateProbeAndBuildRows(joinedChk, sqlKiller) + if err != nil { + return err + } + + if joinedChk.NumRows() > 0 { + j.selected, j.isNulls, err = expression.VecEvalBool(j.ctx.SessCtx.GetExprCtx().GetEvalCtx(), j.ctx.SessCtx.GetSessionVars().EnableVectorizedExpression, j.ctx.OtherCondition, joinedChk, j.selected, j.isNulls) + if err != nil { + return err + } + + for i := 0; i < joinedChk.NumRows(); i++ { + if j.selected[i] { + j.isMatchedRows[j.rowIndexInfos[i].probeRowIndex] = true + } + if j.isNulls[i] { + j.isNullRows[j.rowIndexInfos[i].probeRowIndex] = true + } + } + } + return nil +} + +func (j *leftOuterSemiJoinProbe) probeWithoutOtherCondition(_, joinedChk *chunk.Chunk, remainCap int, sqlKiller *sqlkiller.SQLKiller) (err error) { + meta := j.ctx.hashTableMeta + startProbeRow := j.currentProbeRow + tagHelper := j.ctx.hashTableContext.tagHelper + + for remainCap > 0 && j.currentProbeRow < j.chunkRows { + if j.matchedRowsHeaders[j.currentProbeRow] != 0 { + candidateRow := tagHelper.toUnsafePointer(j.matchedRowsHeaders[j.currentProbeRow]) + if !isKeyMatched(meta.keyMode, j.serializedKeys[j.currentProbeRow], candidateRow, meta) { + j.probeCollision++ + j.matchedRowsHeaders[j.currentProbeRow] = getNextRowAddress(candidateRow, tagHelper, j.matchedRowsHashValue[j.currentProbeRow]) + continue + } + j.isMatchedRows[j.currentProbeRow] = true + } + j.matchedRowsHeaders[j.currentProbeRow] = 0 + remainCap-- + j.currentProbeRow++ + } + + err = checkSQLKiller(sqlKiller, "killedDuringProbe") + + if err != nil { + return err + } + + j.buildResult(joinedChk, startProbeRow) + return nil +} + +func (j *leftOuterSemiJoinProbe) buildResult(chk *chunk.Chunk, startProbeRow int) { + var selected []bool + if startProbeRow == 0 && j.currentProbeRow == j.chunkRows && j.currentChunk.Sel() == nil && chk.NumRows() == 0 && len(j.spilledIdx) == 0 { + // TODO: Can do a shallow copy by directly copying the Column pointers + for index, colIndex := range j.lUsed { + j.currentChunk.Column(colIndex).CopyConstruct(chk.Column(index)) + } + } else { + selected = make([]bool, j.chunkRows) + for i := startProbeRow; i < j.currentProbeRow; i++ { + selected[i] = true + } + for _, spilledIdx := range j.spilledIdx { + selected[spilledIdx] = false // ignore spilled rows + } + for index, colIndex := range j.lUsed { + dstCol := chk.Column(index) + srcCol := j.currentChunk.Column(colIndex) + chunk.CopySelectedRowsWithRowIDFunc(dstCol, srcCol, selected, 0, len(selected), func(i int) int { + return j.usedRows[i] + }) + } + } + + for i := startProbeRow; i < j.currentProbeRow; i++ { + if selected != nil && !selected[i] { + continue + } + if j.isMatchedRows[i] { + chk.AppendInt64(len(j.lUsed), 1) + } else if j.isNullRows[i] { + chk.AppendNull(len(j.lUsed)) + } else { + chk.AppendInt64(len(j.lUsed), 0) + } + } + chk.SetNumVirtualRows(chk.NumRows()) +} + +var maxMatchedRowNum = 4 + +func (j *leftOuterSemiJoinProbe) matchMultiBuildRows(joinedChk *chunk.Chunk, joinedChkRemainCap *int) { + tagHelper := j.ctx.hashTableContext.tagHelper + meta := j.ctx.hashTableMeta + for j.matchedRowsHeaders[j.currentProbeRow] != 0 && *joinedChkRemainCap > 0 && j.matchedRowsForCurrentProbeRow < maxMatchedRowNum { + candidateRow := tagHelper.toUnsafePointer(j.matchedRowsHeaders[j.currentProbeRow]) + if isKeyMatched(meta.keyMode, j.serializedKeys[j.currentProbeRow], candidateRow, meta) { + j.appendBuildRowToCachedBuildRowsV1(j.currentProbeRow, candidateRow, joinedChk, 0, true) + j.matchedRowsForCurrentProbeRow++ + *joinedChkRemainCap-- + } else { + j.probeCollision++ + } + j.matchedRowsHeaders[j.currentProbeRow] = getNextRowAddress(candidateRow, tagHelper, j.matchedRowsHashValue[j.currentProbeRow]) + } + + j.finishLookupCurrentProbeRow() +} + +func (j *leftOuterSemiJoinProbe) concatenateProbeAndBuildRows(joinedChk *chunk.Chunk, sqlKiller *sqlkiller.SQLKiller) error { + joinedChkRemainCap := joinedChk.Capacity() - joinedChk.NumRows() + + for joinedChkRemainCap > 0 && !j.unFinishedProbeRowIdxQueue.IsEmpty() { + probeRowIdx := j.unFinishedProbeRowIdxQueue.Pop() + if j.isMatchedRows[probeRowIdx] { + continue + } + j.currentProbeRow = probeRowIdx + j.matchMultiBuildRows(joinedChk, &joinedChkRemainCap) + if j.matchedRowsHeaders[probeRowIdx] == 0 { + continue + } + j.unFinishedProbeRowIdxQueue.Push(probeRowIdx) + } + + err := checkSQLKiller(sqlKiller, "killedDuringProbe") + if err != nil { + return err + } + + j.finishCurrentLookupLoop(joinedChk) + return nil +} + +func (j *leftOuterSemiJoinProbe) IsCurrentChunkProbeDone() bool { + if j.ctx.hasOtherCondition() && !j.unFinishedProbeRowIdxQueue.IsEmpty() { + return false + } + return j.baseJoinProbe.IsCurrentChunkProbeDone() +} diff --git a/pkg/executor/join/left_outer_semi_join_probe_test.go b/pkg/executor/join/left_outer_semi_join_probe_test.go new file mode 100644 index 0000000000000..f488628a6e356 --- /dev/null +++ b/pkg/executor/join/left_outer_semi_join_probe_test.go @@ -0,0 +1,487 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/mock" + "github.com/stretchr/testify/require" +) + +// generate left outer semi join result using nested loop +func genLeftOuterSemiJoinResult(t *testing.T, sessCtx sessionctx.Context, leftFilter expression.CNFExprs, leftChunks []*chunk.Chunk, rightChunks []*chunk.Chunk, leftKeyIndex []int, rightKeyIndex []int, + leftTypes []*types.FieldType, rightTypes []*types.FieldType, leftKeyTypes []*types.FieldType, rightKeyTypes []*types.FieldType, leftUsedColumns []int, rightUsedColumns []int, + otherConditions expression.CNFExprs, resultTypes []*types.FieldType) []*chunk.Chunk { + filterVector := make([]bool, 0) + var err error + returnChks := make([]*chunk.Chunk, 0, 1) + resultChk := chunk.New(resultTypes, sessCtx.GetSessionVars().MaxChunkSize, sessCtx.GetSessionVars().MaxChunkSize) + shallowRowTypes := make([]*types.FieldType, 0, len(leftTypes)+len(rightTypes)) + shallowRowTypes = append(shallowRowTypes, leftTypes...) + shallowRowTypes = append(shallowRowTypes, rightTypes...) + shallowRow := chunk.MutRowFromTypes(shallowRowTypes) + + // For each row in left chunks + for _, leftChunk := range leftChunks { + if leftFilter != nil { + filterVector, err = expression.VectorizedFilter(sessCtx.GetExprCtx().GetEvalCtx(), sessCtx.GetSessionVars().EnableVectorizedExpression, leftFilter, chunk.NewIterator4Chunk(leftChunk), filterVector) + require.NoError(t, err) + } + for leftIndex := 0; leftIndex < leftChunk.NumRows(); leftIndex++ { + filterIndex := leftIndex + if leftChunk.Sel() != nil { + filterIndex = leftChunk.Sel()[leftIndex] + } + if leftFilter != nil && !filterVector[filterIndex] { + // Filtered by left filter, append 0 for matched flag + appendToResultChk(leftChunk.GetRow(leftIndex), chunk.Row{}, leftUsedColumns, nil, resultChk) + resultChk.AppendInt64(len(leftUsedColumns), 0) + if resultChk.IsFull() { + returnChks = append(returnChks, resultChk) + resultChk = chunk.New(resultTypes, sessCtx.GetSessionVars().MaxChunkSize, sessCtx.GetSessionVars().MaxChunkSize) + } + continue + } + + leftRow := leftChunk.GetRow(leftIndex) + hasMatch := false + hasNull := false + + // For each row in right chunks + for _, rightChunk := range rightChunks { + for rightIndex := 0; rightIndex < rightChunk.NumRows(); rightIndex++ { + rightRow := rightChunk.GetRow(rightIndex) + valid := !containsNullKey(leftRow, leftKeyIndex) && !containsNullKey(rightRow, rightKeyIndex) + if valid { + ok, err := codec.EqualChunkRow(sessCtx.GetSessionVars().StmtCtx.TypeCtx(), leftRow, leftKeyTypes, leftKeyIndex, + rightRow, rightKeyTypes, rightKeyIndex) + require.NoError(t, err) + valid = ok + } + + if valid && otherConditions != nil { + shallowRow.ShallowCopyPartialRow(0, leftRow) + shallowRow.ShallowCopyPartialRow(len(leftTypes), rightRow) + matched, null, err := expression.EvalBool(sessCtx.GetExprCtx().GetEvalCtx(), otherConditions, shallowRow.ToRow()) + require.NoError(t, err) + valid = matched + hasNull = hasNull || null + } + + if valid { + hasMatch = true + break + } + } + if hasMatch { + break + } + } + + // Append result with matched flag + appendToResultChk(leftRow, chunk.Row{}, leftUsedColumns, nil, resultChk) + if hasMatch { + resultChk.AppendInt64(len(leftUsedColumns), 1) + } else { + if hasNull { + resultChk.AppendNull(len(leftUsedColumns)) + } else { + resultChk.AppendInt64(len(leftUsedColumns), 0) + } + } + + if resultChk.IsFull() { + returnChks = append(returnChks, resultChk) + resultChk = chunk.New(resultTypes, sessCtx.GetSessionVars().MaxChunkSize, sessCtx.GetSessionVars().MaxChunkSize) + } + } + } + if resultChk.NumRows() > 0 { + returnChks = append(returnChks, resultChk) + } + return returnChks +} + +func TestLeftOuterSemiJoinProbeBasic(t *testing.T) { + // todo test nullable type after builder support nullable type + tinyTp := types.NewFieldType(mysql.TypeTiny) + tinyTp.AddFlag(mysql.NotNullFlag) + intTp := types.NewFieldType(mysql.TypeLonglong) + intTp.AddFlag(mysql.NotNullFlag) + uintTp := types.NewFieldType(mysql.TypeLonglong) + uintTp.AddFlag(mysql.NotNullFlag) + uintTp.AddFlag(mysql.UnsignedFlag) + stringTp := types.NewFieldType(mysql.TypeVarString) + stringTp.AddFlag(mysql.NotNullFlag) + + lTypes := []*types.FieldType{intTp, stringTp, uintTp, stringTp, tinyTp} + rTypes := []*types.FieldType{intTp, stringTp, uintTp, stringTp, tinyTp} + rTypes = append(rTypes, retTypes...) + rTypes1 := []*types.FieldType{uintTp, stringTp, intTp, stringTp, tinyTp} + rTypes1 = append(rTypes1, rTypes1...) + + rightAsBuildSide := []bool{true} + partitionNumber := 4 + simpleFilter := createSimpleFilter(t) + hasFilter := []bool{false, true} + + testCases := []testCase{ + // normal case + {[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, nil, []int{}, nil, nil, nil}, + // rightUsed is empty + {[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, []int{0, 1, 2, 3}, []int{}, nil, nil, nil}, + // leftUsed is empty + {[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, []int{}, []int{}, nil, nil, nil}, + // both left/right Used are empty + {[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, []int{}, []int{}, nil, nil, nil}, + // both left/right used is part of all columns + {[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, []int{0, 2}, []int{}, nil, nil, nil}, + // int join uint + {[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{uintTp}, lTypes, rTypes1, []int{0, 1, 2, 3}, []int{}, nil, nil, nil}, + // multiple join keys + {[]int{0, 1}, []int{0, 1}, []*types.FieldType{intTp, stringTp}, []*types.FieldType{intTp, stringTp}, lTypes, rTypes, []int{0, 1, 2, 3}, []int{}, nil, nil, nil}, + } + + for _, tc := range testCases { + for _, value := range rightAsBuildSide { + for _, testFilter := range hasFilter { + leftFilter := simpleFilter + if !testFilter { + leftFilter = nil + } + testJoinProbe(t, false, tc.leftKeyIndex, tc.rightKeyIndex, tc.leftKeyTypes, tc.rightKeyTypes, tc.leftTypes, tc.rightTypes, value, tc.leftUsed, + tc.rightUsed, tc.leftUsedByOtherCondition, tc.rightUsedByOtherCondition, leftFilter, nil, tc.otherCondition, partitionNumber, logicalop.LeftOuterSemiJoin, 200) + testJoinProbe(t, false, tc.leftKeyIndex, tc.rightKeyIndex, toNullableTypes(tc.leftKeyTypes), toNullableTypes(tc.rightKeyTypes), + toNullableTypes(tc.leftTypes), toNullableTypes(tc.rightTypes), value, tc.leftUsed, tc.rightUsed, tc.leftUsedByOtherCondition, tc.rightUsedByOtherCondition, + leftFilter, nil, tc.otherCondition, partitionNumber, logicalop.LeftOuterSemiJoin, 200) + } + } + } +} + +func TestLeftOuterSemiJoinProbeAllJoinKeys(t *testing.T) { + tinyTp := types.NewFieldType(mysql.TypeTiny) + tinyTp.AddFlag(mysql.NotNullFlag) + intTp := types.NewFieldType(mysql.TypeLonglong) + intTp.AddFlag(mysql.NotNullFlag) + uintTp := types.NewFieldType(mysql.TypeLonglong) + uintTp.AddFlag(mysql.UnsignedFlag) + uintTp.AddFlag(mysql.NotNullFlag) + yearTp := types.NewFieldType(mysql.TypeYear) + yearTp.AddFlag(mysql.NotNullFlag) + durationTp := types.NewFieldType(mysql.TypeDuration) + durationTp.AddFlag(mysql.NotNullFlag) + enumTp := types.NewFieldType(mysql.TypeEnum) + enumTp.AddFlag(mysql.NotNullFlag) + enumWithIntFlag := types.NewFieldType(mysql.TypeEnum) + enumWithIntFlag.AddFlag(mysql.EnumSetAsIntFlag) + enumWithIntFlag.AddFlag(mysql.NotNullFlag) + setTp := types.NewFieldType(mysql.TypeSet) + setTp.AddFlag(mysql.NotNullFlag) + bitTp := types.NewFieldType(mysql.TypeBit) + bitTp.AddFlag(mysql.NotNullFlag) + jsonTp := types.NewFieldType(mysql.TypeJSON) + jsonTp.AddFlag(mysql.NotNullFlag) + floatTp := types.NewFieldType(mysql.TypeFloat) + floatTp.AddFlag(mysql.NotNullFlag) + doubleTp := types.NewFieldType(mysql.TypeDouble) + doubleTp.AddFlag(mysql.NotNullFlag) + stringTp := types.NewFieldType(mysql.TypeVarString) + stringTp.AddFlag(mysql.NotNullFlag) + datetimeTp := types.NewFieldType(mysql.TypeDatetime) + datetimeTp.AddFlag(mysql.NotNullFlag) + decimalTp := types.NewFieldType(mysql.TypeNewDecimal) + decimalTp.AddFlag(mysql.NotNullFlag) + timestampTp := types.NewFieldType(mysql.TypeTimestamp) + timestampTp.AddFlag(mysql.NotNullFlag) + dateTp := types.NewFieldType(mysql.TypeDate) + dateTp.AddFlag(mysql.NotNullFlag) + binaryStringTp := types.NewFieldType(mysql.TypeBlob) + binaryStringTp.AddFlag(mysql.NotNullFlag) + + lTypes := []*types.FieldType{tinyTp, intTp, uintTp, yearTp, durationTp, enumTp, enumWithIntFlag, setTp, bitTp, jsonTp, floatTp, doubleTp, stringTp, datetimeTp, decimalTp, timestampTp, dateTp, binaryStringTp} + rTypes := lTypes + lUsed := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17} + rUsed := []int{} + joinType := logicalop.LeftOuterSemiJoin + partitionNumber := 4 + + rightAsBuildSide := []bool{true} + + // single key + for i := 0; i < len(lTypes); i++ { + lKeyTypes := []*types.FieldType{lTypes[i]} + rKeyTypes := []*types.FieldType{rTypes[i]} + for _, rightAsBuild := range rightAsBuildSide { + testJoinProbe(t, false, []int{i}, []int{i}, lKeyTypes, rKeyTypes, lTypes, rTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, joinType, 100) + testJoinProbe(t, false, []int{i}, []int{i}, toNullableTypes(lKeyTypes), toNullableTypes(rKeyTypes), toNullableTypes(lTypes), toNullableTypes(rTypes), rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, joinType, 100) + } + } + // composed key + // fixed size, inlined + for _, rightAsBuild := range rightAsBuildSide { + lKeyTypes := []*types.FieldType{intTp, uintTp} + rKeyTypes := []*types.FieldType{intTp, uintTp} + testJoinProbe(t, false, []int{1, 2}, []int{1, 2}, lKeyTypes, rKeyTypes, lTypes, rTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, joinType, 100) + testJoinProbe(t, false, []int{1, 2}, []int{1, 2}, toNullableTypes(lKeyTypes), toNullableTypes(rKeyTypes), toNullableTypes(lTypes), toNullableTypes(rTypes), rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, joinType, 100) + } + // variable size, inlined + for _, rightAsBuild := range rightAsBuildSide { + lKeyTypes := []*types.FieldType{intTp, binaryStringTp} + rKeyTypes := []*types.FieldType{intTp, binaryStringTp} + testJoinProbe(t, false, []int{1, 17}, []int{1, 17}, lKeyTypes, rKeyTypes, lTypes, rTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, joinType, 100) + testJoinProbe(t, false, []int{1, 17}, []int{1, 17}, toNullableTypes(lKeyTypes), toNullableTypes(rKeyTypes), toNullableTypes(lTypes), toNullableTypes(rTypes), rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, joinType, 100) + } + // fixed size, not inlined + for _, rightAsBuild := range rightAsBuildSide { + lKeyTypes := []*types.FieldType{intTp, datetimeTp} + rKeyTypes := []*types.FieldType{intTp, datetimeTp} + testJoinProbe(t, false, []int{1, 13}, []int{1, 13}, lKeyTypes, rKeyTypes, lTypes, rTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, joinType, 100) + testJoinProbe(t, false, []int{1, 13}, []int{1, 13}, toNullableTypes(lKeyTypes), toNullableTypes(rKeyTypes), toNullableTypes(lTypes), toNullableTypes(rTypes), rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, joinType, 100) + } + // variable size, not inlined + for _, rightAsBuild := range rightAsBuildSide { + lKeyTypes := []*types.FieldType{intTp, decimalTp} + rKeyTypes := []*types.FieldType{intTp, decimalTp} + testJoinProbe(t, false, []int{1, 14}, []int{1, 14}, lKeyTypes, rKeyTypes, lTypes, rTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, joinType, 100) + testJoinProbe(t, false, []int{1, 14}, []int{1, 14}, toNullableTypes(lKeyTypes), toNullableTypes(rKeyTypes), toNullableTypes(lTypes), toNullableTypes(rTypes), rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, joinType, 100) + } +} + +func TestLeftOuterSemiJoinProbeOtherCondition(t *testing.T) { + intTp := types.NewFieldType(mysql.TypeLonglong) + intTp.AddFlag(mysql.NotNullFlag) + nullableIntTp := types.NewFieldType(mysql.TypeLonglong) + uintTp := types.NewFieldType(mysql.TypeLonglong) + uintTp.AddFlag(mysql.NotNullFlag) + uintTp.AddFlag(mysql.UnsignedFlag) + stringTp := types.NewFieldType(mysql.TypeVarString) + stringTp.AddFlag(mysql.NotNullFlag) + + lTypes := []*types.FieldType{intTp, intTp, stringTp, uintTp, stringTp} + rTypes := []*types.FieldType{intTp, intTp, stringTp, uintTp, stringTp} + rTypes = append(rTypes, rTypes...) + + tinyTp := types.NewFieldType(mysql.TypeTiny) + a := &expression.Column{Index: 1, RetType: nullableIntTp} + b := &expression.Column{Index: 8, RetType: nullableIntTp} + sf, err := expression.NewFunction(mock.NewContext(), ast.GT, tinyTp, a, b) + require.NoError(t, err, "error when create other condition") + // test condition `a = b` from `a in (select b from t2)` + a2 := &expression.Column{Index: 1, RetType: nullableIntTp, InOperand: true} + b2 := &expression.Column{Index: 8, RetType: nullableIntTp, InOperand: true} + sf2, err := expression.NewFunction(mock.NewContext(), ast.EQ, tinyTp, a2, b2) + require.NoError(t, err, "error when create other condition") + otherCondition := make(expression.CNFExprs, 0) + otherCondition = append(otherCondition, sf) + otherCondition2 := make(expression.CNFExprs, 0) + otherCondition2 = append(otherCondition2, sf2) + joinType := logicalop.LeftOuterSemiJoin + simpleFilter := createSimpleFilter(t) + hasFilter := []bool{false, true} + rightAsBuildSide := []bool{true} + partitionNumber := 4 + rightUsed := []int{} + + for _, rightBuild := range rightAsBuildSide { + for _, testFilter := range hasFilter { + leftFilter := simpleFilter + if !testFilter { + leftFilter = nil + } + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 200) + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 200) + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 200) + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, nil, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 200) + + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition2, partitionNumber, joinType, 200) + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition2, partitionNumber, joinType, 200) + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition2, partitionNumber, joinType, 200) + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, nil, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition2, partitionNumber, joinType, 200) + } + } +} + +func TestLeftOuterSemiJoinProbeWithSel(t *testing.T) { + intTp := types.NewFieldType(mysql.TypeLonglong) + intTp.AddFlag(mysql.NotNullFlag) + nullableIntTp := types.NewFieldType(mysql.TypeLonglong) + uintTp := types.NewFieldType(mysql.TypeLonglong) + uintTp.AddFlag(mysql.NotNullFlag) + uintTp.AddFlag(mysql.UnsignedFlag) + nullableUIntTp := types.NewFieldType(mysql.TypeLonglong) + nullableUIntTp.AddFlag(mysql.UnsignedFlag) + stringTp := types.NewFieldType(mysql.TypeVarString) + stringTp.AddFlag(mysql.NotNullFlag) + + lTypes := []*types.FieldType{intTp, intTp, stringTp, uintTp, stringTp} + rTypes := []*types.FieldType{intTp, intTp, stringTp, uintTp, stringTp} + rTypes = append(rTypes, rTypes...) + + tinyTp := types.NewFieldType(mysql.TypeTiny) + a := &expression.Column{Index: 1, RetType: nullableIntTp} + b := &expression.Column{Index: 8, RetType: nullableUIntTp} + sf, err := expression.NewFunction(mock.NewContext(), ast.GT, tinyTp, a, b) + require.NoError(t, err, "error when create other condition") + otherCondition := make(expression.CNFExprs, 0) + otherCondition = append(otherCondition, sf) + joinType := logicalop.LeftOuterSemiJoin + rightAsBuildSide := []bool{true} + simpleFilter := createSimpleFilter(t) + hasFilter := []bool{false, true} + partitionNumber := 4 + rightUsed := []int{} + + for _, rightBuild := range rightAsBuildSide { + for _, useFilter := range hasFilter { + leftFilter := simpleFilter + if !useFilter { + leftFilter = nil + } + testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 500) + testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 500) + testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 500) + testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, nil, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 500) + } + } +} + +func TestLeftOuterSemiJoinBuildResultFastPath(t *testing.T) { + intTp := types.NewFieldType(mysql.TypeLonglong) + intTp.AddFlag(mysql.NotNullFlag) + nullableIntTp := types.NewFieldType(mysql.TypeLonglong) + uintTp := types.NewFieldType(mysql.TypeLonglong) + uintTp.AddFlag(mysql.NotNullFlag) + uintTp.AddFlag(mysql.UnsignedFlag) + stringTp := types.NewFieldType(mysql.TypeVarString) + stringTp.AddFlag(mysql.NotNullFlag) + + lTypes := []*types.FieldType{intTp, intTp, stringTp, uintTp, stringTp} + rTypes := []*types.FieldType{intTp, intTp, stringTp, uintTp, stringTp} + rTypes = append(rTypes, rTypes...) + + tinyTp := types.NewFieldType(mysql.TypeTiny) + a := &expression.Column{Index: 1, RetType: nullableIntTp} + b := &expression.Column{Index: 8, RetType: nullableIntTp} + sf, err := expression.NewFunction(mock.NewContext(), ast.GT, tinyTp, a, b) + require.NoError(t, err, "error when create other condition") + // test condition `a = b` from `a in (select b from t2)` + a2 := &expression.Column{Index: 1, RetType: nullableIntTp, InOperand: true} + b2 := &expression.Column{Index: 8, RetType: nullableIntTp, InOperand: true} + sf2, err := expression.NewFunction(mock.NewContext(), ast.EQ, tinyTp, a2, b2) + require.NoError(t, err, "error when create other condition") + otherCondition := make(expression.CNFExprs, 0) + otherCondition = append(otherCondition, sf) + otherCondition2 := make(expression.CNFExprs, 0) + otherCondition2 = append(otherCondition2, sf2) + joinType := logicalop.LeftOuterSemiJoin + simpleFilter := createSimpleFilter(t) + hasFilter := []bool{false, true} + rightAsBuildSide := []bool{true} + partitionNumber := 4 + rightUsed := []int{} + + for _, rightBuild := range rightAsBuildSide { + for _, testFilter := range hasFilter { + leftFilter := simpleFilter + if !testFilter { + leftFilter = nil + } + // MockContext set MaxChunkSize to 32, input chunk size should be less than 32 to test fast path + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 30) + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 30) + testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 30) + testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, nil, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition, partitionNumber, joinType, 30) + + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition2, partitionNumber, joinType, 30) + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition2, partitionNumber, joinType, 30) + testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition2, partitionNumber, joinType, 30) + testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, nil, rightUsed, []int{1}, []int{3}, leftFilter, nil, otherCondition2, partitionNumber, joinType, 30) + + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, nil, partitionNumber, joinType, 30) + testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightBuild, []int{}, rightUsed, []int{1}, []int{3}, leftFilter, nil, nil, partitionNumber, joinType, 30) + testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, []int{1, 2, 4}, rightUsed, []int{1}, []int{3}, leftFilter, nil, nil, partitionNumber, joinType, 30) + testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightBuild, nil, rightUsed, []int{1}, []int{3}, leftFilter, nil, nil, partitionNumber, joinType, 30) + } + } +} + +func TestLeftOuterSemiJoinSpill(t *testing.T) { + ctx := mock.NewContext() + ctx.GetSessionVars().InitChunkSize = 32 + ctx.GetSessionVars().MaxChunkSize = 32 + leftDataSource, rightDataSource := buildLeftAndRightDataSource(ctx, leftCols, rightCols, false) + leftDataSourceWithSel, rightDataSourceWithSel := buildLeftAndRightDataSource(ctx, leftCols, rightCols, true) + + intTp := types.NewFieldType(mysql.TypeLonglong) + intTp.AddFlag(mysql.NotNullFlag) + stringTp := types.NewFieldType(mysql.TypeVarString) + stringTp.AddFlag(mysql.NotNullFlag) + + leftTypes := []*types.FieldType{intTp, intTp, intTp, stringTp, intTp} + rightTypes := []*types.FieldType{intTp, intTp, stringTp, intTp, intTp} + + leftKeys := []*expression.Column{ + {Index: 1, RetType: intTp}, + {Index: 3, RetType: stringTp}, + } + rightKeys := []*expression.Column{ + {Index: 0, RetType: intTp}, + {Index: 2, RetType: stringTp}, + } + + tinyTp := types.NewFieldType(mysql.TypeTiny) + a := &expression.Column{Index: 1, RetType: intTp} + b := &expression.Column{Index: 8, RetType: intTp} + sf, err := expression.NewFunction(mock.NewContext(), ast.GT, tinyTp, a, b) + require.NoError(t, err, "error when create other condition") + otherCondition := make(expression.CNFExprs, 0) + otherCondition = append(otherCondition, sf) + + maxRowTableSegmentSize = 100 + spillChunkSize = 100 + + joinType := logicalop.LeftOuterSemiJoin + params := []spillTestParam{ + // basic case + {true, leftKeys, rightKeys, leftTypes, rightTypes, []int{0, 1, 3, 4}, []int{}, nil, nil, nil, []int64{3000000, 1700000, 3500000, 750000, 10000}}, + // with other condition + {true, leftKeys, rightKeys, leftTypes, rightTypes, []int{0, 1, 3, 4}, []int{}, otherCondition, []int{1}, []int{3}, []int64{3000000, 1700000, 4000000, 750000, 10000}}, + } + + for _, param := range params { + testSpill(t, ctx, joinType, leftDataSource, rightDataSource, param) + } + + params2 := []spillTestParam{ + // basic case with sel + {true, leftKeys, rightKeys, leftTypes, rightTypes, []int{0, 1, 3, 4}, []int{}, nil, nil, nil, []int64{1000000, 900000, 1700000, 100000, 10000}}, + // with other condition with sel + {true, leftKeys, rightKeys, leftTypes, rightTypes, []int{0, 1, 3, 4}, []int{}, otherCondition, []int{1}, []int{3}, []int64{1000000, 900000, 2000000, 100000, 10000}}, + } + + for _, param := range params2 { + testSpill(t, ctx, joinType, leftDataSourceWithSel, rightDataSourceWithSel, param) + } +} diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index 4d745c1f0e5e1..b55567c12bb32 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -445,7 +445,7 @@ func VecEvalBool(ctx EvalContext, vecEnabled bool, exprList CNFExprs, input *chu isEQCondFromIn := IsEQCondFromIn(expr) for i := range sel { if isZero[i] == -1 { - if eType != types.ETInt && !isEQCondFromIn { + if eType != types.ETInt || !isEQCondFromIn { continue } // In this case, we set this row to null and let it pass this filter. @@ -457,6 +457,7 @@ func VecEvalBool(ctx EvalContext, vecEnabled bool, exprList CNFExprs, input *chu } if isZero[i] == 0 { + nulls[sel[i]] = false continue } sel[j] = sel[i] // this row passes this filter diff --git a/pkg/planner/core/physical_plans.go b/pkg/planner/core/physical_plans.go index 7d4569f47d7f7..5a680bf2db449 100644 --- a/pkg/planner/core/physical_plans.go +++ b/pkg/planner/core/physical_plans.go @@ -1506,7 +1506,7 @@ func (p *PhysicalHashJoin) CanUseHashJoinV2() bool { return false } switch p.JoinType { - case logicalop.LeftOuterJoin, logicalop.RightOuterJoin, logicalop.InnerJoin: + case logicalop.LeftOuterJoin, logicalop.RightOuterJoin, logicalop.InnerJoin, logicalop.LeftOuterSemiJoin: // null aware join is not supported yet if len(p.LeftNAJoinKeys) > 0 { return false diff --git a/pkg/util/queue/BUILD.bazel b/pkg/util/queue/BUILD.bazel new file mode 100644 index 0000000000000..db8f01ebf757a --- /dev/null +++ b/pkg/util/queue/BUILD.bazel @@ -0,0 +1,17 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "queue", + srcs = ["queue.go"], + importpath = "github.com/pingcap/tidb/pkg/util/queue", + visibility = ["//visibility:public"], +) + +go_test( + name = "queue_test", + timeout = "short", + srcs = ["queue_test.go"], + embed = [":queue"], + flaky = True, + deps = ["@com_github_stretchr_testify//require"], +) diff --git a/pkg/util/queue/queue.go b/pkg/util/queue/queue.go new file mode 100644 index 0000000000000..e4ee939ca01f3 --- /dev/null +++ b/pkg/util/queue/queue.go @@ -0,0 +1,85 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +// Queue is a circular buffer implementation of queue. +type Queue[T any] struct { + elements []T + head int + tail int + size int +} + +// NewQueue creates a new queue with the given capacity. +func NewQueue[T any](capacity int) *Queue[T] { + return &Queue[T]{ + elements: make([]T, capacity), + } +} + +// Push pushes an element to the queue. +func (r *Queue[T]) Push(element T) { + if r.elements == nil { + r.elements = make([]T, 1) + } + + if r.size == len(r.elements) { + // Double capacity when full + newElements := make([]T, len(r.elements)*2) + for i := range r.size { + newElements[i] = r.elements[(r.head+i)%len(r.elements)] + } + r.elements = newElements + r.head = 0 + r.tail = r.size + } + + r.elements[r.tail] = element + r.tail = (r.tail + 1) % len(r.elements) + r.size++ +} + +// Pop pops an element from the queue. +func (r *Queue[T]) Pop() T { + if r.size == 0 { + panic("Queue is empty") + } + element := r.elements[r.head] + r.head = (r.head + 1) % len(r.elements) + r.size-- + return element +} + +// Len returns the number of elements in the queue. +func (r *Queue[T]) Len() int { + return r.size +} + +// IsEmpty returns true if the queue is empty. +func (r *Queue[T]) IsEmpty() bool { + return r.size == 0 +} + +// Clear clears the queue. +func (r *Queue[T]) Clear() { + r.head = 0 + r.tail = 0 + r.size = 0 +} + +// Cap returns the capacity of the queue. +func (r *Queue[T]) Cap() int { + return len(r.elements) +} diff --git a/pkg/util/queue/queue_test.go b/pkg/util/queue/queue_test.go new file mode 100644 index 0000000000000..678247705ff64 --- /dev/null +++ b/pkg/util/queue/queue_test.go @@ -0,0 +1,87 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package queue + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestQueue(t *testing.T) { + t.Run("basic operations", func(t *testing.T) { + q := NewQueue[int](2) + + // Test initial state + require.True(t, q.IsEmpty(), "new queue should be empty") + require.Equal(t, 0, q.Len(), "new queue should have length 0") + require.Equal(t, 2, q.Cap(), "new queue should have capacity 2") + + // Test Push + q.Push(1) + q.Push(2) + require.Equal(t, 2, q.Len(), "queue length should be 2 after pushing 2 elements") + require.False(t, q.IsEmpty(), "queue should not be empty after pushing elements") + + // Test automatic capacity increase + q.Push(3) + require.Equal(t, 4, q.Cap(), "queue capacity should double when full") + + // Test Pop + require.Equal(t, 1, q.Pop(), "first pop should return 1") + require.Equal(t, 2, q.Pop(), "second pop should return 2") + require.Equal(t, 3, q.Pop(), "third pop should return 3") + + // Test empty queue + require.True(t, q.IsEmpty(), "queue should be empty after popping all elements") + }) + + t.Run("clear operation", func(t *testing.T) { + q := NewQueue[string](4) + q.Push("a") + q.Push("b") + q.Push("c") + + q.Clear() + require.True(t, q.IsEmpty(), "queue should be empty after clear") + require.Equal(t, 0, q.Len(), "queue length should be 0 after clear") + }) + + t.Run("panic on empty pop", func(t *testing.T) { + defer func() { + r := recover() + require.NotNil(t, r, "pop on empty queue should panic") + }() + + q := NewQueue[int](1) + q.Pop() // Should panic + }) + + t.Run("circular buffer behavior", func(t *testing.T) { + q := NewQueue[int](3) + q.Push(1) + q.Push(2) + q.Pop() // Remove 1 + q.Push(3) + q.Push(4) // This should wrap around + + // check queue.head and queue.tail + require.Equal(t, 1, q.head, "queue.head should be 1") + require.Equal(t, 1, q.tail, "queue.tail should be 1") + + require.Equal(t, 2, q.Pop(), "expected 2") + require.Equal(t, 3, q.Pop(), "expected 3") + require.Equal(t, 4, q.Pop(), "expected 4") + }) +}