Skip to content

Commit

Permalink
expression: make the return column type always be nullable for some w…
Browse files Browse the repository at this point in the history
…indow function (pingcap#45965)

close pingcap#45964
  • Loading branch information
xzhangxian1008 authored Aug 11, 2023
1 parent 66033d5 commit db7a89b
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 80 deletions.
37 changes: 37 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
package executor_test

import (
"context"
"fmt"
"testing"

"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/testkit"
)

Expand Down Expand Up @@ -498,3 +500,38 @@ func TestIssue29947(t *testing.T) {
result.Check(testkit.Rows("2", "3"))
tk.MustExec("commit")
}

func testReturnColumnNullableAttribute(tk *testkit.TestKit, funcName string, isNullable bool) {
rs, err := tk.ExecWithContext(context.Background(), fmt.Sprintf("select %s over (partition by p order by o rows between 1 preceding and 1 following) as a from agg;", funcName))
tk.RequireNoError(err, "testReturnColumnNullableAttribute get error")
retField := rs.Fields()[0]
if isNullable {
tk.RequireNotEqual(mysql.NotNullFlag, (retField.Column.FieldType.GetFlag() & mysql.NotNullFlag), fmt.Sprintf("%s window function's return column should have nullable attribute", funcName))
} else {
tk.RequireEqual(mysql.NotNullFlag, (retField.Column.FieldType.GetFlag() & mysql.NotNullFlag), fmt.Sprintf("%s window function's return column should not have nullable attribute", funcName))
}
rs.Close()
}

func TestIssue45964(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec(`drop table if exists agg;`)
tk.MustExec("create table agg(p int not null, o int not null, v int not null);")
tk.MustExec(`INSERT INTO agg VALUES (0, 0, 1), (1, 1, 2), (1, 2, 3), (1, 3, 4), (1, 4, 5), (2, 5, 6), (2, 6, 7);`)

testReturnColumnNullableAttribute(tk, "first_value(v)", true)
testReturnColumnNullableAttribute(tk, "last_value(v)", true)
testReturnColumnNullableAttribute(tk, "nth_value(v, 2)", true)
testReturnColumnNullableAttribute(tk, "lead(v)", true)
testReturnColumnNullableAttribute(tk, "lag(v)", true)
testReturnColumnNullableAttribute(tk, "ntile(2)", true)
testReturnColumnNullableAttribute(tk, "sum(v)", true)
testReturnColumnNullableAttribute(tk, "count(v)", false)
testReturnColumnNullableAttribute(tk, "row_number()", false)
testReturnColumnNullableAttribute(tk, "rank()", false)
testReturnColumnNullableAttribute(tk, "dense_rank()", false)
testReturnColumnNullableAttribute(tk, "cume_dist()", false)
testReturnColumnNullableAttribute(tk, "percent_rank()", false)
}
17 changes: 17 additions & 0 deletions expression/aggregation/window_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tipb/go-tipb"
)
Expand Down Expand Up @@ -58,6 +59,22 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex
}

base, err := newBaseFuncDesc(ctx, name, args)

// Some window functions' return column type must be nullable or not nullable
switch name {
case ast.WindowFuncRowNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank, ast.WindowFuncCumeDist, ast.WindowFuncPercentRank,
ast.AggFuncCount, ast.AggFuncApproxCountDistinct, ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
base.RetTp.SetFlag(mysql.NotNullFlag)
case ast.WindowFuncLead, ast.WindowFuncLag:
if len(args) == 3 && ((args[0].GetType().GetFlag() & mysql.NotNullFlag) != 0) {
base.RetTp.SetFlag(mysql.NotNullFlag)
break
}
base.RetTp.DelFlag(mysql.NotNullFlag)
default:
base.RetTp.DelFlag(mysql.NotNullFlag)
}

