Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: prepare grouping function for roll up #42464

Merged
merged 24 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions DEPS.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3402,8 +3402,9 @@ def go_deps():
name = "com_github_pingcap_tipb",
build_file_proto_mode = "disable_global",
importpath = "github.com/pingcap/tipb",
sum = "h1:CeeMOq1aHPAhXrw4eYXtQRyWOFlbfqK1+3f9Iop4IfU=",
version = "v0.0.0-20230310043643-5362260ee6f7",
replace = "github.com/pingcap/tipb",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the replace = xxxx needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the replace = xxxx needed?

I will delete it before merge.

sum = "h1:6KmfIc9QO7ixhD0pFDtqrJZ5m/gdWKmKDzYInXBZA94=",
version = "v0.0.0-20230322022145-dc802b917d4e",
)
go_repository(
name = "com_github_pkg_browser",
Expand Down
2 changes: 1 addition & 1 deletion executor/showtest/show_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ func TestShowBuiltin(t *testing.T) {
res := tk.MustQuery("show builtins;")
require.NotNil(t, res)
rows := res.Rows()
const builtinFuncNum = 287
const builtinFuncNum = 288
require.Equal(t, builtinFuncNum, len(rows))
require.Equal(t, rows[0][0].(string), "abs")
require.Equal(t, rows[builtinFuncNum-1][0].(string), "yearweek")
Expand Down
2 changes: 2 additions & 0 deletions expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ go_library(
"builtin_encryption.go",
"builtin_encryption_vec.go",
"builtin_func_param.go",
"builtin_grouping.go",
"builtin_ilike.go",
"builtin_ilike_vec.go",
"builtin_info.go",
Expand Down Expand Up @@ -142,6 +143,7 @@ go_test(
"builtin_control_vec_generated_test.go",
"builtin_encryption_test.go",
"builtin_encryption_vec_test.go",
"builtin_grouping_test.go",
"builtin_ilike_test.go",
"builtin_info_test.go",
"builtin_info_vec_test.go",
Expand Down
1 change: 1 addition & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ var funcs = map[string]functionClass{
ast.UUIDToBin: &uuidToBinFunctionClass{baseFunctionClass{ast.UUIDToBin, 1, 2}},
ast.BinToUUID: &binToUUIDFunctionClass{baseFunctionClass{ast.BinToUUID, 1, 2}},
ast.TiDBShard: &tidbShardFunctionClass{baseFunctionClass{ast.TiDBShard, 1, 1}},
ast.Grouping: &groupingFunctionClass{baseFunctionClass{ast.Grouping, 1, 1}},

ast.GetLock: &lockFunctionClass{baseFunctionClass{ast.GetLock, 2, 2}},
ast.ReleaseLock: &releaseLockFunctionClass{baseFunctionClass{ast.ReleaseLock, 1, 1}},
Expand Down
186 changes: 186 additions & 0 deletions expression/builtin_grouping.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
"github.com/gogo/protobuf/proto"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tipb/go-tipb"
)

var (
_ functionClass = &groupingFunctionClass{}
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
)

var (
_ builtinFunc = &builtinGroupingSig{}
)

type groupingFunctionClass struct {
baseFunctionClass
}

func (c *groupingFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
argTp := []types.EvalType{types.ETInt}
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTp...)
if err != nil {
return nil, err
}
bf.tp.SetFlen(1)
sig := &builtinGroupingSig{bf, 0, map[int64]struct{}{}}
sig.setPbCode(tipb.ScalarFuncSig_GroupingSig)
return sig, nil
}

type builtinGroupingSig struct {
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
baseBuiltinFunc

// TODO these are two temporary fields for tests
windtalker marked this conversation as resolved.
Show resolved Hide resolved
version uint32
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
groupingIDs map[int64]struct{}
}

func (b *builtinGroupingSig) SetMetaVersion(version uint32) {
b.version = version
}

func (b *builtinGroupingSig) SetMetaGroupingIDs(groupingIDs map[int64]struct{}) {
b.groupingIDs = groupingIDs
}

func (b *builtinGroupingSig) getMetaVersion() uint32 {
return b.version
}

// metadata returns the metadata of grouping functions
func (b *builtinGroupingSig) metadata() proto.Message {
args := &tipb.GroupingFunctionMetadata{
// TODO
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
}
return args
}

func (b *builtinGroupingSig) Clone() builtinFunc {
newSig := &builtinGroupingSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.version = b.version
newSig.groupingIDs = b.groupingIDs
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
return newSig
}

func (b *builtinGroupingSig) getMetaGroupingIDs() map[int64]struct{} {
return b.groupingIDs
}

func (b *builtinGroupingSig) getMetaGroupingID() int64 {
var metaGroupingID int64
grouping_ids := b.getMetaGroupingIDs()
for key := range grouping_ids {
metaGroupingID = key
}
return metaGroupingID
}

func (b *builtinGroupingSig) checkMetadata() error {
version := b.getMetaVersion()
grouping_ids := b.getMetaGroupingIDs()
if version < 1 || version > 3 {
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
return errors.Errorf("Version of meta data in grouping function is invalid. input version: %d", version)
} else if (version == 1 || version == 2) && len(grouping_ids) != 1 {
return errors.Errorf("Invalid number of groupingID. version: %d, number of groupingID: %d", version, len(b.groupingIDs))
}
return nil
}

func (b *builtinGroupingSig) groupingImplV1(groupingID int64, metaGroupingID int64) int64 {
if groupingID&metaGroupingID > 0 {
return 1
}
return 0
}

func (b *builtinGroupingSig) groupingImplV2(groupingID int64, metaGroupingID int64) int64 {
if groupingID > metaGroupingID {
return 1
}
return 0
}

func (b *builtinGroupingSig) groupingImplV3(groupingID int64) int64 {
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
grouping_ids := b.getMetaGroupingIDs()
_, ok := grouping_ids[groupingID]
if ok {
return 0
}
return 1
}

func (b *builtinGroupingSig) grouping(groupingID int64) int64 {
switch b.version {
case 1:
return b.groupingImplV1(groupingID, b.getMetaGroupingID())
case 2:
return b.groupingImplV2(groupingID, b.getMetaGroupingID())
case 3:
return b.groupingImplV3(groupingID)
}
return 0
}

// evalInt evals a builtinGroupingSig.
func (b *builtinGroupingSig) evalInt(row chunk.Row) (int64, bool, error) {
err := b.checkMetadata()
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return 0, false, err
}

groupingID, isNull, err := b.args[0].EvalInt(b.ctx, row)
if isNull || err != nil {
return 0, isNull, err
}

return b.grouping(groupingID), false, nil
}

func (b *builtinGroupingSig) groupingVec(groupingIds *chunk.Column, rowNum int, result *chunk.Column) {
result.ResizeInt64(rowNum, false)
resContainer := result.Int64s()
for i := 0; i < rowNum; i++ {
resContainer[i] = b.grouping(groupingIds.GetInt64(i))
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
}
}

func (b *builtinGroupingSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error {
rowNum := input.NumRows()

bufVal, err := b.bufAllocator.get()
if err != nil {
return err
}
defer b.bufAllocator.put(bufVal)
if err = b.args[0].VecEvalInt(b.ctx, input, bufVal); err != nil {
return err
}

b.groupingVec(bufVal, rowNum, result)

return nil
}
89 changes: 89 additions & 0 deletions expression/builtin_grouping_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
"fmt"
"testing"

"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/testkit/testutil"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tipb/go-tipb"
"github.com/stretchr/testify/require"
)

func createGroupingFunc(ctx sessionctx.Context, args []Expression) (*builtinGroupingSig, error) {
argTp := []types.EvalType{types.ETInt}
bf, err := newBaseBuiltinFuncWithTp(ctx, ast.Grouping, args, types.ETInt, argTp...)
if err != nil {
return nil, err
}
bf.tp.SetFlen(1)
sig := &builtinGroupingSig{bf, 0, map[int64]struct{}{}}
sig.setPbCode(tipb.ScalarFuncSig_GroupingSig)
return sig, nil
}

func TestGrouping(t *testing.T) {
ctx := createContext(t)
tests := []struct {
groupingID uint64
version uint32
groupingIDs map[int64]struct{}
expectResult uint64
}{
// version 1
{1, 1, map[int64]struct{}{1: {}}, 1},
{1, 1, map[int64]struct{}{3: {}}, 1},
{1, 1, map[int64]struct{}{6: {}}, 0},
{2, 1, map[int64]struct{}{1: {}}, 0},
{2, 1, map[int64]struct{}{3: {}}, 1},
{2, 1, map[int64]struct{}{6: {}}, 1},
{4, 1, map[int64]struct{}{2: {}}, 0},
{4, 1, map[int64]struct{}{4: {}}, 1},
{4, 1, map[int64]struct{}{6: {}}, 1},

// version 2
{0, 2, map[int64]struct{}{0: {}}, 0},
{0, 2, map[int64]struct{}{2: {}}, 0},
{2, 2, map[int64]struct{}{0: {}}, 1},
{2, 2, map[int64]struct{}{1: {}}, 1},
{2, 2, map[int64]struct{}{2: {}}, 0},
{2, 2, map[int64]struct{}{3: {}}, 0},

// version 3
{1, 3, map[int64]struct{}{1: {}, 2: {}}, 0},
{1, 3, map[int64]struct{}{2: {}}, 1},
{2, 3, map[int64]struct{}{1: {}, 3: {}}, 1},
{2, 3, map[int64]struct{}{2: {}, 3: {}}, 0},
}

for _, testCase := range tests {
comment := fmt.Sprintf(`for grouping = "%d", version = "%d", groupingIDs = "%v", expectRes = "%d""`, testCase.groupingID, testCase.version, testCase.groupingIDs, testCase.expectResult)
args := datumsToConstants(types.MakeDatums(testCase.groupingID))

groupingFunc, err := createGroupingFunc(ctx, args)
groupingFunc.SetMetaVersion(testCase.version)
groupingFunc.SetMetaGroupingIDs(testCase.groupingIDs)
require.NoError(t, err, comment)

actualResult, err := evalBuiltinFunc(groupingFunc, chunk.Row{})
require.NoError(t, err, comment)
testutil.DatumEqual(t, types.NewDatum(testCase.expectResult), actualResult, comment)
}
}
7 changes: 7 additions & 0 deletions expression/builtin_regexp_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ func fillNullStringIntoResult(result *chunk.Column, num int) {
}
}

func fillNullBytesIntoResult(result *chunk.Column, num int) {
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
result.ReserveBytes(num)
for i := 0; i < num; i++ {
result.AppendNull()
}
}

// check if this is a valid position argument when position is out of range
func checkOutRangePos(strLen int, pos int64) bool {
// false condition:
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,5 @@ replace (
github.com/pingcap/tidb/parser => ./parser
go.opencensus.io => go.opencensus.io v0.23.1-0.20220331163232-052120675fac
)

replace github.com/pingcap/tipb => github.com/pingcap/tipb v0.0.0-20230322022145-dc802b917d4e
1 change: 1 addition & 0 deletions parser/ast/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ const (
TiDBShard = "tidb_shard"
GetLock = "get_lock"
ReleaseLock = "release_lock"
Grouping = "grouping"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shoul not add this since we actually not support it yet.

Copy link
Contributor Author

@xzhangxian1008 xzhangxian1008 Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shoul not add this since we actually not support it yet.

UT needs this field, may be we have to keep it.


// encryption and compression functions
AesDecrypt = "aes_decrypt"
Expand Down