Skip to content

Commit

Permalink
store column and weight string column information in group by params
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <harshit@planetscale.com>
  • Loading branch information
harshit-gangal committed Jul 13, 2021
1 parent 10870f9 commit 7809688
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 57 deletions.
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 28 additions & 9 deletions go/vt/vtgate/engine/ordered_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ type OrderedAggregate struct {
// aggregation function: function opcode and input column number.
Aggregates []AggregateParams

// Keys specifies the input values that must be used for
// GroupByKeys specifies the input values that must be used for
// the aggregation key.
Keys []int
GroupByKeys []GroupbyParams

// TruncateColumnCount specifies the number of columns to return
// in the final result. Rest of the columns are truncated
Expand All @@ -58,6 +58,17 @@ type OrderedAggregate struct {
Input Primitive
}

// GroupbyParams specify the grouping key to be used.
type GroupbyParams struct {
Col int
WeightStringCol int
KeyCol int
}

func (gbp GroupbyParams) String() string {
return strconv.Itoa(gbp.KeyCol)
}

// AggregateParams specify the parameters for each aggregation.
// It contains the opcode and input column number.
type AggregateParams struct {
Expand Down Expand Up @@ -201,7 +212,7 @@ func (oa *OrderedAggregate) execute(vcursor VCursor, bindVars map[string]*queryp
current, curDistinct = oa.convertRow(row)
}

if len(result.Rows) == 0 && len(oa.Keys) == 0 {
if len(result.Rows) == 0 && len(oa.GroupByKeys) == 0 {
// When doing aggregation without grouping keys, we need to produce a single row containing zero-value for the
// different aggregation functions
row, err := oa.createEmptyRow()
Expand Down Expand Up @@ -350,10 +361,18 @@ func (oa *OrderedAggregate) NeedsTransaction() bool {
}

func (oa *OrderedAggregate) keysEqual(row1, row2 []sqltypes.Value) (bool, error) {
for _, key := range oa.Keys {
cmp, err := evalengine.NullsafeCompare(row1[key], row2[key])
for _, key := range oa.GroupByKeys {
cmp, err := evalengine.NullsafeCompare(row1[key.KeyCol], row2[key.KeyCol])
if err != nil {
return false, err
_, isComparisonErr := err.(evalengine.UnsupportedComparisonError)
if !(isComparisonErr && key.WeightStringCol != -1) {
return false, err
}
key.KeyCol = key.WeightStringCol
cmp, err = evalengine.NullsafeCompare(row1[key.WeightStringCol], row2[key.WeightStringCol])
if err != nil {
return false, err
}
}
if cmp != 0 {
return false, nil
Expand Down Expand Up @@ -450,13 +469,13 @@ func aggregateParamsToString(in interface{}) string {
return in.(AggregateParams).String()
}

func intToString(i interface{}) string {
return strconv.Itoa(i.(int))
func groupByParamsToString(i interface{}) string {
return i.(GroupbyParams).String()
}

func (oa *OrderedAggregate) description() PrimitiveDescription {
aggregates := GenericJoin(oa.Aggregates, aggregateParamsToString)
groupBy := GenericJoin(oa.Keys, intToString)
groupBy := GenericJoin(oa.GroupByKeys, groupByParamsToString)
other := map[string]interface{}{
"Aggregates": aggregates,
"GroupBy": groupBy,
Expand Down
40 changes: 20 additions & 20 deletions go/vt/vtgate/engine/ordered_aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ func TestOrderedAggregateExecute(t *testing.T) {
Opcode: AggregateCount,
Col: 1,
}},
Keys: []int{0},
Input: fp,
GroupByKeys: []GroupbyParams{{KeyCol: 0}},
Input: fp,
}

result, err := oa.Execute(nil, nil, false)
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestOrderedAggregateExecuteTruncate(t *testing.T) {
Opcode: AggregateCount,
Col: 1,
}},
Keys: []int{2},
GroupByKeys: []GroupbyParams{{KeyCol: 2}},
TruncateColumnCount: 2,
Input: fp,
}
Expand Down Expand Up @@ -132,8 +132,8 @@ func TestOrderedAggregateStreamExecute(t *testing.T) {
Opcode: AggregateCount,
Col: 1,
}},
Keys: []int{0},
Input: fp,
GroupByKeys: []GroupbyParams{{KeyCol: 0}},
Input: fp,
}

var results []*sqltypes.Result
Expand Down Expand Up @@ -175,7 +175,7 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) {
Opcode: AggregateCount,
Col: 1,
}},
Keys: []int{2},
GroupByKeys: []GroupbyParams{{KeyCol: 2}},
TruncateColumnCount: 2,
Input: fp,
}
Expand Down Expand Up @@ -316,8 +316,8 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) {
Opcode: AggregateCount,
Col: 2,
}},
Keys: []int{0},
Input: fp,
GroupByKeys: []GroupbyParams{{KeyCol: 0}},
Input: fp,
}

