Skip to content

Commit

Permalink
[SQL] Support MySQL window function. (#688)
Browse files Browse the repository at this point in the history
* add node config support (#464)

* Support MySQL CAST_CHAR function.

* format style

* Support MySQL CAST_TIME function. (#570)

* Support MySQL CAST_DATE function. (#569)

* Support MySQL CAST_DATETIME function. (#568)

* Support MySQL CAST_TIME/CAST_DATE/CAST_DATETIME function

* Resolve Conversation

* Support CREATE TABLE

* add: IfNotExists

* fix: reformat imports

* Resolve Conversation

* Support window function: CUME_DIST

* Support window function: PERCENT_RANK

* Support window function: RANK

* Support window function: DENSE_RANK

* Support window function: FIRST_VALUE/LAST_VALUE/LAG/LEAD

* Support window function: NTH_VALUE/NTILE/ROW_NUMBER

* support argument(n) in LAG/LEAD

* convert Int64 to Float64 in test case
  • Loading branch information
csynineyang authored Jun 18, 2023
1 parent b9b7e96 commit c696476
Show file tree
Hide file tree
Showing 22 changed files with 2,878 additions and 0 deletions.
72 changes: 72 additions & 0 deletions pkg/runtime/function/cume_dist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package function

import (
"context"
)

import (
"github.com/pkg/errors"
)

import (
"github.com/arana-db/arana/pkg/proto"
)

// FuncCumeDist is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html
const FuncCumeDist = "CUME_DIST"

var _ proto.Func = (*cumedistFunc)(nil)

func init() {
proto.RegisterFunc(FuncCumeDist, cumedistFunc{})
}

type cumedistFunc struct{}

func (a cumedistFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) {
first, err := inputs[0].Value(ctx)
if first == nil || err != nil {
return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist)
}
firstDec, _ := first.Float64()
firstNum := 0

for _, it := range inputs[1:] {
val, err := it.Value(ctx)
if val == nil || err != nil {
return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist)
}
valDec, _ := val.Float64()

if valDec <= firstDec {
firstNum++
}
}

r := 0.0
if len(inputs) > 1 {
r = float64(firstNum) / float64(len(inputs)-1)
}
return proto.NewValueFloat64(r), nil
}

func (a cumedistFunc) NumInput() int {
return 0
}
127 changes: 127 additions & 0 deletions pkg/runtime/function/cume_dist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package function

import (
"context"
"fmt"
"testing"
)

import (
"github.com/stretchr/testify/assert"
)

import (
"github.com/arana-db/arana/pkg/proto"
)

func TestFuncCumeDist(t *testing.T) {
fn := proto.MustGetFunc(FuncCumeDist)
type tt struct {
inputs []proto.Value
want string
}
for _, v := range []tt{
{
[]proto.Value{
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(2),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(4),
proto.NewValueFloat64(4),
proto.NewValueFloat64(5),
},
"0.2222222222222222",
},
{
[]proto.Value{
proto.NewValueFloat64(2),
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(2),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(4),
proto.NewValueFloat64(4),
proto.NewValueFloat64(5),
},
"0.3333333333333333",
},
{
[]proto.Value{
proto.NewValueFloat64(3),
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(2),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(4),
proto.NewValueFloat64(4),
proto.NewValueFloat64(5),
},
"0.6666666666666666",
},
{
[]proto.Value{
proto.NewValueFloat64(4),
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(2),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(4),
proto.NewValueFloat64(4),
proto.NewValueFloat64(5),
},
"0.8888888888888888",
},
{
[]proto.Value{
proto.NewValueFloat64(5),
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(2),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(4),
proto.NewValueFloat64(4),
proto.NewValueFloat64(5),
},
"1",
},
} {
t.Run(v.want, func(t *testing.T) {
var inputs []proto.Valuer
for i := range v.inputs {
inputs = append(inputs, proto.ToValuer(v.inputs[i]))
}
out, err := fn.Apply(context.Background(), inputs...)
assert.NoError(t, err)
assert.Equal(t, v.want, fmt.Sprint(out))
})
}
}
70 changes: 70 additions & 0 deletions pkg/runtime/function/dense_rank.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package function

import (
"context"
)

import (
"github.com/pkg/errors"
)

import (
"github.com/arana-db/arana/pkg/proto"
)

// FuncDenseRank is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html
const FuncDenseRank = "DENSE_RANK"

var _ proto.Func = (*denserankFunc)(nil)

func init() {
proto.RegisterFunc(FuncDenseRank, denserankFunc{})
}

type denserankFunc struct{}

func (a denserankFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) {
first, err := inputs[0].Value(ctx)
if first == nil || err != nil {
return nil, errors.Wrapf(err, "cannot eval %s", FuncDenseRank)
}
firstDec, _ := first.Float64()
secondDec := firstDec
firstNum := 0

for _, it := range inputs[1:] {
val, err := it.Value(ctx)
if val == nil || err != nil {
return nil, errors.Wrapf(err, "cannot eval %s", FuncDenseRank)
}
valDec, _ := val.Float64()

if valDec < firstDec && valDec != secondDec {
firstNum++
secondDec = valDec
}
}

return proto.NewValueInt64(int64(firstNum) + 1), nil
}

func (a denserankFunc) NumInput() int {
return 0
}
127 changes: 127 additions & 0 deletions pkg/runtime/function/dense_rank_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package function

import (
"context"
"fmt"
"testing"
)

import (
"github.com/stretchr/testify/assert"
)

import (
"github.com/arana-db/arana/pkg/proto"
)

func TestFuncDenseRankt(t *testing.T) {
fn := proto.MustGetFunc(FuncDenseRank)
type tt struct {
inputs []proto.Value
want string
}
for _, v := range []tt{
{
[]proto.Value{
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(2),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(4),
proto.NewValueFloat64(4),
proto.NewValueFloat64(5),
},
"1",
},
{
[]proto.Value{
proto.NewValueFloat64(2),
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(2),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(4),
proto.NewValueFloat64(4),
proto.NewValueFloat64(5),
},
"2",
},
{
[]proto.Value{
proto.NewValueFloat64(3),
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(2),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(4),
proto.NewValueFloat64(4),
proto.NewValueFloat64(5),
},
"3",
},
{
[]proto.Value{
proto.NewValueFloat64(4),
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(2),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(4),
proto.NewValueFloat64(4),
proto.NewValueFloat64(5),
},
"4",
},
{
[]proto.Value{
proto.NewValueFloat64(5),
proto.NewValueFloat64(1),
proto.NewValueFloat64(1),
proto.NewValueFloat64(2),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(3),
proto.NewValueFloat64(4),
proto.NewValueFloat64(4),
proto.NewValueFloat64(5),
},
"5",
},
} {
t.Run(v.want, func(t *testing.T) {
var inputs []proto.Valuer
for i := range v.inputs {
inputs = append(inputs, proto.ToValuer(v.inputs[i]))
}
out, err := fn.Apply(context.Background(), inputs...)
assert.NoError(t, err)
assert.Equal(t, v.want, fmt.Sprint(out))
})
}
}
Loading

0 comments on commit c696476

Please sign in to comment.