Skip to content

Commit

Permalink
executor: fix naaj panic caused by wrong field types check (#42482) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot committed Apr 10, 2023
1 parent c08796f commit 3b3cc98
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
45 changes: 30 additions & 15 deletions executor/hash_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (c *hashRowContainer) GetMatchedRows(probeKey uint64, probeRow chunk.Row, h
}

func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRow chunk.Row,
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) {
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildColPos, needCheckProbeColPos []int, needCheckBuildTypes, needCheckProbeTypes []*types.FieldType) ([]chunk.Row, error) {
// for NAAJ probe row with null, we should match them with all build rows.
var (
ok bool
Expand All @@ -180,16 +180,20 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo
// else like
// (null, 1, 2), we should use the not-null probe bit to filter rows. Only fetch rows like
// ( ? , 1, 2), that exactly with value as 1 and 2 in the second and third join key column.
needCheckProbeRowPos = needCheckProbeRowPos[:0]
needCheckBuildRowPos = needCheckBuildRowPos[:0]
needCheckProbeColPos = needCheckProbeColPos[:0]
needCheckBuildColPos = needCheckBuildColPos[:0]
needCheckBuildTypes = needCheckBuildTypes[:0]
needCheckProbeTypes = needCheckProbeTypes[:0]
keyColLen := len(c.hCtx.naKeyColIdx)
for i := 0; i < keyColLen; i++ {
// since all bucket is from hash table (Not Null), so the buildSideNullBits check is eliminated.
if probeKeyNullBits.UnsafeIsSet(i) {
continue
}
needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i])
needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i])
needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i])
needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i])
needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i])
needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i])
}
}
var mayMatchedRow chunk.Row
Expand All @@ -200,7 +204,7 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo
}
if probeKeyNullBits != nil && len(probeHCtx.naKeyColIdx) > 1 {
// check the idxs-th value of the join columns.
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos)
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -287,7 +291,7 @@ func (c *hashRowContainer) GetMatchedRowsAndPtrs(probeKey uint64, probeRow chunk
}

func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRow chunk.Row,
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) {
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildColPos, needCheckProbeColPos []int, needCheckBuildTypes, needCheckProbeTypes []*types.FieldType) ([]chunk.Row, error) {
var (
ok bool
err error
Expand All @@ -306,8 +310,10 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo
// case2: left side (probe side) don't have null
// left side key <1, 2>, actually we should fetch <1,null>, <null, 2>, <null, null> from the null bucket because
// case like <3,null> is obviously not matched with the probe key.
needCheckProbeRowPos = needCheckProbeRowPos[:0]
needCheckBuildRowPos = needCheckBuildRowPos[:0]
needCheckProbeColPos = needCheckProbeColPos[:0]
needCheckBuildColPos = needCheckBuildColPos[:0]
needCheckBuildTypes = needCheckBuildTypes[:0]
needCheckProbeTypes = needCheckProbeTypes[:0]
keyColLen := len(c.hCtx.naKeyColIdx)
if probeKeyNullBits != nil {
// when the probeKeyNullBits is not nil, it means the probe key has null values, where we should distinguish
Expand All @@ -325,11 +331,13 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo
if probeKeyNullBits.UnsafeIsSet(i) || nullEntry.nullBitMap.UnsafeIsSet(i) {
continue
}
needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i])
needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i])
needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i])
needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i])
needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i])
needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i])
}
// check the idxs-th value of the join columns.
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos)
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos)
if err != nil {
return nil, err
}
Expand All @@ -346,11 +354,13 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo
if nullEntry.nullBitMap.UnsafeIsSet(i) {
continue
}
needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i])
needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i])
needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i])
needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i])
needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i])
needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i])
}
// check the idxs-th value of the join columns.
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos)
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos)
if err != nil {
return nil, err
}
Expand All @@ -366,6 +376,11 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo

// matchJoinKey checks if join keys of buildRow and probeRow are logically equal.
func (c *hashRowContainer) matchJoinKey(buildRow, probeRow chunk.Row, probeHCtx *hashContext) (ok bool, err error) {
if len(c.hCtx.naKeyColIdx) > 0 {
return codec.EqualChunkRow(c.sc,
buildRow, c.hCtx.allTypes, c.hCtx.naKeyColIdx,
probeRow, probeHCtx.allTypes, probeHCtx.naKeyColIdx)
}
return codec.EqualChunkRow(c.sc,
buildRow, c.hCtx.allTypes, c.hCtx.keyColIdx,
probeRow, probeHCtx.allTypes, probeHCtx.keyColIdx)
Expand Down
24 changes: 14 additions & 10 deletions executor/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ type probeWorker struct {
rowIters *chunk.Iterator4Slice
rowContainerForProbe *hashRowContainer
// for every naaj probe worker, pre-allocate the int slice for store the join column index to check.
needCheckBuildRowPos []int
needCheckProbeRowPos []int
needCheckBuildColPos []int
needCheckProbeColPos []int
needCheckBuildTypes []*types.FieldType
needCheckProbeTypes []*types.FieldType
probeChkResourceCh chan *probeChkResource
joinChkResourceCh chan *chunk.Chunk
probeResultCh chan *chunk.Chunk
Expand Down Expand Up @@ -176,8 +178,10 @@ func (e *HashJoinExec) Close() error {
for _, w := range e.probeWorkers {
w.buildSideRows = nil
w.buildSideRowPtrs = nil
w.needCheckBuildRowPos = nil
w.needCheckProbeRowPos = nil
w.needCheckBuildColPos = nil
w.needCheckProbeColPos = nil
w.needCheckBuildTypes = nil
w.needCheckProbeTypes = nil
w.joinChkResourceCh = nil
}

Expand Down Expand Up @@ -604,7 +608,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK
}
}
// step2: match the null bucket secondly.
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows = w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -649,7 +653,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK
// case1: <?, null> NOT IN (empty set): ----------------------> result is <rhs, 1>.
// case2: <?, null> NOT IN (at least a valid inner row) ------------------> result is <rhs, null>.
// Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows := w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -679,7 +683,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK
}
}
// Step2: match all hash table bucket build rows (use probeKeyNullBits to filter if any).
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows = w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -728,7 +732,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey
if probeKeyNullBits == nil {
// step1: match null bucket first.
// need fetch the "valid" rows every time. (nullBits map check is necessary)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows := w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -803,7 +807,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey
// case1: <?, null> NOT IN (empty set): ----------------------> accept rhs row.
// case2: <?, null> NOT IN (at least a valid inner row) ------------------> unknown result, refuse rhs row.
// Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows := w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -833,7 +837,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey
}
}
// Step2: match all hash table bucket build rows.
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows = w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down
15 changes: 15 additions & 0 deletions executor/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2909,3 +2909,18 @@ func TestCartesianJoinPanic(t *testing.T) {
require.NotNil(t, err)
require.True(t, strings.Contains(err.Error(), "Out Of Memory Quota!"))
}

func TestTiDBNAAJ(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("set @@session.tidb_enable_null_aware_anti_join=0;")
tk.MustExec("create table t(a decimal(40,0), b bigint(20) not null);")
tk.MustExec("insert into t values(7,8),(7,8),(3,4),(3,4),(9,2),(9,2),(2,0),(2,0),(0,4),(0,4),(8,8),(8,8),(6,1),(6,1),(NULL, 0),(NULL,0);")
tk.MustQuery("select ( table1 . a , table1 . b ) NOT IN ( SELECT 3 , 2 UNION SELECT 9, 2 ) AS field2 from t as table1 order by field2;").Check(testkit.Rows(
"0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1"))
tk.MustExec("set @@session.tidb_enable_null_aware_anti_join=1;")
tk.MustQuery("select ( table1 . a , table1 . b ) NOT IN ( SELECT 3 , 2 UNION SELECT 9, 2 ) AS field2 from t as table1 order by field2;").Check(testkit.Rows(
"0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1"))
}

0 comments on commit 3b3cc98

Please sign in to comment.