result, err := oa.Execute(nil, nil, false)
Expand Down Expand Up @@ -392,8 +392,8 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) {
Opcode: AggregateCount,
Col: 2,
}},
Keys: []int{0},
Input: fp,
GroupByKeys: []GroupbyParams{{KeyCol: 0}},
Input: fp,
}

var results []*sqltypes.Result
Expand Down Expand Up @@ -480,8 +480,8 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) {
Opcode: AggregateSum,
Col: 2,
}},
Keys: []int{0},
Input: fp,
GroupByKeys: []GroupbyParams{{KeyCol: 0}},
Input: fp,
}

result, err := oa.Execute(nil, nil, false)
Expand Down Expand Up @@ -525,8 +525,8 @@ func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) {
Col: 1,
Alias: "sum(distinct col2)",
}},
Keys: []int{0},
Input: fp,
GroupByKeys: []GroupbyParams{{KeyCol: 0}},
Input: fp,
}

result, err := oa.Execute(nil, nil, false)
Expand Down Expand Up @@ -560,8 +560,8 @@ func TestOrderedAggregateKeysFail(t *testing.T) {
Opcode: AggregateCount,
Col: 1,
}},
Keys: []int{0},
Input: fp,
GroupByKeys: []GroupbyParams{{KeyCol: 0}},
Input: fp,
}

want := "types are not comparable: VARCHAR vs VARCHAR"
Expand Down Expand Up @@ -593,8 +593,8 @@ func TestOrderedAggregateMergeFail(t *testing.T) {
Opcode: AggregateCount,
Col: 1,
}},
Keys: []int{0},
Input: fp,
GroupByKeys: []GroupbyParams{{KeyCol: 0}},
Input: fp,
}

result := &sqltypes.Result{
Expand Down Expand Up @@ -721,8 +721,8 @@ func TestNoInputAndNoGroupingKeys(outer *testing.T) {
Col: 0,
Alias: test.name,
}},
Keys: []int{},
Input: fp,
GroupByKeys: []GroupbyParams{},
Input: fp,
}

result, err := oa.Execute(nil, nil, false)
Expand Down
5 changes: 3 additions & 2 deletions go/vt/vtgate/planbuilder/grouping.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/engine"
)

