Skip to content

Commit

Permalink
fix aggregation to not push partial distinct - all or none
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <harshit@planetscale.com>
  • Loading branch information
harshit-gangal committed Jul 10, 2023
1 parent 254e401 commit 83a2394
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,11 @@ func TestDistinctAggregation(t *testing.T) {
mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'d1','toto',200), (6,'c1','tata',893), (7,'a1','titi',2380), (8,'b1','tete',12833), (9,'e1','yoyo',783493)")

for _, query := range []string{
`SELECT /*vt+ PLANNER=gen4 */ COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`,
// `SELECT /*vt+ PLANNER=gen4 */ COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`, - fails as different distinct expression.
`SELECT /*vt+ PLANNER=gen4 */ a.t1_id, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.t1_id`,
`SELECT /*vt+ PLANNER=gen4 */ a.value, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.value`,
// `SELECT /*vt+ PLANNER=gen4 */ count(distinct a.value), SUM(DISTINCT b.t1_id) FROM t1 a, t1 b`, - fails as different distinct expression.
`SELECT /*vt+ PLANNER=gen4 */ a.value, SUM(DISTINCT b.t1_id) FROM t1 a, t1 b group by a.value`,
} {
mcmp.Run(query, func(mcmp *utils.MySQLCompare) {
mcmp.Exec(query)
Expand Down
79 changes: 60 additions & 19 deletions go/vt/vtgate/engine/scalar_aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"

"vitess.io/vitess/go/test/utils"

"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -258,33 +256,76 @@ func TestScalarGroupConcatWithAggrOnEngine(t *testing.T) {
}

// TestScalarDistinctAggr tests distinct aggregation on engine.
func TestScalarDistinctAggr(t *testing.T) {
func TestScalarDistinctAggrOnEngine(t *testing.T) {
fields := sqltypes.MakeTestFields(
"value|value",
"int64|int64",
)

fp := &fakePrimitive{results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"100|100",
"200|200",
"200|200",
"400|400",
"400|400",
"600|600",
)}}

oa := &ScalarAggregate{
Aggregates: []*AggregateParams{
NewAggregateParam(AggregateCountDistinct, 0, "count(distinct value)"),
NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct value)"),
},
Input: fp,
}
qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
require.NoError(t, err)
require.Equal(t, `[[INT64(4) DECIMAL(1300)]]`, fmt.Sprintf("%v", qr.Rows))

fp.rewind()
results := &sqltypes.Result{}
err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
if qr.Fields != nil {
results.Fields = qr.Fields
}
results.Rows = append(results.Rows, qr.Rows...)
return nil
})
require.NoError(t, err)
require.Equal(t, `[[INT64(4) DECIMAL(1300)]]`, fmt.Sprintf("%v", results.Rows))
}

func TestScalarDistinctPushedDown(t *testing.T) {
fields := sqltypes.MakeTestFields(
"value|sum(distinct shardkey)",
"varchar|decimal",
"count(distinct value)|sum(distinct value)",
"int64|decimal",
)

fp := &fakePrimitive{results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"foo|600",
"tata|893",
"tete|12833",
"titi|2380",
"toto|200",
"yoyo|783493",
"2|200",
"6|400",
"3|700",
"1|10",
"7|30",
"8|90",
)}}
param := NewAggregateParam(AggregateCountDistinct, 0, "count(distinct value)")
param.CollationID = collations.CollationUtf8mb4ID

param2 := NewAggregateParam(AggregateSum, 1, "sum(distinct sharkey)")
param2.OrigOpcode = AggregateSumDistinct
countAggr := NewAggregateParam(AggregateSum, 0, "count(distinct value)")
countAggr.OrigOpcode = AggregateCountDistinct
sumAggr := NewAggregateParam(AggregateSum, 1, "sum(distinct value)")
sumAggr.OrigOpcode = AggregateSumDistinct
oa := &ScalarAggregate{
Aggregates: []*AggregateParams{param, param2},
Input: fp,
Aggregates: []*AggregateParams{
countAggr,
sumAggr,
},
Input: fp,
}
qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
require.NoError(t, err)
require.Equal(t, `[INT64(6) DECIMAL(800199)]`, fmt.Sprintf("%v", qr.Rows))
require.Equal(t, `[[INT64(27) DECIMAL(1430)]]`, fmt.Sprintf("%v", qr.Rows))

fp.rewind()
results := &sqltypes.Result{}
Expand All @@ -296,7 +337,7 @@ func TestScalarDistinctAggr(t *testing.T) {
return nil
})
require.NoError(t, err)
require.Equal(t, `[INT64(6) DECIMAL(800199)]`, fmt.Sprintf("%v", results.Rows))
require.Equal(t, `[[INT64(27) DECIMAL(1430)]]`, fmt.Sprintf("%v", results.Rows))
}

