Skip to content

Commit

Permalink
[SQL] Support MySQL IF&IFNULL Function (#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
gongna-au authored Jan 10, 2023
1 parent e1ca687 commit dc902d0
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 0 deletions.
70 changes: 70 additions & 0 deletions pkg/runtime/function/if.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package function

import (
"context"
)

import (
"github.com/pkg/errors"
)

import (
"github.com/arana-db/arana/pkg/proto"
)

// FuncIf is https://dev.mysql.com/doc/refman/5.6/en/flow-control-functions.html#function_if
const FuncIf = "IF"

var _ proto.Func = (*ifFunc)(nil)

func init() {
proto.RegisterFunc(FuncIf, ifFunc{})
}

type ifFunc struct{}

func (i ifFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) {
val1, err := inputs[0].Value(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
val2, err := inputs[1].Value(ctx)
if err != nil {
return nil, errors.WithStack(err)
}

val3, err := inputs[2].Value(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
if val1 == nil {
return val3, nil
}
if v, _ := val1.Int64(); v != 0 {
if val2 != nil {
return val2, nil
}
}
return val3, nil
}

func (i ifFunc) NumInput() int {
return 3
}
61 changes: 61 additions & 0 deletions pkg/runtime/function/if_null.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package function

import (
"context"
)

import (
"github.com/pkg/errors"
)

import (
"github.com/arana-db/arana/pkg/proto"
)

// FuncIfNull is https://dev.mysql.com/doc/refman/5.6/en/flow-control-functions.html#function_ifnull
const FuncIfNull = "IFNULL"

var _ proto.Func = (*ifNullFunc)(nil)

func init() {
proto.RegisterFunc(FuncIfNull, ifNullFunc{})
}

type ifNullFunc struct{}

func (i ifNullFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) {
val1, err := inputs[0].Value(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
val2, err := inputs[1].Value(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
if val1 != nil {
return val1, nil
}
return val2, nil

}

func (i ifNullFunc) NumInput() int {
return 2
}
62 changes: 62 additions & 0 deletions pkg/runtime/function/if_null_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package function

import (
"context"
"fmt"
"testing"
)

import (
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
)

import (
"github.com/arana-db/arana/pkg/proto"
)

// FuncIFNULL is https://dev.mysql.com/doc/refman/5.6/en/flow-control-functions.html#function_ifnull
func TestIfNull(t *testing.T) {
fn := proto.MustGetFunc(FuncIfNull)
assert.Equal(t, 2, fn.NumInput())
type tt struct {
inFirst proto.Value
inSecond proto.Value
want proto.Value
}
for _, v := range []tt{
{nil, proto.NewValueInt64(2), proto.NewValueInt64(2)},
{nil, proto.NewValueString("yes"), proto.NewValueString("yes")},
{nil, proto.NewValueBool(true), proto.NewValueBool(true)},
{nil, proto.NewValueFloat64(9.009123), proto.NewValueFloat64(9.009123)},
{nil, proto.NewValueDecimal(decimal.NewFromInt(8080)), proto.NewValueDecimal(decimal.NewFromInt(8080))},
{proto.NewValueInt64(2), proto.NewValueFloat64(9.009123), proto.NewValueInt64(2)},
{proto.NewValueString("yes"), proto.NewValueInt64(2), proto.NewValueString("yes")},
{proto.NewValueBool(true), proto.NewValueFloat64(9.009123), proto.NewValueBool(true)},
{proto.NewValueFloat64(9.009123), proto.NewValueBool(true), proto.NewValueFloat64(9.009123)},
{proto.NewValueDecimal(decimal.NewFromInt(8080)), proto.NewValueFloat64(9.009123), proto.NewValueDecimal(decimal.NewFromInt(8080))},
} {
t.Run(fmt.Sprint(v.inFirst), func(t *testing.T) {
out, err := fn.Apply(context.Background(), proto.ToValuer(v.inFirst), proto.ToValuer(v.inSecond))
assert.NoError(t, err)
assert.Equal(t, v.want, out)
})
}
}
71 changes: 71 additions & 0 deletions pkg/runtime/function/if_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package function

import (
"context"
"fmt"
"testing"
)

import (
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
)

import (
"github.com/arana-db/arana/pkg/proto"
)

// FuncIf is https://dev.mysql.com/doc/refman/5.6/en/flow-control-functions.html#function_if
func TestIf(t *testing.T) {
fn := proto.MustGetFunc(FuncIf)
assert.Equal(t, 3, fn.NumInput())
type tt struct {
inFirst proto.Value
inSecond proto.Value
inThird proto.Value
want proto.Value
}
for _, v := range []tt{
{proto.NewValueBool(true), proto.NewValueInt64(1), proto.NewValueInt64(2), proto.NewValueInt64(1)},
{proto.NewValueBool(true), proto.NewValueString("yes"), proto.NewValueString("no"), proto.NewValueString("yes")},
{proto.NewValueBool(true), proto.NewValueFloat64(0.00000000001), proto.NewValueFloat64(0.00000000002), proto.NewValueFloat64(0.00000000001)},
{proto.NewValueBool(true), proto.NewValueDecimal(decimal.NewFromInt(0)), proto.NewValueDecimal(decimal.NewFromInt(1)), proto.NewValueDecimal(decimal.NewFromInt(0))},
{proto.NewValueBool(true), proto.NewValueString("yes"), proto.NewValueFloat64(0.00000000001), proto.NewValueString("yes")},
{proto.NewValueBool(true), proto.NewValueFloat64(0.00000000001), proto.NewValueDecimal(decimal.NewFromInt(200)), proto.NewValueFloat64(0.00000000001)},
{proto.NewValueBool(true), proto.NewValueDecimal(decimal.NewFromInt(1)), proto.NewValueInt64(2), proto.NewValueDecimal(decimal.NewFromInt(1))},
{proto.NewValueBool(true), nil, proto.NewValueInt64(2), proto.NewValueInt64(2)},
{proto.NewValueBool(true), nil, proto.NewValueString("no"), proto.NewValueString("no")},
{proto.NewValueBool(false), proto.NewValueInt64(1), proto.NewValueInt64(2), proto.NewValueInt64(2)},
{proto.NewValueBool(false), proto.NewValueString("yes"), proto.NewValueString("no"), proto.NewValueString("no")},
{proto.NewValueBool(false), proto.NewValueFloat64(0.00000000001), proto.NewValueFloat64(0.00000000002), proto.NewValueFloat64(0.00000000002)},
{proto.NewValueBool(false), proto.NewValueDecimal(decimal.NewFromInt(0)), proto.NewValueDecimal(decimal.NewFromInt(1)), proto.NewValueDecimal(decimal.NewFromInt(1))},
{proto.NewValueBool(false), proto.NewValueString("yes"), proto.NewValueFloat64(0.00000000001), proto.NewValueFloat64(0.00000000001)},
{proto.NewValueBool(false), proto.NewValueFloat64(0.00000000001), proto.NewValueDecimal(decimal.NewFromInt(200)), proto.NewValueDecimal(decimal.NewFromInt(200))},
{proto.NewValueBool(false), proto.NewValueDecimal(decimal.NewFromInt(1)), proto.NewValueInt64(2), proto.NewValueInt64(2)},
{nil, proto.NewValueDecimal(decimal.NewFromInt(1)), proto.NewValueInt64(2), proto.NewValueInt64(2)},
{nil, proto.NewValueDecimal(decimal.NewFromInt(1)), nil, nil},
} {
t.Run(fmt.Sprint(v.inFirst), func(t *testing.T) {
out, err := fn.Apply(context.Background(), proto.ToValuer(v.inFirst), proto.ToValuer(v.inSecond), proto.ToValuer(v.inThird))
assert.NoError(t, err)
assert.Equal(t, v.want, out)
})
}
}

0 comments on commit dc902d0

Please sign in to comment.