diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index a26f7d08eba..b3d96970b43 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -535,9 +535,11 @@ func TestDistinctAggregation(t *testing.T) { mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'d1','toto',200), (6,'c1','tata',893), (7,'a1','titi',2380), (8,'b1','tete',12833), (9,'e1','yoyo',783493)") for _, query := range []string{ - `SELECT /*vt+ PLANNER=gen4 */ COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`, + // `SELECT /*vt+ PLANNER=gen4 */ COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`, - fails as different distinct expression. `SELECT /*vt+ PLANNER=gen4 */ a.t1_id, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.t1_id`, `SELECT /*vt+ PLANNER=gen4 */ a.value, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.value`, + // `SELECT /*vt+ PLANNER=gen4 */ count(distinct a.value), SUM(DISTINCT b.t1_id) FROM t1 a, t1 b`, - fails as different distinct expression. + `SELECT /*vt+ PLANNER=gen4 */ a.value, SUM(DISTINCT b.t1_id) FROM t1 a, t1 b group by a.value`, } { mcmp.Run(query, func(mcmp *utils.MySQLCompare) { mcmp.Exec(query) diff --git a/go/vt/vtgate/engine/scalar_aggregation_test.go b/go/vt/vtgate/engine/scalar_aggregation_test.go index 810d67f2b53..2dfc5b10763 100644 --- a/go/vt/vtgate/engine/scalar_aggregation_test.go +++ b/go/vt/vtgate/engine/scalar_aggregation_test.go @@ -24,8 +24,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/sqltypes" @@ -258,33 +256,76 @@ func TestScalarGroupConcatWithAggrOnEngine(t *testing.T) { } // TestScalarDistinctAggr tests distinct aggregation on engine. -func TestScalarDistinctAggr(t *testing.T) { +func TestScalarDistinctAggrOnEngine(t *testing.T) { + fields := sqltypes.MakeTestFields( + "value|value", + "int64|int64", + ) + + fp := &fakePrimitive{results: []*sqltypes.Result{sqltypes.MakeTestResult( + fields, + "100|100", + "200|200", + "200|200", + "400|400", + "400|400", + "600|600", + )}} + + oa := &ScalarAggregate{ + Aggregates: []*AggregateParams{ + NewAggregateParam(AggregateCountDistinct, 0, "count(distinct value)"), + NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct value)"), + }, + Input: fp, + } + qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) + require.NoError(t, err) + require.Equal(t, `[[INT64(4) DECIMAL(1300)]]`, fmt.Sprintf("%v", qr.Rows)) + + fp.rewind() + results := &sqltypes.Result{} + err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error { + if qr.Fields != nil { + results.Fields = qr.Fields + } + results.Rows = append(results.Rows, qr.Rows...) + return nil + }) + require.NoError(t, err) + require.Equal(t, `[[INT64(4) DECIMAL(1300)]]`, fmt.Sprintf("%v", results.Rows)) +} + +func TestScalarDistinctPushedDown(t *testing.T) { fields := sqltypes.MakeTestFields( - "value|sum(distinct shardkey)", - "varchar|decimal", + "count(distinct value)|sum(distinct value)", + "int64|decimal", ) fp := &fakePrimitive{results: []*sqltypes.Result{sqltypes.MakeTestResult( fields, - "foo|600", - "tata|893", - "tete|12833", - "titi|2380", - "toto|200", - "yoyo|783493", + "2|200", + "6|400", + "3|700", + "1|10", + "7|30", + "8|90", )}} - param := NewAggregateParam(AggregateCountDistinct, 0, "count(distinct value)") - param.CollationID = collations.CollationUtf8mb4ID - param2 := NewAggregateParam(AggregateSum, 1, "sum(distinct sharkey)") - param2.OrigOpcode = AggregateSumDistinct + countAggr := NewAggregateParam(AggregateSum, 0, "count(distinct value)") + countAggr.OrigOpcode = AggregateCountDistinct + sumAggr := NewAggregateParam(AggregateSum, 1, "sum(distinct value)") + sumAggr.OrigOpcode = AggregateSumDistinct oa := &ScalarAggregate{ - Aggregates: []*AggregateParams{param, param2}, - Input: fp, + Aggregates: []*AggregateParams{ + countAggr, + sumAggr, + }, + Input: fp, } qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) require.NoError(t, err) - require.Equal(t, `[INT64(6) DECIMAL(800199)]`, fmt.Sprintf("%v", qr.Rows)) + require.Equal(t, `[[INT64(27) DECIMAL(1430)]]`, fmt.Sprintf("%v", qr.Rows)) fp.rewind() results := &sqltypes.Result{} @@ -296,7 +337,7 @@ func TestScalarDistinctAggr(t *testing.T) { return nil }) require.NoError(t, err) - require.Equal(t, `[INT64(6) DECIMAL(800199)]`, fmt.Sprintf("%v", results.Rows)) + require.Equal(t, `[[INT64(27) DECIMAL(1430)]]`, fmt.Sprintf("%v", results.Rows)) } // TestScalarGroupConcat tests group_concat with partial aggregation on engine. diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index 80aba9ca1f3..3a67a4fc2a2 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -115,35 +115,68 @@ func pushDownAggregationThroughRoute( // pushDownAggregations splits aggregations between the original aggregator and the one we are pushing down func pushDownAggregations(ctx *plancontext.PlanningContext, aggregator *Aggregator, aggrBelowRoute *Aggregator) error { - for i, aggregation := range aggregator.Aggregations { - if !aggregation.Distinct || exprHasUniqueVindex(ctx, aggregation.Func.GetArg()) { - aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggregation) - aggregateTheAggregate(aggregator, i) + canPushDownDistinct, distinctExpr, err := checkIfWeCanPushDown(ctx, aggregator) + if err != nil { + return err + } + + if !canPushDownDistinct { + aggregator.DistinctExpr = distinctExpr + } + + aeDistinctExpr := aeWrap(aggregator.DistinctExpr) + offset := -1 + for i, aggr := range aggregator.Aggregations { + if aggr.Distinct && !canPushDownDistinct { + offset = aggr.ColOffset + aggrBelowRoute.Columns[offset] = aeDistinctExpr continue } - innerExpr := aggregation.Func.GetArg() + aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggr) + aggregateTheAggregate(aggregator, i) + } - if aggregator.DistinctExpr != nil { - if ctx.SemTable.EqualsExpr(aggregator.DistinctExpr, innerExpr) { - // we can handle multiple distinct aggregations, as long as they are aggregating on the same expression - aggrBelowRoute.Columns[aggregation.ColOffset] = aeWrap(innerExpr) - continue - } - return vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(aggregation.Original))) - } + // everything is pushed below the route. + if canPushDownDistinct { + return nil + } - // We handle a distinct aggregation by turning it into a group by and - // doing the aggregating on the vtgate level instead - aggregator.DistinctExpr = innerExpr - aeDistinctExpr := aeWrap(aggregator.DistinctExpr) + // We handle a distinct aggregation by turning it into a group by and + // doing the aggregating on the vtgate level instead + // Adding to group by can be done only once even though there are multiple distinct aggregation with same expression. + groupBy := NewGroupBy(aggregator.DistinctExpr, aggregator.DistinctExpr, aeDistinctExpr) + groupBy.ColOffset = offset + aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy) - aggrBelowRoute.Columns[aggregation.ColOffset] = aeDistinctExpr + return nil +} + +func checkIfWeCanPushDown(ctx *plancontext.PlanningContext, aggregator *Aggregator) (bool, sqlparser.Expr, error) { + canPushDown := true + var distinctExpr sqlparser.Expr + var differentExpr *sqlparser.AliasedExpr - groupBy := NewGroupBy(aggregator.DistinctExpr, aggregator.DistinctExpr, aeDistinctExpr) - groupBy.ColOffset = aggregation.ColOffset - aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy) + for _, aggr := range aggregator.Aggregations { + if !aggr.Distinct { + continue + } + innerExpr := aggr.Func.GetArg() + if !exprHasUniqueVindex(ctx, innerExpr) { + canPushDown = false + } + if distinctExpr == nil { + distinctExpr = innerExpr + } + if !ctx.SemTable.EqualsExpr(distinctExpr, innerExpr) { + differentExpr = aggr.Original + } } - return nil + + if !canPushDown && differentExpr != nil { + return false, nil, vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(differentExpr))) + } + + return canPushDown, distinctExpr, nil } func pushDownAggregationThroughFilter( @@ -411,6 +444,15 @@ func splitAggrColumnsToLeftAndRight( outerJoin: join.LeftJoin, } + canPushDownDistinct, distinctExpr, err := checkIfWeCanPushDown(ctx, aggregator) + if err != nil { + return nil, nil, err + } + if !canPushDownDistinct { + aggregator.DistinctExpr = distinctExpr + return nil, nil, errAbortAggrPushing + } + outer: // we prefer adding the aggregations in the same order as the columns are declared for colIdx, col := range aggregator.Columns { @@ -509,9 +551,6 @@ func (ab *aggBuilder) handleAggr(ctx *plancontext.PlanningContext, aggr Aggr) er // this is only used for SHOW GTID queries that will never contain joins return vterrors.VT13001("cannot do join with vgtid") case opcode.AggregateSumDistinct, opcode.AggregateCountDistinct: - if !exprHasUniqueVindex(ctx, aggr.Func.GetArg()) { - return errAbortAggrPushing - } return ab.handlePushThroughAggregation(ctx, aggr) default: return errHorizonNotPlanned() diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 36602e024d4..ce6b9cc1912 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -42,7 +42,9 @@ type ( Grouping []GroupBy Aggregations []Aggr - // We support a single distinct aggregation per aggregator. It is stored here + // We support a single distinct aggregation per aggregator. It is stored here. + // When planning the ordering that the OrderedAggregate will require, + // this needs to be the last ORDER BY expression DistinctExpr sqlparser.Expr // Pushed will be set to true once this aggregation has been pushed deeper in the tree diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index c7462522ce0..0dbbe645080 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -3854,8 +3854,8 @@ "Sharded": true }, "FieldQuery": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u where 1 != 1", - "OrderBy": "0 ASC COLLATE latin1_swedish_ci", - "Query": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u order by u.textcol1 asc", + "OrderBy": "0 ASC COLLATE latin1_swedish_ci, (1|2) ASC", + "Query": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u order by u.textcol1 asc, u.val2 asc", "Table": "`user`" }, {