// TestScalarGroupConcat tests group_concat with partial aggregation on engine.
Expand Down
89 changes: 64 additions & 25 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,35 +115,68 @@ func pushDownAggregationThroughRoute(

// pushDownAggregations splits aggregations between the original aggregator and the one we are pushing down
func pushDownAggregations(ctx *plancontext.PlanningContext, aggregator *Aggregator, aggrBelowRoute *Aggregator) error {
for i, aggregation := range aggregator.Aggregations {
if !aggregation.Distinct || exprHasUniqueVindex(ctx, aggregation.Func.GetArg()) {
aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggregation)
aggregateTheAggregate(aggregator, i)
canPushDownDistinct, distinctExpr, err := checkIfWeCanPushDown(ctx, aggregator)
if err != nil {
return err
}

if !canPushDownDistinct {
aggregator.DistinctExpr = distinctExpr
}

aeDistinctExpr := aeWrap(aggregator.DistinctExpr)
offset := -1
for i, aggr := range aggregator.Aggregations {
if aggr.Distinct && !canPushDownDistinct {
offset = aggr.ColOffset
aggrBelowRoute.Columns[offset] = aeDistinctExpr
continue
}
innerExpr := aggregation.Func.GetArg()
aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggr)
aggregateTheAggregate(aggregator, i)
}

if aggregator.DistinctExpr != nil {
if ctx.SemTable.EqualsExpr(aggregator.DistinctExpr, innerExpr) {
// we can handle multiple distinct aggregations, as long as they are aggregating on the same expression
aggrBelowRoute.Columns[aggregation.ColOffset] = aeWrap(innerExpr)
continue
}
return vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(aggregation.Original)))
}
// everything is pushed below the route.
if canPushDownDistinct {
return nil
}

// We handle a distinct aggregation by turning it into a group by and
// doing the aggregating on the vtgate level instead
aggregator.DistinctExpr = innerExpr
aeDistinctExpr := aeWrap(aggregator.DistinctExpr)
// We handle a distinct aggregation by turning it into a group by and
// doing the aggregating on the vtgate level instead
// Adding to group by can be done only once even though there are multiple distinct aggregation with same expression.
groupBy := NewGroupBy(aggregator.DistinctExpr, aggregator.DistinctExpr, aeDistinctExpr)
groupBy.ColOffset = offset
aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy)

aggrBelowRoute.Columns[aggregation.ColOffset] = aeDistinctExpr
return nil
}

func checkIfWeCanPushDown(ctx *plancontext.PlanningContext, aggregator *Aggregator) (bool, sqlparser.Expr, error) {
canPushDown := true
var distinctExpr sqlparser.Expr
var differentExpr *sqlparser.AliasedExpr

groupBy := NewGroupBy(aggregator.DistinctExpr, aggregator.DistinctExpr, aeDistinctExpr)
groupBy.ColOffset = aggregation.ColOffset
aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy)
for _, aggr := range aggregator.Aggregations {
if !aggr.Distinct {
continue
}
innerExpr := aggr.Func.GetArg()
if !exprHasUniqueVindex(ctx, innerExpr) {
canPushDown = false
}
if distinctExpr == nil {
distinctExpr = innerExpr
}
if !ctx.SemTable.EqualsExpr(distinctExpr, innerExpr) {
differentExpr = aggr.Original
}
}
return nil

if !canPushDown && differentExpr != nil {
return false, nil, vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(differentExpr)))
}

return canPushDown, distinctExpr, nil
}

func pushDownAggregationThroughFilter(
Expand Down Expand Up @@ -411,6 +444,15 @@ func splitAggrColumnsToLeftAndRight(
outerJoin: join.LeftJoin,
}

canPushDownDistinct, distinctExpr, err := checkIfWeCanPushDown(ctx, aggregator)
if err != nil {
return nil, nil, err
}
if !canPushDownDistinct {
aggregator.DistinctExpr = distinctExpr
return nil, nil, errAbortAggrPushing
}

outer:
// we prefer adding the aggregations in the same order as the columns are declared
for colIdx, col := range aggregator.Columns {
Expand Down Expand Up @@ -509,9 +551,6 @@ func (ab *aggBuilder) handleAggr(ctx *plancontext.PlanningContext, aggr Aggr) er
// this is only used for SHOW GTID queries that will never contain joins
return vterrors.VT13001("cannot do join with vgtid")
case opcode.AggregateSumDistinct, opcode.AggregateCountDistinct:
if !exprHasUniqueVindex(ctx, aggr.Func.GetArg()) {
return errAbortAggrPushing
}
return ab.handlePushThroughAggregation(ctx, aggr)
default:
return errHorizonNotPlanned()
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ type (
Grouping []GroupBy
Aggregations []Aggr

// We support a single distinct aggregation per aggregator. It is stored here
// We support a single distinct aggregation per aggregator. It is stored here.
// When planning the ordering that the OrderedAggregate will require,
// this needs to be the last ORDER BY expression
DistinctExpr sqlparser.Expr

// Pushed will be set to true once this aggregation has been pushed deeper in the tree
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/testdata/aggr_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -3854,8 +3854,8 @@
"Sharded": true
},
"FieldQuery": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u where 1 != 1",
"OrderBy": "0 ASC COLLATE latin1_swedish_ci",
"Query": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u order by u.textcol1 asc",
"OrderBy": "0 ASC COLLATE latin1_swedish_ci, (1|2) ASC",
"Query": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u order by u.textcol1 asc, u.val2 asc",
"Table": "`user`"
},
{
Expand Down

0 comments on commit 83a2394

Please sign in to comment.