if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions parser/ast/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,8 @@ const (
type WindowFuncExpr struct {
funcNode

// F is the function name.
F string
// Name is the function name.
Name string
// Args is the function args.
Args []ExprNode
// Distinct cannot be true for most window functions, except `max` and `min`.
Expand All @@ -914,7 +914,7 @@ type WindowFuncExpr struct {

// Restore implements Node interface.
func (n *WindowFuncExpr) Restore(ctx *format.RestoreCtx) error {
ctx.WriteKeyWord(n.F)
ctx.WriteKeyWord(n.Name)
ctx.WritePlain("(")
for i, v := range n.Args {
if i != 0 {
Expand Down
68 changes: 34 additions & 34 deletions parser/parser.go

Large diffs are not rendered by default.

68 changes: 34 additions & 34 deletions parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -8159,7 +8159,7 @@ SumExpr:
"AVG" '(' BuggyDefaultFalseDistinctOpt Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool)}
}
Expand All @@ -8175,47 +8175,47 @@ SumExpr:
| builtinBitAnd '(' Expression ')' OptWindowingClause
{
if $5 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3}, Spec: *($5.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3}, Spec: *($5.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$3}}
}
}
| builtinBitAnd '(' "ALL" Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4}, Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4}}
}
}
| builtinBitOr '(' Expression ')' OptWindowingClause
{
if $5 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3}, Spec: *($5.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3}, Spec: *($5.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$3}}
}
}
| builtinBitOr '(' "ALL" Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4}, Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4}}
}
}
| builtinBitXor '(' Expression ')' OptWindowingClause
{
if $5 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3}, Spec: *($5.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3}, Spec: *($5.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$3}}
}
}
| builtinBitXor '(' "ALL" Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4}, Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4}}
}
Expand All @@ -8227,15 +8227,15 @@ SumExpr:
| builtinCount '(' "ALL" Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4}, Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4}}
}
}
| builtinCount '(' Expression ')' OptWindowingClause
{
if $5 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3}, Spec: *($5.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3}, Spec: *($5.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$3}}
}
Expand All @@ -8244,7 +8244,7 @@ SumExpr:
{
args := []ast.ExprNode{ast.NewValueExpr(1, parser.charset, parser.collation)}
if $5 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: args, Spec: *($5.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: args, Spec: *($5.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: args}
}
Expand All @@ -8254,7 +8254,7 @@ SumExpr:
args := $4.([]ast.ExprNode)
args = append(args, $6.(ast.ExprNode))
if $8 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: args, Distinct: $3.(bool), Spec: *($8.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: args, Distinct: $3.(bool), Spec: *($8.(*ast.WindowSpec))}
} else {
agg := &ast.AggregateFuncExpr{F: $1, Args: args, Distinct: $3.(bool)}
if $5 != nil {
Expand All @@ -8266,47 +8266,47 @@ SumExpr:
| builtinMax '(' BuggyDefaultFalseDistinctOpt Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool)}
}
}
| builtinMin '(' BuggyDefaultFalseDistinctOpt Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool)}
}
}
| builtinSum '(' BuggyDefaultFalseDistinctOpt Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool)}
}
}
| builtinStddevPop '(' BuggyDefaultFalseDistinctOpt Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: ast.AggFuncStddevPop, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: ast.AggFuncStddevPop, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: ast.AggFuncStddevPop, Args: []ast.ExprNode{$4}, Distinct: $3.(bool)}
}
}
| builtinStddevSamp '(' BuggyDefaultFalseDistinctOpt Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Distinct: $3.(bool)}
}
}
| builtinVarPop '(' BuggyDefaultFalseDistinctOpt Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: ast.AggFuncVarPop, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: ast.AggFuncVarPop, Args: []ast.ExprNode{$4}, Distinct: $3.(bool), Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: ast.AggFuncVarPop, Args: []ast.ExprNode{$4}, Distinct: $3.(bool)}
}
Expand All @@ -8318,47 +8318,47 @@ SumExpr:
| "JSON_ARRAYAGG" '(' Expression ')' OptWindowingClause
{
if $5 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3}, Spec: *($5.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3}, Spec: *($5.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$3}}
}
}
| "JSON_ARRAYAGG" '(' "ALL" Expression ')' OptWindowingClause
{
if $6 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4}, Spec: *($6.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4}, Spec: *($6.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4}}
}
}
| "JSON_OBJECTAGG" '(' Expression ',' Expression ')' OptWindowingClause
{
if $7 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3, $5}, Spec: *($7.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3, $5}, Spec: *($7.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$3, $5}}
}
}
| "JSON_OBJECTAGG" '(' "ALL" Expression ',' Expression ')' OptWindowingClause
{
if $8 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4, $6}, Spec: *($8.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4, $6}, Spec: *($8.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4, $6}}
}
}
| "JSON_OBJECTAGG" '(' Expression ',' "ALL" Expression ')' OptWindowingClause
{
if $8 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3, $6}, Spec: *($8.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3, $6}, Spec: *($8.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$3, $6}}
}
}
| "JSON_OBJECTAGG" '(' "ALL" Expression ',' "ALL" Expression ')' OptWindowingClause
{
if $9 != nil {
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$4, $7}, Spec: *($9.(*ast.WindowSpec))}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$4, $7}, Spec: *($9.(*ast.WindowSpec))}
} else {
$$ = &ast.AggregateFuncExpr{F: $1, Args: []ast.ExprNode{$4, $7}}
}
Expand Down Expand Up @@ -9432,55 +9432,55 @@ WindowNameOrSpec:
WindowFuncCall:
"ROW_NUMBER" '(' ')' WindowingClause
{
$$ = &ast.WindowFuncExpr{F: $1, Spec: $4.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Spec: $4.(ast.WindowSpec)}
}
| "RANK" '(' ')' WindowingClause
{
$$ = &ast.WindowFuncExpr{F: $1, Spec: $4.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Spec: $4.(ast.WindowSpec)}
}
| "DENSE_RANK" '(' ')' WindowingClause
{
$$ = &ast.WindowFuncExpr{F: $1, Spec: $4.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Spec: $4.(ast.WindowSpec)}
}
| "CUME_DIST" '(' ')' WindowingClause
{
$$ = &ast.WindowFuncExpr{F: $1, Spec: $4.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Spec: $4.(ast.WindowSpec)}
}
| "PERCENT_RANK" '(' ')' WindowingClause
{
$$ = &ast.WindowFuncExpr{F: $1, Spec: $4.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Spec: $4.(ast.WindowSpec)}
}
| "NTILE" '(' SimpleExpr ')' WindowingClause
{
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3}, Spec: $5.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3}, Spec: $5.(ast.WindowSpec)}
}
| "LEAD" '(' Expression OptLeadLagInfo ')' OptNullTreatment WindowingClause
{
args := []ast.ExprNode{$3}
if $4 != nil {
args = append(args, $4.([]ast.ExprNode)...)
}
$$ = &ast.WindowFuncExpr{F: $1, Args: args, IgnoreNull: $6.(bool), Spec: $7.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Args: args, IgnoreNull: $6.(bool), Spec: $7.(ast.WindowSpec)}
}
| "LAG" '(' Expression OptLeadLagInfo ')' OptNullTreatment WindowingClause
{
args := []ast.ExprNode{$3}
if $4 != nil {
args = append(args, $4.([]ast.ExprNode)...)
}
$$ = &ast.WindowFuncExpr{F: $1, Args: args, IgnoreNull: $6.(bool), Spec: $7.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Args: args, IgnoreNull: $6.(bool), Spec: $7.(ast.WindowSpec)}
}
| "FIRST_VALUE" '(' Expression ')' OptNullTreatment WindowingClause
{
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3}, IgnoreNull: $5.(bool), Spec: $6.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3}, IgnoreNull: $5.(bool), Spec: $6.(ast.WindowSpec)}
}
| "LAST_VALUE" '(' Expression ')' OptNullTreatment WindowingClause
{
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3}, IgnoreNull: $5.(bool), Spec: $6.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3}, IgnoreNull: $5.(bool), Spec: $6.(ast.WindowSpec)}
}
| "NTH_VALUE" '(' Expression ',' SimpleExpr ')' OptFromFirstLast OptNullTreatment WindowingClause
{
$$ = &ast.WindowFuncExpr{F: $1, Args: []ast.ExprNode{$3, $5}, FromLast: $7.(bool), IgnoreNull: $8.(bool), Spec: $9.(ast.WindowSpec)}
$$ = &ast.WindowFuncExpr{Name: $1, Args: []ast.ExprNode{$3, $5}, FromLast: $7.(bool), IgnoreNull: $8.(bool), Spec: $9.(ast.WindowSpec)}
}

OptLeadLagInfo:
Expand Down
2 changes: 1 addition & 1 deletion planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
index, ok = er.windowMap[v]
}
if !ok {
er.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.F))
er.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.Name))
return inNode, true
}
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
Expand Down
Loading

0 comments on commit db7a89b

Please sign in to comment.