diff --git a/DEPS.bzl b/DEPS.bzl index 2d36d1c6422ae..2b87fcdc3aa7e 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -3410,8 +3410,8 @@ def go_deps(): name = "com_github_pingcap_tipb", build_file_proto_mode = "disable_global", importpath = "github.com/pingcap/tipb", - sum = "h1:CeeMOq1aHPAhXrw4eYXtQRyWOFlbfqK1+3f9Iop4IfU=", - version = "v0.0.0-20230310043643-5362260ee6f7", + sum = "h1:ltplM2dLXcIAwlleA5v4gke6m6ZeHpvUA3qYX9dCC18=", + version = "v0.0.0-20230427024529-aed92caf20b9", ) go_repository( name = "com_github_pkg_browser", diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index 5a38ebdd1a982..25526be4c1736 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "builtin_encryption.go", "builtin_encryption_vec.go", "builtin_func_param.go", + "builtin_grouping.go", "builtin_ilike.go", "builtin_ilike_vec.go", "builtin_info.go", @@ -142,6 +143,7 @@ go_test( "builtin_control_vec_generated_test.go", "builtin_encryption_test.go", "builtin_encryption_vec_test.go", + "builtin_grouping_test.go", "builtin_ilike_test.go", "builtin_info_test.go", "builtin_info_vec_test.go", diff --git a/expression/builtin_grouping.go b/expression/builtin_grouping.go new file mode 100644 index 0000000000000..60b1fad83f539 --- /dev/null +++ b/expression/builtin_grouping.go @@ -0,0 +1,228 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "github.com/gogo/protobuf/proto" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tipb/go-tipb" +) + +var ( + _ functionClass = &groupingImplFunctionClass{} +) + +var ( + _ builtinFunc = &builtinGroupingImplSig{} +) + +type groupingImplFunctionClass struct { + baseFunctionClass +} + +func (c *groupingImplFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTp := []types.EvalType{types.ETInt} + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTp...) + if err != nil { + return nil, err + } + bf.tp.SetFlen(1) + sig := &builtinGroupingImplSig{bf, 0, map[int64]struct{}{}, false} + sig.setPbCode(tipb.ScalarFuncSig_GroupingSig) + return sig, nil +} + +// grouping functions called by user is actually executed by this builtinGroupingImplSig. +// Users will designate a column as parameter to pass into the grouping function, and tidb +// will rewrite it to convert the parameter to the meta info. Then, tidb will generate grouping id +// which is a indicator to be calculated with meta info, these grouping id are actually what +// builtinGroupingImplSig receives. +type builtinGroupingImplSig struct { + baseBuiltinFunc + + // TODO these are two temporary fields for tests + mode tipb.GroupingMode + groupingMarks map[int64]struct{} + isMetaInited bool +} + +func (b *builtinGroupingImplSig) SetMetadata(mode tipb.GroupingMode, groupingMarks map[int64]struct{}) error { + b.setGroupingMode(mode) + b.setMetaGroupingMarks(groupingMarks) + b.isMetaInited = true + err := b.checkMetadata() + if err != nil { + b.isMetaInited = false + return err + } + return nil +} + +func (b *builtinGroupingImplSig) setGroupingMode(mode tipb.GroupingMode) { + b.mode = mode +} + +func (b *builtinGroupingImplSig) setMetaGroupingMarks(groupingMarks map[int64]struct{}) { + b.groupingMarks = groupingMarks +} + +func (b *builtinGroupingImplSig) getGroupingMode() tipb.GroupingMode { + return b.mode +} + +// metadata returns the metadata of grouping functions +func (b *builtinGroupingImplSig) metadata() proto.Message { + err := b.checkMetadata() + if err != nil { + return &tipb.GroupingFunctionMetadata{} + } + args := &tipb.GroupingFunctionMetadata{} + *(args.Mode) = b.mode + for groupingMark := range b.groupingMarks { + args.GroupingMarks = append(args.GroupingMarks, uint64(groupingMark)) + } + return args +} + +func (b *builtinGroupingImplSig) Clone() builtinFunc { + newSig := &builtinGroupingImplSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.mode = b.mode + newSig.groupingMarks = b.groupingMarks + return newSig +} + +func (b *builtinGroupingImplSig) getMetaGroupingMarks() map[int64]struct{} { + return b.groupingMarks +} + +func (b *builtinGroupingImplSig) getMetaGroupingID() int64 { + var metaGroupingID int64 + groupingIDs := b.getMetaGroupingMarks() + for key := range groupingIDs { + metaGroupingID = key + } + return metaGroupingID +} + +func (b *builtinGroupingImplSig) checkMetadata() error { + if !b.isMetaInited { + return errors.Errorf("Meta data hasn't been initialized") + } + mode := b.getGroupingMode() + groupingIDs := b.getMetaGroupingMarks() + if mode != tipb.GroupingMode_ModeBitAnd && mode != tipb.GroupingMode_ModeNumericCmp && mode != tipb.GroupingMode_ModeNumericSet { + return errors.Errorf("Mode of meta data in grouping function is invalid. input mode: %d", mode) + } else if (mode == tipb.GroupingMode_ModeBitAnd || mode == tipb.GroupingMode_ModeNumericCmp) && len(groupingIDs) != 1 { + return errors.Errorf("Invalid number of groupingID. mode: %d, number of groupingID: %d", mode, len(b.groupingMarks)) + } + return nil +} + +func (b *builtinGroupingImplSig) groupingImplBitAnd(groupingID int64, metaGroupingID int64) int64 { + if groupingID&metaGroupingID > 0 { + return 1 + } + return 0 +} + +func (b *builtinGroupingImplSig) groupingImplNumericCmp(groupingID int64, metaGroupingID int64) int64 { + if groupingID > metaGroupingID { + return 1 + } + return 0 +} + +func (b *builtinGroupingImplSig) groupingImplNumericSet(groupingID int64) int64 { + groupingIDs := b.getMetaGroupingMarks() + _, ok := groupingIDs[groupingID] + if ok { + return 0 + } + return 1 +} + +func (b *builtinGroupingImplSig) grouping(groupingID int64) int64 { + switch b.mode { + case tipb.GroupingMode_ModeBitAnd: + return b.groupingImplBitAnd(groupingID, b.getMetaGroupingID()) + case tipb.GroupingMode_ModeNumericCmp: + return b.groupingImplNumericCmp(groupingID, b.getMetaGroupingID()) + case tipb.GroupingMode_ModeNumericSet: + return b.groupingImplNumericSet(groupingID) + } + return 0 +} + +// evalInt evals a builtinGroupingSig. +func (b *builtinGroupingImplSig) evalInt(row chunk.Row) (int64, bool, error) { + if !b.isMetaInited { + return 0, false, errors.Errorf("Meta data is not initialzied") + } + + groupingID, isNull, err := b.args[0].EvalInt(b.ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + + return b.grouping(groupingID), false, nil +} + +func (b *builtinGroupingImplSig) groupingVec(groupingIds *chunk.Column, rowNum int, result *chunk.Column) { + result.ResizeInt64(rowNum, false) + resContainer := result.Int64s() + switch b.mode { + case tipb.GroupingMode_ModeBitAnd: + metaGroupingID := b.getMetaGroupingID() + for i := 0; i < rowNum; i++ { + resContainer[i] = b.groupingImplBitAnd(groupingIds.GetInt64(i), metaGroupingID) + } + case tipb.GroupingMode_ModeNumericCmp: + metaGroupingID := b.getMetaGroupingID() + for i := 0; i < rowNum; i++ { + resContainer[i] = b.groupingImplNumericCmp(groupingIds.GetInt64(i), metaGroupingID) + } + case tipb.GroupingMode_ModeNumericSet: + for i := 0; i < rowNum; i++ { + resContainer[i] = b.groupingImplNumericSet(groupingIds.GetInt64(i)) + } + } +} + +func (b *builtinGroupingImplSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { + if !b.isMetaInited { + return errors.Errorf("Meta data is not initialzied") + } + rowNum := input.NumRows() + + bufVal, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(bufVal) + if err = b.args[0].VecEvalInt(b.ctx, input, bufVal); err != nil { + return err + } + + b.groupingVec(bufVal, rowNum, result) + + return nil +} diff --git a/expression/builtin_grouping_test.go b/expression/builtin_grouping_test.go new file mode 100644 index 0000000000000..96ff46bb2c166 --- /dev/null +++ b/expression/builtin_grouping_test.go @@ -0,0 +1,105 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "fmt" + "testing" + + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/testkit/testutil" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tipb/go-tipb" + "github.com/stretchr/testify/require" +) + +// This is an temporary function, and should not be existence for a long time. +// After the completion of full rollup feature, the construction of 'bf' in +// 'createGroupingFunc' should be implemented by 'newBaseBuiltinFuncWithTp' +func constructFieldType() types.FieldType { + var tp types.FieldType + tp.Init(8) + tp.AddFlag(128) + tp.SetFlen(20) + tp.SetCharset("binary") + tp.SetCollate("binary") + return tp +} + +func createGroupingFunc(ctx sessionctx.Context, args []Expression) (*builtinGroupingImplSig, error) { + // TODO We should use the commented codes after the completion of rollup + // argTp := []types.EvalType{types.ETInt} + tp := constructFieldType() + // bf, err := newBaseBuiltinFuncWithTp(ctx, groupingImplName, args, types.ETInt, argTp...) + bf, err := newBaseBuiltinFuncWithFieldType(ctx, &tp, args) + if err != nil { + return nil, err + } + bf.tp.SetFlen(1) + sig := &builtinGroupingImplSig{bf, 0, map[int64]struct{}{}, false} + sig.setPbCode(tipb.ScalarFuncSig_GroupingSig) + return sig, nil +} + +func TestGrouping(t *testing.T) { + ctx := createContext(t) + tests := []struct { + groupingID uint64 + mode tipb.GroupingMode + groupingIDs map[int64]struct{} + expectResult uint64 + }{ + // GroupingMode_ModeBitAnd + {1, 1, map[int64]struct{}{1: {}}, 1}, + {1, 1, map[int64]struct{}{3: {}}, 1}, + {1, 1, map[int64]struct{}{6: {}}, 0}, + {2, 1, map[int64]struct{}{1: {}}, 0}, + {2, 1, map[int64]struct{}{3: {}}, 1}, + {2, 1, map[int64]struct{}{6: {}}, 1}, + {4, 1, map[int64]struct{}{2: {}}, 0}, + {4, 1, map[int64]struct{}{4: {}}, 1}, + {4, 1, map[int64]struct{}{6: {}}, 1}, + + // GroupingMode_ModeNumericCmp + {0, 2, map[int64]struct{}{0: {}}, 0}, + {0, 2, map[int64]struct{}{2: {}}, 0}, + {2, 2, map[int64]struct{}{0: {}}, 1}, + {2, 2, map[int64]struct{}{1: {}}, 1}, + {2, 2, map[int64]struct{}{2: {}}, 0}, + {2, 2, map[int64]struct{}{3: {}}, 0}, + + // GroupingMode_ModeNumericSet + {1, 3, map[int64]struct{}{1: {}, 2: {}}, 0}, + {1, 3, map[int64]struct{}{2: {}}, 1}, + {2, 3, map[int64]struct{}{1: {}, 3: {}}, 1}, + {2, 3, map[int64]struct{}{2: {}, 3: {}}, 0}, + } + + for _, testCase := range tests { + comment := fmt.Sprintf(`for grouping = "%d", version = "%d", groupingIDs = "%v", expectRes = "%d"`, testCase.groupingID, testCase.mode, testCase.groupingIDs, testCase.expectResult) + args := datumsToConstants(types.MakeDatums(testCase.groupingID)) + + groupingFunc, err := createGroupingFunc(ctx, args) + require.NoError(t, err, comment) + + err = groupingFunc.SetMetadata(testCase.mode, testCase.groupingIDs) + require.NoError(t, err, comment) + + actualResult, err := evalBuiltinFunc(groupingFunc, chunk.Row{}) + require.NoError(t, err, comment) + testutil.DatumEqual(t, types.NewDatum(testCase.expectResult), actualResult, comment) + } +} diff --git a/go.mod b/go.mod index d352c24259266..ed2d42af5a1e7 100644 --- a/go.mod +++ b/go.mod @@ -76,7 +76,7 @@ require ( github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22 github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 github.com/pingcap/tidb/parser v0.0.0-20211011031125-9b13dc409c5e - github.com/pingcap/tipb v0.0.0-20230310043643-5362260ee6f7 + github.com/pingcap/tipb v0.0.0-20230427024529-aed92caf20b9 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.15.0 github.com/prometheus/client_model v0.3.0 diff --git a/go.sum b/go.sum index 18d8a079949e8..0d8bcb2887b44 100644 --- a/go.sum +++ b/go.sum @@ -783,8 +783,8 @@ github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22 h1:2SOzvGvE8beiC1Y4g github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tipb v0.0.0-20230310043643-5362260ee6f7 h1:CeeMOq1aHPAhXrw4eYXtQRyWOFlbfqK1+3f9Iop4IfU= -github.com/pingcap/tipb v0.0.0-20230310043643-5362260ee6f7/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= +github.com/pingcap/tipb v0.0.0-20230427024529-aed92caf20b9 h1:ltplM2dLXcIAwlleA5v4gke6m6ZeHpvUA3qYX9dCC18= +github.com/pingcap/tipb v0.0.0-20230427024529-aed92caf20b9/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=