diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index e348cfb7a41..7fc342d4513 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -395,9 +395,9 @@ func (cached *OrderedAggregate) CachedSize(alloc bool) int64 { size += elem.CachedSize(false) } } - // field Keys []int + // field GroupByKeys []vitess.io/vitess/go/vt/vtgate/engine.GroupbyParams { - size += int64(cap(cached.Keys)) * int64(8) + size += int64(cap(cached.GroupByKeys)) * int64(24) } // field Input vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Input.(cachedObject); ok { diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index ca59f16a981..c24968dfb49 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -45,9 +45,9 @@ type OrderedAggregate struct { // aggregation function: function opcode and input column number. Aggregates []AggregateParams - // Keys specifies the input values that must be used for + // GroupByKeys specifies the input values that must be used for // the aggregation key. - Keys []int + GroupByKeys []GroupbyParams // TruncateColumnCount specifies the number of columns to return // in the final result. Rest of the columns are truncated @@ -58,6 +58,17 @@ type OrderedAggregate struct { Input Primitive } +// GroupbyParams specify the grouping key to be used. +type GroupbyParams struct { + Col int + WeightStringCol int + KeyCol int +} + +func (gbp GroupbyParams) String() string { + return strconv.Itoa(gbp.KeyCol) +} + // AggregateParams specify the parameters for each aggregation. // It contains the opcode and input column number. type AggregateParams struct { @@ -201,7 +212,7 @@ func (oa *OrderedAggregate) execute(vcursor VCursor, bindVars map[string]*queryp current, curDistinct = oa.convertRow(row) } - if len(result.Rows) == 0 && len(oa.Keys) == 0 { + if len(result.Rows) == 0 && len(oa.GroupByKeys) == 0 { // When doing aggregation without grouping keys, we need to produce a single row containing zero-value for the // different aggregation functions row, err := oa.createEmptyRow() @@ -350,10 +361,18 @@ func (oa *OrderedAggregate) NeedsTransaction() bool { } func (oa *OrderedAggregate) keysEqual(row1, row2 []sqltypes.Value) (bool, error) { - for _, key := range oa.Keys { - cmp, err := evalengine.NullsafeCompare(row1[key], row2[key]) + for _, key := range oa.GroupByKeys { + cmp, err := evalengine.NullsafeCompare(row1[key.KeyCol], row2[key.KeyCol]) if err != nil { - return false, err + _, isComparisonErr := err.(evalengine.UnsupportedComparisonError) + if !(isComparisonErr && key.WeightStringCol != -1) { + return false, err + } + key.KeyCol = key.WeightStringCol + cmp, err = evalengine.NullsafeCompare(row1[key.WeightStringCol], row2[key.WeightStringCol]) + if err != nil { + return false, err + } } if cmp != 0 { return false, nil @@ -450,13 +469,13 @@ func aggregateParamsToString(in interface{}) string { return in.(AggregateParams).String() } -func intToString(i interface{}) string { - return strconv.Itoa(i.(int)) +func groupByParamsToString(i interface{}) string { + return i.(GroupbyParams).String() } func (oa *OrderedAggregate) description() PrimitiveDescription { aggregates := GenericJoin(oa.Aggregates, aggregateParamsToString) - groupBy := GenericJoin(oa.Keys, intToString) + groupBy := GenericJoin(oa.GroupByKeys, groupByParamsToString) other := map[string]interface{}{ "Aggregates": aggregates, "GroupBy": groupBy, diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index 045291e0e43..45290eceeeb 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -53,8 +53,8 @@ func TestOrderedAggregateExecute(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - Keys: []int{0}, - Input: fp, + GroupByKeys: []GroupbyParams{{KeyCol: 0}}, + Input: fp, } result, err := oa.Execute(nil, nil, false) @@ -90,7 +90,7 @@ func TestOrderedAggregateExecuteTruncate(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - Keys: []int{2}, + GroupByKeys: []GroupbyParams{{KeyCol: 2}}, TruncateColumnCount: 2, Input: fp, } @@ -132,8 +132,8 @@ func TestOrderedAggregateStreamExecute(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - Keys: []int{0}, - Input: fp, + GroupByKeys: []GroupbyParams{{KeyCol: 0}}, + Input: fp, } var results []*sqltypes.Result @@ -175,7 +175,7 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - Keys: []int{2}, + GroupByKeys: []GroupbyParams{{KeyCol: 2}}, TruncateColumnCount: 2, Input: fp, } @@ -316,8 +316,8 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) { Opcode: AggregateCount, Col: 2, }}, - Keys: []int{0}, - Input: fp, + GroupByKeys: []GroupbyParams{{KeyCol: 0}}, + Input: fp, } result, err := oa.Execute(nil, nil, false) @@ -392,8 +392,8 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) { Opcode: AggregateCount, Col: 2, }}, - Keys: []int{0}, - Input: fp, + GroupByKeys: []GroupbyParams{{KeyCol: 0}}, + Input: fp, } var results []*sqltypes.Result @@ -480,8 +480,8 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) { Opcode: AggregateSum, Col: 2, }}, - Keys: []int{0}, - Input: fp, + GroupByKeys: []GroupbyParams{{KeyCol: 0}}, + Input: fp, } result, err := oa.Execute(nil, nil, false) @@ -525,8 +525,8 @@ func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) { Col: 1, Alias: "sum(distinct col2)", }}, - Keys: []int{0}, - Input: fp, + GroupByKeys: []GroupbyParams{{KeyCol: 0}}, + Input: fp, } result, err := oa.Execute(nil, nil, false) @@ -560,8 +560,8 @@ func TestOrderedAggregateKeysFail(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - Keys: []int{0}, - Input: fp, + GroupByKeys: []GroupbyParams{{KeyCol: 0}}, + Input: fp, } want := "types are not comparable: VARCHAR vs VARCHAR" @@ -593,8 +593,8 @@ func TestOrderedAggregateMergeFail(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - Keys: []int{0}, - Input: fp, + GroupByKeys: []GroupbyParams{{KeyCol: 0}}, + Input: fp, } result := &sqltypes.Result{ @@ -721,8 +721,8 @@ func TestNoInputAndNoGroupingKeys(outer *testing.T) { Col: 0, Alias: test.name, }}, - Keys: []int{}, - Input: fp, + GroupByKeys: []GroupbyParams{}, + Input: fp, } result, err := oa.Execute(nil, nil, false) diff --git a/go/vt/vtgate/planbuilder/grouping.go b/go/vt/vtgate/planbuilder/grouping.go index fec8ad38411..d81f3e52d25 100644 --- a/go/vt/vtgate/planbuilder/grouping.go +++ b/go/vt/vtgate/planbuilder/grouping.go @@ -20,6 +20,7 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/engine" ) func planGroupBy(pb *primitiveBuilder, input logicalPlan, groupBy sqlparser.GroupBy) (logicalPlan, error) { @@ -76,7 +77,7 @@ func planGroupBy(pb *primitiveBuilder, input logicalPlan, groupBy sqlparser.Grou default: return nil, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: only simple references allowed") } - node.eaggr.Keys = append(node.eaggr.Keys, colNumber) + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupbyParams{Col: colNumber, KeyCol: colNumber, WeightStringCol: -1}) } // Append the distinct aggregate if any. if node.extraDistinct != nil { @@ -109,7 +110,7 @@ func planDistinct(input logicalPlan) (logicalPlan, error) { if rc.column.Origin() == node { return newDistinct(node), nil } - node.eaggr.Keys = append(node.eaggr.Keys, i) + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupbyParams{Col: i, KeyCol: i, WeightStringCol: -1}) } newInput, err := planDistinct(node.input) if err != nil { diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index e3a82a08f47..60e5eea65a4 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -331,14 +331,15 @@ func (oa *orderedAggregate) needDistinctHandling(pb *primitiveBuilder, funcExpr // compare those instead. This is because we currently don't have the // ability to mimic mysql's collation behavior. func (oa *orderedAggregate) Wireup(plan logicalPlan, jt *jointab) error { - for i, colNumber := range oa.eaggr.Keys { - rc := oa.resultColumns[colNumber] + for i, gbk := range oa.eaggr.GroupByKeys { + rc := oa.resultColumns[gbk.Col] if sqltypes.IsText(rc.column.typ) { if weightcolNumber, ok := oa.weightStrings[rc]; ok { - oa.eaggr.Keys[i] = weightcolNumber + oa.eaggr.GroupByKeys[i].WeightStringCol = weightcolNumber + oa.eaggr.GroupByKeys[i].KeyCol = weightcolNumber continue } - weightcolNumber, err := oa.input.SupplyWeightString(colNumber) + weightcolNumber, err := oa.input.SupplyWeightString(gbk.Col) if err != nil { _, isUnsupportedErr := err.(UnsupportedSupplyWeightString) if isUnsupportedErr { @@ -347,7 +348,8 @@ func (oa *orderedAggregate) Wireup(plan logicalPlan, jt *jointab) error { return err } oa.weightStrings[rc] = weightcolNumber - oa.eaggr.Keys[i] = weightcolNumber + oa.eaggr.GroupByKeys[i].WeightStringCol = weightcolNumber + oa.eaggr.GroupByKeys[i].KeyCol = weightcolNumber oa.eaggr.TruncateColumnCount = len(oa.resultColumns) } } diff --git a/go/vt/vtgate/planbuilder/ordering.go b/go/vt/vtgate/planbuilder/ordering.go index d3d53f4d720..1a80d00b8b0 100644 --- a/go/vt/vtgate/planbuilder/ordering.go +++ b/go/vt/vtgate/planbuilder/ordering.go @@ -76,7 +76,7 @@ func planOAOrdering(pb *primitiveBuilder, orderBy sqlparser.OrderBy, oa *ordered } // referenced tracks the keys referenced by the order by clause. - referenced := make([]bool, len(oa.eaggr.Keys)) + referenced := make([]bool, len(oa.eaggr.GroupByKeys)) postSort := false selOrderBy := make(sqlparser.OrderBy, 0, len(orderBy)) for _, order := range orderBy { @@ -103,8 +103,8 @@ func planOAOrdering(pb *primitiveBuilder, orderBy sqlparser.OrderBy, oa *ordered // Match orderByCol against the group by columns. found := false - for j, key := range oa.eaggr.Keys { - if oa.resultColumns[key].column != orderByCol { + for j, key := range oa.eaggr.GroupByKeys { + if oa.resultColumns[key.Col].column != orderByCol { continue } @@ -119,12 +119,12 @@ func planOAOrdering(pb *primitiveBuilder, orderBy sqlparser.OrderBy, oa *ordered } // Append any unreferenced keys at the end of the order by. - for i, key := range oa.eaggr.Keys { + for i, key := range oa.eaggr.GroupByKeys { if referenced[i] { continue } // Build a brand new reference for the key. - col, err := BuildColName(oa.input.ResultColumns(), key) + col, err := BuildColName(oa.input.ResultColumns(), key.Col) if err != nil { return nil, vterrors.Wrapf(err, "generating order by clause") } diff --git a/go/vt/vtgate/planbuilder/selectGen4.go b/go/vt/vtgate/planbuilder/selectGen4.go index 4edd6cf6d1c..6a519acf9c6 100644 --- a/go/vt/vtgate/planbuilder/selectGen4.go +++ b/go/vt/vtgate/planbuilder/selectGen4.go @@ -164,15 +164,11 @@ func planGroupByGen4(groupExpr abstract.GroupBy, plan logicalPlan, semTable *sem sel.GroupBy = append(sel.GroupBy, groupExpr.Inner) return false, nil case *orderedAggregate: - offset, weightStringOffset, colAdded, err := funcName(groupExpr.Inner, groupExpr.WeightStrExpr, node.input, semTable) + offset, weightStringOffset, colAdded, err := wrapAndPushExpr(groupExpr.Inner, groupExpr.WeightStrExpr, node.input, semTable) if err != nil { return false, err } - if weightStringOffset == -1 { - node.eaggr.Keys = append(node.eaggr.Keys, offset) - } else { - node.eaggr.Keys = append(node.eaggr.Keys, weightStringOffset) - } + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupbyParams{KeyCol: offset, Col: offset, WeightStringCol: weightStringOffset}) colAddedRecursively, err := planGroupByGen4(groupExpr, node.input, semTable) if err != nil { return false, err @@ -228,7 +224,7 @@ func planOrderBy(qp *abstract.QueryProjection, orderExprs []abstract.OrderBy, pl func planOrderByForRoute(orderExprs []abstract.OrderBy, plan *route, semTable *semantics.SemTable) (logicalPlan, bool, error) { origColCount := plan.Select.GetColumnCount() for _, order := range orderExprs { - offset, weightStringOffset, _, err := funcName(order.Inner.Expr, order.WeightStrExpr, plan, semTable) + offset, weightStringOffset, _, err := wrapAndPushExpr(order.Inner.Expr, order.WeightStrExpr, plan, semTable) if err != nil { return nil, false, err } @@ -243,7 +239,7 @@ func planOrderByForRoute(orderExprs []abstract.OrderBy, plan *route, semTable *s return plan, origColCount != plan.Select.GetColumnCount(), nil } -func funcName(expr sqlparser.Expr, weightStrExpr sqlparser.Expr, plan logicalPlan, semTable *semantics.SemTable) (int, int, bool, error) { +func wrapAndPushExpr(expr sqlparser.Expr, weightStrExpr sqlparser.Expr, plan logicalPlan, semTable *semantics.SemTable) (int, int, bool, error) { offset, added, err := pushProjection(&sqlparser.AliasedExpr{Expr: expr}, plan, semTable, true, true) if err != nil { return 0, 0, false, err @@ -314,7 +310,7 @@ func planOrderByForJoin(qp *abstract.QueryProjection, orderExprs []abstract.Orde var colAdded bool for _, order := range orderExprs { - offset, weightStringOffset, added, err := funcName(order.Inner.Expr, order.WeightStrExpr, plan, semTable) + offset, weightStringOffset, added, err := wrapAndPushExpr(order.Inner.Expr, order.WeightStrExpr, plan, semTable) if err != nil { return nil, false, err } diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 98dc89c3aeb..c4071eb6fee 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -529,15 +529,23 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer // the results, which engine.OrderedAggregate can do. if len(aggregates) != 0 { td.sourcePrimitive = &engine.OrderedAggregate{ - Aggregates: aggregates, - Keys: td.pkCols, - Input: td.sourcePrimitive, + Aggregates: aggregates, + GroupByKeys: pkColsToGroupByParams(td.pkCols), + Input: td.sourcePrimitive, } } return td, nil } +func pkColsToGroupByParams(pkCols []int) []engine.GroupbyParams { + var res []engine.GroupbyParams + for _, col := range pkCols { + res = append(res, engine.GroupbyParams{Col: col, KeyCol: col, WeightStringCol: -1}) + } + return res +} + // newMergeSorter creates an engine.MergeSort based on the shard streamers and pk columns. func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compareColInfo) *engine.MergeSort { prims := make([]engine.StreamExecutor, 0, len(participants)) diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index 5d6e2b3ef46..c995d4b06e1 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -394,8 +394,8 @@ func TestVDiffPlanSuccess(t *testing.T) { Opcode: engine.AggregateSum, Col: 3, }}, - Keys: []int{0}, - Input: newMergeSorter(nil, []compareColInfo{{0, 0, true}}), + GroupByKeys: []engine.GroupbyParams{{Col: 0, KeyCol: 0, WeightStringCol: -1}}, + Input: newMergeSorter(nil, []compareColInfo{{0, 0, true}}), }, targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, 0, true}}), },