Skip to content

Commit

Permalink
reslove the correlated agg func's correlated col when it in the sub-q…
Browse files Browse the repository at this point in the history
…uery (#11)
  • Loading branch information
AilinKid authored Mar 15, 2022
1 parent f2d0b3d commit 89b7b99
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 12 deletions.
89 changes: 77 additions & 12 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2102,27 +2102,56 @@ func (a *havingWindowAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, s
return n, false
}

func dfsResolveFromInsideJoin(v *ast.ColumnNameExpr, p LogicalPlan) ([]*expression.Column, types.NameSlice, int, error) {
// For SQL like `select t2.a from t1 join t2 using(a) where t2.a > 0
// order by t2.a`, the query plan will be `join->selection->sort`. The
// schema of selection will be `[t1.a]`, thus we need to recursively
// retrieve the `t2.a` from the underlying join.
idx, err := expression.FindFieldName(p.OutputNames(), v.Name)
if err != nil {
return nil, nil, -1, err
}
if idx >= 0 {
return p.Schema().Columns, p.OutputNames(), idx, err
}
switch x := p.(type) {
case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow:
return dfsResolveFromInsideJoin(v, p.Children()[0])
case *LogicalJoin:
if len(x.fullNames) != 0 {
idx, err = expression.FindFieldName(x.fullNames, v.Name)
schemaCols, outputNames := x.fullSchema.Columns, x.fullNames
if err == nil && idx >= 0 {
return schemaCols, outputNames, idx, err
}
}
}
// -1, nil
return nil, nil, idx, err
}

func (a *havingWindowAndOrderbyExprResolver) resolveFromPlan(v *ast.ColumnNameExpr, p LogicalPlan) (int, error) {
idx, err := expression.FindFieldName(p.OutputNames(), v.Name)
if err != nil {
return -1, err
}
schemaCols, outputNames := p.Schema().Columns, p.OutputNames()
if idx < 0 {
// For SQL like `select t2.a from t1 join t2 using(a) where t2.a > 0
// order by t2.a`, the query plan will be `join->selection->sort`. The
// schema of selection will be `[t1.a]`, thus we need to recursively
// retrieve the `t2.a` from the underlying join.
switch x := p.(type) {
case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow:
return a.resolveFromPlan(v, p.Children()[0])
case *LogicalJoin:
if len(x.fullNames) != 0 {
idx, err = expression.FindFieldName(x.fullNames, v.Name)
schemaCols, outputNames = x.fullSchema.Columns, x.fullNames
// maybe the referred column is in the outer schema stack.
for i := len(a.outerSchemas) - 1; i >= 0; i-- {
outerSchema, outerName := a.outerSchemas[i], a.outerNames[i]
idx, err = expression.FindFieldName(outerName, v.Name)
if err == nil && idx >= 0 {
schemaCols, outputNames = outerSchema.Columns, outerName
break
}
}
if err != nil || idx < 0 {
}
if idx < 0 {
// maybe the referred column is in the inside join's full schema.
schemaCols, outputNames, idx, err = dfsResolveFromInsideJoin(v, p)
if idx < 0 {
// nowhere to be found.
return -1, err
}
}
Expand Down Expand Up @@ -2645,6 +2674,20 @@ func (b *PlanBuilder) resolveCorrelatedAggregates(ctx context.Context, sel *ast.
}
correlatedAggMap := make(map[*ast.AggregateFuncExpr]int)
for _, aggFunc := range correlatedAggList {
colMap := make(map[*types.FieldName]struct{}, len(p.Schema().Columns))
allColFromAggExprNode(p, aggFunc, colMap)
for k := range colMap {
colName := &ast.ColumnName{
Schema: k.DBName,
Table: k.TblName,
Name: k.ColName,
}
sel.Fields.Fields = append(sel.Fields.Fields, &ast.SelectField{
Auxiliary: true,
AuxiliaryColInAgg: true,
Expr: &ast.ColumnNameExpr{Name: colName},
})
}
correlatedAggMap[aggFunc] = len(sel.Fields.Fields)
sel.Fields.Fields = append(sel.Fields.Fields, &ast.SelectField{
Auxiliary: true,
Expand Down Expand Up @@ -3213,6 +3256,28 @@ func (c *colResolverForOnlyFullGroupBy) Leave(node ast.Node) (ast.Node, bool) {
return node, true
}

type aggColNameResolver struct {
colNameResolver
}

func (c *aggColNameResolver) Enter(inNode ast.Node) (ast.Node, bool) {
switch inNode.(type) {
case *ast.ColumnNameExpr:
return inNode, true
}
return inNode, false
}

func allColFromAggExprNode(p LogicalPlan, n ast.Node, names map[*types.FieldName]struct{}) {
extractor := &aggColNameResolver{
colNameResolver: colNameResolver{
p: p,
names: names,
},
}
n.Accept(extractor)
}

type colNameResolver struct {
p LogicalPlan
names map[*types.FieldName]struct{}
Expand Down
8 changes: 8 additions & 0 deletions planner/funcdep/only_full_group_by_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,12 @@ func TestOnlyFullGroupByOldCases(t *testing.T) {
tk.MustQuery("select t1.a from t1 join t2 on t2.a=t1.a group by t2.a having min(t2.b) > 0;")
tk.MustQuery("select t2.a, count(t2.b) from t1 join t2 using (a) where t1.a = 1;")
tk.MustQuery("select count(t2.b) from t1 join t2 using (a) order by t2.a;")

// test issue #30024
tk.MustExec("drop table if exists t1,t2;")
tk.MustExec("CREATE TABLE t1 (a INT, b INT, c INT DEFAULT 0);")
tk.MustExec("INSERT INTO t1 (a, b) VALUES (3,3), (2,2), (3,3), (2,2), (3,3), (4,4);")
tk.MustExec("CREATE TABLE t2 (a INT, b INT, c INT DEFAULT 0);")
tk.MustExec("INSERT INTO t2 (a, b) VALUES (3,3), (2,2), (3,3), (2,2), (3,3), (4,4);")
tk.MustQuery("SELECT t1.a FROM t1 GROUP BY t1.a HAVING t1.a IN (SELECT t2.a FROM t2 ORDER BY SUM(t1.b));")
}

0 comments on commit 89b7b99

Please sign in to comment.