Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exec: handle NULL group by keys in hash aggregator #38900

Merged
merged 1 commit into from
Jul 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pkg/sql/exec/aggregator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,14 @@ func TestAggregatorAllFunctions(t *testing.T) {
aggCols: [][]uint32{{0}, {1}, {}, {1}, {1}, {2}, {2}, {2}, {1}},
colTypes: []types.T{types.Int64, types.Decimal, types.Int64},
input: tuples{
{nil, 1.1, 4},
{0, nil, nil},
{0, 3.1, 5},
{1, nil, nil},
{1, nil, nil},
},
expected: tuples{
{nil, 1.1, 1, 1, 1.1, 4, 4, 4, 1.1},
{0, 3.1, 2, 1, 3.1, 5, 5, 5, 3.1},
{1, nil, 2, 0, nil, nil, nil, nil, nil},
},
Expand Down Expand Up @@ -489,13 +491,14 @@ func TestAggregatorRandom(t *testing.T) {
expNulls = append(expNulls, true)
curGroup++
}
// Keep the inputs small so they are a realistic size. Using a
// large range is not realistic and makes decimal operations
// slower.
aggCol[i] = 2048 * (rng.Float64() - 0.5)

if hasNulls && rng.Float64() < nullProbability {
aggColNulls.SetNull(uint16(i))
} else {
// Keep the inputs small so they are a realistic size. Using a
// large range is not realistic and makes decimal operations
// slower.
aggCol[i] = 2048 * (rng.Float64() - 0.5)
expNulls[curGroup] = false
expCounts[curGroup]++
expSums[curGroup] += aggCol[i]
Expand Down
12 changes: 7 additions & 5 deletions pkg/sql/exec/execgen/cmd/execgen/hashjoiner_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func genHashJoiner(wr io.Writer) error {
assignHash := makeFunctionRegex("_ASSIGN_HASH", 2)
s = assignHash.ReplaceAllString(s, `{{.Global.UnaryAssign "$1" "$2"}}`)

rehash := makeFunctionRegex("_REHASH_BODY", 6)
s = rehash.ReplaceAllString(s, `{{template "rehashBody" buildDict "Global" . "SelInd" $6}}`)
rehash := makeFunctionRegex("_REHASH_BODY", 8)
s = rehash.ReplaceAllString(s, `{{template "rehashBody" buildDict "Global" . "SelInd" $7 "HasNulls" $8}}`)

checkCol := makeFunctionRegex("_CHECK_COL_WITH_NULLS", 7)
s = checkCol.ReplaceAllString(s, `{{template "checkColWithNulls" buildDict "Global" . "SelInd" $7}}`)
Expand All @@ -56,11 +56,13 @@ func genHashJoiner(wr io.Writer) error {
collectNoOuter := makeFunctionRegex("_COLLECT_NO_OUTER", 5)
s = collectNoOuter.ReplaceAllString(s, `{{template "collectNoOuter" buildDict "Global" . "SelInd" $5}}`)

checkColMain := makeFunctionRegex("_CHECK_COL_MAIN", 1)
checkColMain := makeFunctionRegex("_CHECK_COL_MAIN", 5)
s = checkColMain.ReplaceAllString(s, `{{template "checkColMain" .}}`)

checkColBody := makeFunctionRegex("_CHECK_COL_BODY", 8)
s = checkColBody.ReplaceAllString(s, `{{template "checkColBody" buildDict "Global" .Global "SelInd" .SelInd "ProbeHasNulls" $7 "BuildHasNulls" $8}}`)
checkColBody := makeFunctionRegex("_CHECK_COL_BODY", 9)
s = checkColBody.ReplaceAllString(
s,
`{{template "checkColBody" buildDict "Global" .Global "SelInd" .SelInd "ProbeHasNulls" $7 "BuildHasNulls" $8 "AllowNullEquality" $9}}`)

tmpl, err := template.New("hashjoiner_op").Funcs(template.FuncMap{"buildDict": buildDict}).Parse(s)

Expand Down
1 change: 1 addition & 0 deletions pkg/sql/exec/hash_aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func NewHashAggregator(
colTypes,
groupCols,
outCols,
true, /* allowNullEquality */
)

builder := makeHashJoinBuilder(
Expand Down
16 changes: 14 additions & 2 deletions pkg/sql/exec/hashjoiner.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ func (hj *hashJoinEqOp) Init() {
build.sourceTypes,
build.eqCols,
build.outCols,
false, /* allowNullEquality */
)

hj.builder = makeHashJoinBuilder(
Expand Down Expand Up @@ -392,11 +393,19 @@ type hashTable struct {
// key.
differs []bool

// allowNullEquality determines if NULL keys should be treated as equal to
// each other.
allowNullEquality bool

cancelChecker CancelChecker
}

func makeHashTable(
bucketSize uint64, sourceTypes []types.T, eqCols []uint32, outCols []uint32,
bucketSize uint64,
sourceTypes []types.T,
eqCols []uint32,
outCols []uint32,
allowNullEquality bool,
) *hashTable {
// Compute the union of eqCols and outCols and compress vals to only keep the
// important columns.
Expand Down Expand Up @@ -467,6 +476,8 @@ func makeHashTable(

keys: make([]coldata.Vec, len(eqCols)),
buckets: make([]uint64, coldata.BatchSize),

allowNullEquality: allowNullEquality,
}
}

Expand Down Expand Up @@ -577,7 +588,8 @@ func makeHashJoinBuilder(ht *hashTable, spec hashJoinerSourceSpec) *hashJoinBuil
}

// exec executes distinctExec, and then eagerly populates the hashTable's same
// array by probing the hashTable with every single input key.
// array by probing the hashTable with every single input key. This is intended
// for use by the hash aggregator.
func (builder *hashJoinBuilder) exec(ctx context.Context) {
builder.distinctExec(ctx)

Expand Down
63 changes: 50 additions & 13 deletions pkg/sql/exec/hashjoiner_tmpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ func _CHECK_COL_BODY(
nToCheck uint16,
_PROBE_HAS_NULLS bool,
_BUILD_HAS_NULLS bool,
_ALLOW_NULL_EQUALITY bool,
) { // */}}
// {{define "checkColBody"}}
probeIsNull := false
buildIsNull := false
for i := uint16(0); i < nToCheck; i++ {
// keyID of 0 is reserved to represent the end of the next chain.

Expand All @@ -98,11 +101,23 @@ func _CHECK_COL_BODY(
// found.

/* {{if .ProbeHasNulls }} */
if probeVec.Nulls().NullAt(_SEL_IND) {
probeIsNull = probeVec.Nulls().NullAt(_SEL_IND)
/* {{end}} */

/* {{if .BuildHasNulls }} */
buildIsNull = buildVec.Nulls().NullAt64(keyID - 1)
/* {{end}} */

/* {{if .AllowNullEquality}} */
if probeIsNull && buildIsNull {
continue
}
/* {{end}} */
if probeIsNull {
ht.groupID[ht.toCheck[i]] = 0
} else /*{{end}} {{if .BuildHasNulls}} */ if buildVec.Nulls().NullAt64(keyID - 1) {
} else if buildIsNull {
ht.differs[ht.toCheck[i]] = true
} else /*{{end}} */ {
} else {
_CHECK_COL_MAIN(ht, buildKeys, probeKeys, keyID, i)
}
}
Expand All @@ -121,15 +136,21 @@ func _CHECK_COL_WITH_NULLS(
// {{define "checkColWithNulls"}}
if probeVec.MaybeHasNulls() {
if buildVec.MaybeHasNulls() {
_CHECK_COL_BODY(ht, probeVec, buildVec, buildKeys, probeKeys, nToCheck, true, true)
if ht.allowNullEquality {
// The allowNullEquality flag only matters if both vectors have nulls.
// This lets us avoid writing all 2^3 conditional branches.
_CHECK_COL_BODY(ht, probeVec, buildVec, buildKeys, probeKeys, nToCheck, true, true, true)
} else {
_CHECK_COL_BODY(ht, probeVec, buildVec, buildKeys, probeKeys, nToCheck, true, true, false)
}
} else {
_CHECK_COL_BODY(ht, probeVec, buildVec, buildKeys, probeKeys, nToCheck, true, false)
_CHECK_COL_BODY(ht, probeVec, buildVec, buildKeys, probeKeys, nToCheck, true, false, false)
}
} else {
if buildVec.MaybeHasNulls() {
_CHECK_COL_BODY(ht, probeVec, buildVec, buildKeys, probeKeys, nToCheck, false, true)
_CHECK_COL_BODY(ht, probeVec, buildVec, buildKeys, probeKeys, nToCheck, false, true, false)
} else {
_CHECK_COL_BODY(ht, probeVec, buildVec, buildKeys, probeKeys, nToCheck, false, false)
_CHECK_COL_BODY(ht, probeVec, buildVec, buildKeys, probeKeys, nToCheck, false, false, false)
}
}
// {{end}}
Expand All @@ -141,12 +162,19 @@ func _REHASH_BODY(
ht *hashTable,
buckets []uint64,
keys []interface{},
nulls *coldata.Nulls,
nKeys uint64,
_SEL_STRING string,
_HAS_NULLS bool,
) { // */}}
// {{define "rehashBody"}}
for i := uint64(0); i < nKeys; i++ {
ht.cancelChecker.check(ctx)
// {{ if .HasNulls }}
if nulls.NullAt(uint16(_SEL_IND)) {
continue
}
// {{ end }}
v := keys[_SEL_IND]
p := uintptr(buckets[i])
_ASSIGN_HASH(p, v)
Expand Down Expand Up @@ -270,11 +298,19 @@ func (ht *hashTable) rehash(
switch t {
// {{range $hashType := .HashTemplate}}
case _TYPES_T:
keys := col._TemplateType()
if sel != nil {
_REHASH_BODY(ctx, ht, buckets, keys, nKeys, "sel[i]")
keys, nulls := col._TemplateType(), col.Nulls()
if col.MaybeHasNulls() {
if sel != nil {
_REHASH_BODY(ctx, ht, buckets, keys, nulls, nKeys, "sel[i]", true)
} else {
_REHASH_BODY(ctx, ht, buckets, keys, nulls, nKeys, "i", true)
}
} else {
_REHASH_BODY(ctx, ht, buckets, keys, nKeys, "i")
if sel != nil {
_REHASH_BODY(ctx, ht, buckets, keys, nulls, nKeys, "sel[i]", false)
} else {
_REHASH_BODY(ctx, ht, buckets, keys, nulls, nKeys, "i", false)
}
}

// {{end}}
Expand All @@ -285,8 +321,9 @@ func (ht *hashTable) rehash(

// checkCol determines if the current key column in the groupID buckets matches
// the specified equality column key. If there is a match, then the key is added
// to differs. If the bucket has reached the end, the key is rejected. If any
// element in the key is null, then there is no match.
// to differs. If the bucket has reached the end, the key is rejected. If the
// hashTable disallows null equality, then if any element in the key is null,
// there is no match.
func (ht *hashTable) checkCol(t types.T, keyColIdx int, nToCheck uint16, sel []uint16) {
switch t {
// {{range $neType := .NETemplate}}
Expand Down
25 changes: 23 additions & 2 deletions pkg/sql/exec/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import (
"reflect"
"sort"
"testing"
"testing/quick"

"github.com/cockroachdb/apd"
"github.com/cockroachdb/cockroach/pkg/sql/exec/coldata"
"github.com/cockroachdb/cockroach/pkg/sql/exec/types"
"github.com/cockroachdb/cockroach/pkg/util/randutil"
Expand All @@ -41,6 +43,9 @@ var orderedVerifier verifier = (*opTestOutput).Verify
// error if they aren't equal by set comparison (irrespective of order).
var unorderedVerifier verifier = (*opTestOutput).VerifyAnyOrder

// decimalType is the reflection type for apd.Decimal.
var decimalType = reflect.TypeOf(apd.Decimal{})

// runTests is a helper that automatically runs your tests with varied batch
// sizes and with and without a random selection vector.
// tups is the set of input tuples.
Expand Down Expand Up @@ -271,18 +276,34 @@ func (s *opTestInput) Next(context.Context) coldata.Batch {
s.batch.ColVec(i).Nulls().UnsetNulls()
}

rng := rand.New(rand.NewSource(123))

for i := range s.typs {
vec := s.batch.ColVec(i)
typ := reflect.TypeOf(vec.Col()).Elem()
// Automatically convert the Go values into exec.Type slice elements using
// reflection. This is slow, but acceptable for tests.
col := reflect.ValueOf(vec.Col())
for j := uint16(0); j < batchSize; j++ {
outputIdx := s.selection[j]
if tups[j][i] == nil {
// Set garbage data in the value to make sure NULL gets handled
// correctly.
vec.Nulls().SetNull(outputIdx)
if typ.AssignableTo(decimalType) {
d := apd.Decimal{}
_, err := d.SetFloat64(rng.Float64())
if err != nil {
panic(fmt.Sprintf("%v", err))
}
col.Index(int(outputIdx)).Set(reflect.ValueOf(d))
} else if val, ok := quick.Value(typ, rng); ok {
col.Index(int(outputIdx)).Set(val)
} else {
panic(fmt.Sprintf("could not generate a random value of type %s\n.", typ.Name()))
}
} else {
col.Index(int(outputIdx)).Set(
reflect.ValueOf(tups[j][i]).Convert(reflect.TypeOf(vec.Col()).Elem()))
col.Index(int(outputIdx)).Set(reflect.ValueOf(tups[j][i]).Convert(typ))
}
}
}
Expand Down