diff --git a/executor/hash_table.go b/executor/hash_table.go index a70b80954c30d..d6d7d12112352 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -153,7 +153,7 @@ func (c *hashRowContainer) GetMatchedRows(probeKey uint64, probeRow chunk.Row, h } func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRow chunk.Row, - probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) { + probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildColPos, needCheckProbeColPos []int, needCheckBuildTypes, needCheckProbeTypes []*types.FieldType) ([]chunk.Row, error) { // for NAAJ probe row with null, we should match them with all build rows. var ( ok bool @@ -180,16 +180,20 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo // else like // (null, 1, 2), we should use the not-null probe bit to filter rows. Only fetch rows like // ( ? , 1, 2), that exactly with value as 1 and 2 in the second and third join key column. - needCheckProbeRowPos = needCheckProbeRowPos[:0] - needCheckBuildRowPos = needCheckBuildRowPos[:0] + needCheckProbeColPos = needCheckProbeColPos[:0] + needCheckBuildColPos = needCheckBuildColPos[:0] + needCheckBuildTypes = needCheckBuildTypes[:0] + needCheckProbeTypes = needCheckProbeTypes[:0] keyColLen := len(c.hCtx.naKeyColIdx) for i := 0; i < keyColLen; i++ { // since all bucket is from hash table (Not Null), so the buildSideNullBits check is eliminated. if probeKeyNullBits.UnsafeIsSet(i) { continue } - needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i]) - needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i]) + needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i]) + needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i]) + needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i]) + needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i]) } } var mayMatchedRow chunk.Row @@ -200,7 +204,7 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo } if probeKeyNullBits != nil && len(probeHCtx.naKeyColIdx) > 1 { // check the idxs-th value of the join columns. - ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos) + ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos) if err != nil { return nil, err } @@ -287,7 +291,7 @@ func (c *hashRowContainer) GetMatchedRowsAndPtrs(probeKey uint64, probeRow chunk } func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRow chunk.Row, - probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) { + probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildColPos, needCheckProbeColPos []int, needCheckBuildTypes, needCheckProbeTypes []*types.FieldType) ([]chunk.Row, error) { var ( ok bool err error @@ -306,8 +310,10 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo // case2: left side (probe side) don't have null // left side key <1, 2>, actually we should fetch <1,null>, , from the null bucket because // case like <3,null> is obviously not matched with the probe key. - needCheckProbeRowPos = needCheckProbeRowPos[:0] - needCheckBuildRowPos = needCheckBuildRowPos[:0] + needCheckProbeColPos = needCheckProbeColPos[:0] + needCheckBuildColPos = needCheckBuildColPos[:0] + needCheckBuildTypes = needCheckBuildTypes[:0] + needCheckProbeTypes = needCheckProbeTypes[:0] keyColLen := len(c.hCtx.naKeyColIdx) if probeKeyNullBits != nil { // when the probeKeyNullBits is not nil, it means the probe key has null values, where we should distinguish @@ -325,11 +331,13 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo if probeKeyNullBits.UnsafeIsSet(i) || nullEntry.nullBitMap.UnsafeIsSet(i) { continue } - needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i]) - needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i]) + needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i]) + needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i]) + needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i]) + needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i]) } // check the idxs-th value of the join columns. - ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos) + ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos) if err != nil { return nil, err } @@ -346,11 +354,13 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo if nullEntry.nullBitMap.UnsafeIsSet(i) { continue } - needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i]) - needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i]) + needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i]) + needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i]) + needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i]) + needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i]) } // check the idxs-th value of the join columns. - ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos) + ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos) if err != nil { return nil, err } @@ -366,6 +376,11 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo // matchJoinKey checks if join keys of buildRow and probeRow are logically equal. func (c *hashRowContainer) matchJoinKey(buildRow, probeRow chunk.Row, probeHCtx *hashContext) (ok bool, err error) { + if len(c.hCtx.naKeyColIdx) > 0 { + return codec.EqualChunkRow(c.sc, + buildRow, c.hCtx.allTypes, c.hCtx.naKeyColIdx, + probeRow, probeHCtx.allTypes, probeHCtx.naKeyColIdx) + } return codec.EqualChunkRow(c.sc, buildRow, c.hCtx.allTypes, c.hCtx.keyColIdx, probeRow, probeHCtx.allTypes, probeHCtx.keyColIdx) diff --git a/executor/join.go b/executor/join.go index 0f18f523b11d5..6ac59cc15d3f3 100644 --- a/executor/join.go +++ b/executor/join.go @@ -98,8 +98,10 @@ type probeWorker struct { rowIters *chunk.Iterator4Slice rowContainerForProbe *hashRowContainer // for every naaj probe worker, pre-allocate the int slice for store the join column index to check. - needCheckBuildRowPos []int - needCheckProbeRowPos []int + needCheckBuildColPos []int + needCheckProbeColPos []int + needCheckBuildTypes []*types.FieldType + needCheckProbeTypes []*types.FieldType probeChkResourceCh chan *probeChkResource joinChkResourceCh chan *chunk.Chunk probeResultCh chan *chunk.Chunk @@ -177,8 +179,10 @@ func (e *HashJoinExec) Close() error { for _, w := range e.probeWorkers { w.buildSideRows = nil w.buildSideRowPtrs = nil - w.needCheckBuildRowPos = nil - w.needCheckProbeRowPos = nil + w.needCheckBuildColPos = nil + w.needCheckProbeColPos = nil + w.needCheckBuildTypes = nil + w.needCheckProbeTypes = nil w.joinChkResourceCh = nil } @@ -605,7 +609,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK } } // step2: match the null bucket secondly. - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows = w.buildSideRows if err != nil { joinResult.err = err @@ -650,7 +654,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK // case1: NOT IN (empty set): ----------------------> result is . // case2: NOT IN (at least a valid inner row) ------------------> result is . // Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows) - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows := w.buildSideRows if err != nil { joinResult.err = err @@ -680,7 +684,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK } } // Step2: match all hash table bucket build rows (use probeKeyNullBits to filter if any). - w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows = w.buildSideRows if err != nil { joinResult.err = err @@ -729,7 +733,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey if probeKeyNullBits == nil { // step1: match null bucket first. // need fetch the "valid" rows every time. (nullBits map check is necessary) - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows := w.buildSideRows if err != nil { joinResult.err = err @@ -804,7 +808,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey // case1: NOT IN (empty set): ----------------------> accept rhs row. // case2: NOT IN (at least a valid inner row) ------------------> unknown result, refuse rhs row. // Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows) - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows := w.buildSideRows if err != nil { joinResult.err = err @@ -834,7 +838,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey } } // Step2: match all hash table bucket build rows. - w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows = w.buildSideRows if err != nil { joinResult.err = err diff --git a/executor/join_test.go b/executor/join_test.go index 976897b086efe..35f4315439a75 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -681,6 +681,7 @@ func TestUsingAndNaturalJoinSchema(t *testing.T) { } } +<<<<<<< HEAD func TestNaturalJoin(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) @@ -804,10 +805,14 @@ AND b44=a42`) } func TestSubquerySameTable(t *testing.T) { +======= +func TestTiDBNAAJ(t *testing.T) { +>>>>>>> 7c05f82212a (executor: fix naaj panic caused by wrong field types check (#42482)) store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) tk.MustExec("use test") tk.MustExec("drop table if exists t") +<<<<<<< HEAD tk.MustExec("create table t (a int)") tk.MustExec("insert t values (1), (2)") result := tk.MustQuery("select a from t where exists(select 1 from t as x where x.a < t.a)") @@ -2891,4 +2896,14 @@ func TestOuterJoin(t *testing.T) { "3 2 3 4", ), ) +======= + tk.MustExec("set @@session.tidb_enable_null_aware_anti_join=0;") + tk.MustExec("create table t(a decimal(40,0), b bigint(20) not null);") + tk.MustExec("insert into t values(7,8),(7,8),(3,4),(3,4),(9,2),(9,2),(2,0),(2,0),(0,4),(0,4),(8,8),(8,8),(6,1),(6,1),(NULL, 0),(NULL,0);") + tk.MustQuery("select ( table1 . a , table1 . b ) NOT IN ( SELECT 3 , 2 UNION SELECT 9, 2 ) AS field2 from t as table1 order by field2;").Check(testkit.Rows( + "0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1")) + tk.MustExec("set @@session.tidb_enable_null_aware_anti_join=1;") + tk.MustQuery("select ( table1 . a , table1 . b ) NOT IN ( SELECT 3 , 2 UNION SELECT 9, 2 ) AS field2 from t as table1 order by field2;").Check(testkit.Rows( + "0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1")) +>>>>>>> 7c05f82212a (executor: fix naaj panic caused by wrong field types check (#42482)) }