Skip to content

Commit

Permalink
fix naaj panic caused by field types miss-use
Browse files Browse the repository at this point in the history
Signed-off-by: AilinKid <314806019@qq.com>
  • Loading branch information
AilinKid committed Mar 23, 2023
1 parent 273763b commit d1f37cc
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
20 changes: 15 additions & 5 deletions executor/hash_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (c *hashRowContainer) GetMatchedRows(probeKey uint64, probeRow chunk.Row, h
}

func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRow chunk.Row,
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) {
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int, needCheckBuildTypes, needCheckProbeTypes []*types.FieldType) ([]chunk.Row, error) {
// for NAAJ probe row with null, we should match them with all build rows.
var (
ok bool
Expand Down Expand Up @@ -182,14 +182,18 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo
// ( ? , 1, 2), that exactly with value as 1 and 2 in the second and third join key column.
needCheckProbeRowPos = needCheckProbeRowPos[:0]
needCheckBuildRowPos = needCheckBuildRowPos[:0]
needCheckBuildTypes = needCheckBuildTypes[:0]
needCheckProbeTypes = needCheckProbeTypes[:0]
keyColLen := len(c.hCtx.naKeyColIdx)
for i := 0; i < keyColLen; i++ {
// since all bucket is from hash table (Not Null), so the buildSideNullBits check is eliminated.
if probeKeyNullBits.UnsafeIsSet(i) {
continue
}
needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i])
needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i])
needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i])
needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i])
}
}
var mayMatchedRow chunk.Row
Expand All @@ -200,7 +204,7 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo
}
if probeKeyNullBits != nil && len(probeHCtx.naKeyColIdx) > 1 {
// check the idxs-th value of the join columns.
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos)
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildRowPos, probeSideRow, needCheckProbeTypes, needCheckProbeRowPos)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -287,7 +291,7 @@ func (c *hashRowContainer) GetMatchedRowsAndPtrs(probeKey uint64, probeRow chunk
}

func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRow chunk.Row,
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) {
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int, needCheckBuildTypes, needCheckProbeTypes []*types.FieldType) ([]chunk.Row, error) {
var (
ok bool
err error
Expand All @@ -308,6 +312,8 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo
// case like <3,null> is obviously not matched with the probe key.
needCheckProbeRowPos = needCheckProbeRowPos[:0]
needCheckBuildRowPos = needCheckBuildRowPos[:0]
needCheckBuildTypes = needCheckBuildTypes[:0]
needCheckProbeTypes = needCheckProbeTypes[:0]
keyColLen := len(c.hCtx.naKeyColIdx)
if probeKeyNullBits != nil {
// when the probeKeyNullBits is not nil, it means the probe key has null values, where we should distinguish
Expand All @@ -326,10 +332,12 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo
continue
}
needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i])
needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i])
needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i])
needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i])
}
// check the idxs-th value of the join columns.
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos)
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildRowPos, probeSideRow, needCheckProbeTypes, needCheckProbeRowPos)
if err != nil {
return nil, err
}
Expand All @@ -347,10 +355,12 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo
continue
}
needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i])
needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i])
needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i])
needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i])
}
// check the idxs-th value of the join columns.
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos)
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildRowPos, probeSideRow, needCheckProbeTypes, needCheckProbeRowPos)
if err != nil {
return nil, err
}
Expand Down
16 changes: 10 additions & 6 deletions executor/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ type probeWorker struct {
// for every naaj probe worker, pre-allocate the int slice for store the join column index to check.
needCheckBuildRowPos []int
needCheckProbeRowPos []int
needCheckBuildTypes []*types.FieldType
needCheckProbeTypes []*types.FieldType
probeChkResourceCh chan *probeChkResource
joinChkResourceCh chan *chunk.Chunk
probeResultCh chan *chunk.Chunk
Expand Down Expand Up @@ -179,6 +181,8 @@ func (e *HashJoinExec) Close() error {
w.buildSideRowPtrs = nil
w.needCheckBuildRowPos = nil
w.needCheckProbeRowPos = nil
w.needCheckBuildTypes = nil
w.needCheckProbeTypes = nil
w.joinChkResourceCh = nil
}

Expand Down Expand Up @@ -605,7 +609,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK
}
}
// step2: match the null bucket secondly.
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows = w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -650,7 +654,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK
// case1: <?, null> NOT IN (empty set): ----------------------> result is <rhs, 1>.
// case2: <?, null> NOT IN (at least a valid inner row) ------------------> result is <rhs, null>.
// Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows := w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -680,7 +684,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK
}
}
// Step2: match all hash table bucket build rows (use probeKeyNullBits to filter if any).
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows = w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -729,7 +733,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey
if probeKeyNullBits == nil {
// step1: match null bucket first.
// need fetch the "valid" rows every time. (nullBits map check is necessary)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows := w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -804,7 +808,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey
// case1: <?, null> NOT IN (empty set): ----------------------> accept rhs row.
// case2: <?, null> NOT IN (at least a valid inner row) ------------------> unknown result, refuse rhs row.
// Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows := w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -834,7 +838,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey
}
}
// Step2: match all hash table bucket build rows.
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows = w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down
15 changes: 15 additions & 0 deletions executor/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,18 @@ func TestUsingAndNaturalJoinSchema(t *testing.T) {
tk.MustQuery(tt).Sort().Check(testkit.Rows(output[i].Res...))
}
}

func TestTiDBNAAJ(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("set @@session.tidb_enable_null_aware_anti_join=0;")
tk.MustExec("create table t(a decimal(40,0), b bigint(20) not null);")
tk.MustExec("insert into t values(7,8),(7,8),(3,4),(3,4),(9,2),(9,2),(2,0),(2,0),(0,4),(0,4),(8,8),(8,8),(6,1),(6,1),(NULL, 0),(NULL,0);")
tk.MustQuery("select ( table1 . a , table1 . b ) NOT IN ( SELECT 3 , 2 UNION SELECT 9, 2 ) AS field2 from t as table1 order by field2;").Check(testkit.Rows(
"0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1"))
tk.MustExec("set @@session.tidb_enable_null_aware_anti_join=1;")
tk.MustQuery("select ( table1 . a , table1 . b ) NOT IN ( SELECT 3 , 2 UNION SELECT 9, 2 ) AS field2 from t as table1 order by field2;").Check(testkit.Rows(
"0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1"))
}

0 comments on commit d1f37cc

Please sign in to comment.