func planGroupBy(pb *primitiveBuilder, input logicalPlan, groupBy sqlparser.GroupBy) (logicalPlan, error) {
Expand Down Expand Up @@ -76,7 +77,7 @@ func planGroupBy(pb *primitiveBuilder, input logicalPlan, groupBy sqlparser.Grou
default:
return nil, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: only simple references allowed")
}
node.eaggr.Keys = append(node.eaggr.Keys, colNumber)
node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupbyParams{Col: colNumber, KeyCol: colNumber, WeightStringCol: -1})
}
// Append the distinct aggregate if any.
if node.extraDistinct != nil {
Expand Down Expand Up @@ -109,7 +110,7 @@ func planDistinct(input logicalPlan) (logicalPlan, error) {
if rc.column.Origin() == node {
return newDistinct(node), nil
}
node.eaggr.Keys = append(node.eaggr.Keys, i)
node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupbyParams{Col: i, KeyCol: i, WeightStringCol: -1})
}
newInput, err := planDistinct(node.input)
if err != nil {
Expand Down
12 changes: 7 additions & 5 deletions go/vt/vtgate/planbuilder/ordered_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,15 @@ func (oa *orderedAggregate) needDistinctHandling(pb *primitiveBuilder, funcExpr
// compare those instead. This is because we currently don't have the
// ability to mimic mysql's collation behavior.
func (oa *orderedAggregate) Wireup(plan logicalPlan, jt *jointab) error {
for i, colNumber := range oa.eaggr.Keys {
rc := oa.resultColumns[colNumber]
for i, gbk := range oa.eaggr.GroupByKeys {
rc := oa.resultColumns[gbk.Col]
if sqltypes.IsText(rc.column.typ) {
if weightcolNumber, ok := oa.weightStrings[rc]; ok {
oa.eaggr.Keys[i] = weightcolNumber
oa.eaggr.GroupByKeys[i].WeightStringCol = weightcolNumber
oa.eaggr.GroupByKeys[i].KeyCol = weightcolNumber
continue
}
weightcolNumber, err := oa.input.SupplyWeightString(colNumber)
weightcolNumber, err := oa.input.SupplyWeightString(gbk.Col)
if err != nil {
_, isUnsupportedErr := err.(UnsupportedSupplyWeightString)
if isUnsupportedErr {
Expand All @@ -347,7 +348,8 @@ func (oa *orderedAggregate) Wireup(plan logicalPlan, jt *jointab) error {
return err
}
oa.weightStrings[rc] = weightcolNumber
oa.eaggr.Keys[i] = weightcolNumber
oa.eaggr.GroupByKeys[i].WeightStringCol = weightcolNumber
oa.eaggr.GroupByKeys[i].KeyCol = weightcolNumber
oa.eaggr.TruncateColumnCount = len(oa.resultColumns)
}
}
Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtgate/planbuilder/ordering.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func planOAOrdering(pb *primitiveBuilder, orderBy sqlparser.OrderBy, oa *ordered
}

// referenced tracks the keys referenced by the order by clause.
referenced := make([]bool, len(oa.eaggr.Keys))
referenced := make([]bool, len(oa.eaggr.GroupByKeys))
postSort := false
selOrderBy := make(sqlparser.OrderBy, 0, len(orderBy))
for _, order := range orderBy {
Expand All @@ -103,8 +103,8 @@ func planOAOrdering(pb *primitiveBuilder, orderBy sqlparser.OrderBy, oa *ordered

// Match orderByCol against the group by columns.
found := false
for j, key := range oa.eaggr.Keys {
if oa.resultColumns[key].column != orderByCol {
for j, key := range oa.eaggr.GroupByKeys {
if oa.resultColumns[key.Col].column != orderByCol {
continue
}

Expand All @@ -119,12 +119,12 @@ func planOAOrdering(pb *primitiveBuilder, orderBy sqlparser.OrderBy, oa *ordered
}

// Append any unreferenced keys at the end of the order by.
for i, key := range oa.eaggr.Keys {
for i, key := range oa.eaggr.GroupByKeys {
if referenced[i] {
continue
}
// Build a brand new reference for the key.
col, err := BuildColName(oa.input.ResultColumns(), key)
col, err := BuildColName(oa.input.ResultColumns(), key.Col)
if err != nil {
return nil, vterrors.Wrapf(err, "generating order by clause")
}
Expand Down
14 changes: 5 additions & 9 deletions go/vt/vtgate/planbuilder/selectGen4.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,11 @@ func planGroupByGen4(groupExpr abstract.GroupBy, plan logicalPlan, semTable *sem
sel.GroupBy = append(sel.GroupBy, groupExpr.Inner)
return false, nil
case *orderedAggregate:
offset, weightStringOffset, colAdded, err := funcName(groupExpr.Inner, groupExpr.WeightStrExpr, node.input, semTable)
offset, weightStringOffset, colAdded, err := wrapAndPushExpr(groupExpr.Inner, groupExpr.WeightStrExpr, node.input, semTable)
if err != nil {
return false, err
}
if weightStringOffset == -1 {
node.eaggr.Keys = append(node.eaggr.Keys, offset)
} else {
node.eaggr.Keys = append(node.eaggr.Keys, weightStringOffset)
}
node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupbyParams{KeyCol: offset, Col: offset, WeightStringCol: weightStringOffset})
colAddedRecursively, err := planGroupByGen4(groupExpr, node.input, semTable)
if err != nil {
return false, err
Expand Down Expand Up @@ -228,7 +224,7 @@ func planOrderBy(qp *abstract.QueryProjection, orderExprs []abstract.OrderBy, pl
func planOrderByForRoute(orderExprs []abstract.OrderBy, plan *route, semTable *semantics.SemTable) (logicalPlan, bool, error) {
origColCount := plan.Select.GetColumnCount()
for _, order := range orderExprs {
offset, weightStringOffset, _, err := funcName(order.Inner.Expr, order.WeightStrExpr, plan, semTable)
offset, weightStringOffset, _, err := wrapAndPushExpr(order.Inner.Expr, order.WeightStrExpr, plan, semTable)
if err != nil {
return nil, false, err
}
Expand All @@ -243,7 +239,7 @@ func planOrderByForRoute(orderExprs []abstract.OrderBy, plan *route, semTable *s
return plan, origColCount != plan.Select.GetColumnCount(), nil
}

func funcName(expr sqlparser.Expr, weightStrExpr sqlparser.Expr, plan logicalPlan, semTable *semantics.SemTable) (int, int, bool, error) {
func wrapAndPushExpr(expr sqlparser.Expr, weightStrExpr sqlparser.Expr, plan logicalPlan, semTable *semantics.SemTable) (int, int, bool, error) {
offset, added, err := pushProjection(&sqlparser.AliasedExpr{Expr: expr}, plan, semTable, true, true)
if err != nil {
return 0, 0, false, err
Expand Down Expand Up @@ -314,7 +310,7 @@ func planOrderByForJoin(qp *abstract.QueryProjection, orderExprs []abstract.Orde

var colAdded bool
for _, order := range orderExprs {
offset, weightStringOffset, added, err := funcName(order.Inner.Expr, order.WeightStrExpr, plan, semTable)
offset, weightStringOffset, added, err := wrapAndPushExpr(order.Inner.Expr, order.WeightStrExpr, plan, semTable)
if err != nil {
return nil, false, err
}
Expand Down
14 changes: 11 additions & 3 deletions go/vt/wrangler/vdiff.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,15 +529,23 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer
// the results, which engine.OrderedAggregate can do.
if len(aggregates) != 0 {
td.sourcePrimitive = &engine.OrderedAggregate{
Aggregates: aggregates,
Keys: td.pkCols,
Input: td.sourcePrimitive,
Aggregates: aggregates,
GroupByKeys: pkColsToGroupByParams(td.pkCols),
Input: td.sourcePrimitive,
}
}

return td, nil
}

func pkColsToGroupByParams(pkCols []int) []engine.GroupbyParams {
var res []engine.GroupbyParams
for _, col := range pkCols {
res = append(res, engine.GroupbyParams{Col: col, KeyCol: col, WeightStringCol: -1})
}
return res
}

// newMergeSorter creates an engine.MergeSort based on the shard streamers and pk columns.
func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compareColInfo) *engine.MergeSort {
prims := make([]engine.StreamExecutor, 0, len(participants))
Expand Down
4 changes: 2 additions & 2 deletions go/vt/wrangler/vdiff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ func TestVDiffPlanSuccess(t *testing.T) {
Opcode: engine.AggregateSum,
Col: 3,
}},
Keys: []int{0},
Input: newMergeSorter(nil, []compareColInfo{{0, 0, true}}),
GroupByKeys: []engine.GroupbyParams{{Col: 0, KeyCol: 0, WeightStringCol: -1}},
Input: newMergeSorter(nil, []compareColInfo{{0, 0, true}}),
},
targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, 0, true}}),
},
Expand Down

0 comments on commit 7809688

Please sign in to comment.