diff --git a/pkg/runtime/function/if.go b/pkg/runtime/function/if.go new file mode 100644 index 00000000..4df09091 --- /dev/null +++ b/pkg/runtime/function/if.go @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncIf is https://dev.mysql.com/doc/refman/5.6/en/flow-control-functions.html#function_if +const FuncIf = "IF" + +var _ proto.Func = (*ifFunc)(nil) + +func init() { + proto.RegisterFunc(FuncIf, ifFunc{}) +} + +type ifFunc struct{} + +func (i ifFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + val1, err := inputs[0].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + val2, err := inputs[1].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + + val3, err := inputs[2].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + if val1 == nil { + return val3, nil + } + if v, _ := val1.Int64(); v != 0 { + if val2 != nil { + return val2, nil + } + } + return val3, nil +} + +func (i ifFunc) NumInput() int { + return 3 +} diff --git a/pkg/runtime/function/if_null.go b/pkg/runtime/function/if_null.go new file mode 100644 index 00000000..a5a7e1d2 --- /dev/null +++ b/pkg/runtime/function/if_null.go @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncIfNull is https://dev.mysql.com/doc/refman/5.6/en/flow-control-functions.html#function_ifnull +const FuncIfNull = "IFNULL" + +var _ proto.Func = (*ifNullFunc)(nil) + +func init() { + proto.RegisterFunc(FuncIfNull, ifNullFunc{}) +} + +type ifNullFunc struct{} + +func (i ifNullFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + val1, err := inputs[0].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + val2, err := inputs[1].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + if val1 != nil { + return val1, nil + } + return val2, nil + +} + +func (i ifNullFunc) NumInput() int { + return 2 +} diff --git a/pkg/runtime/function/if_null_test.go b/pkg/runtime/function/if_null_test.go new file mode 100644 index 00000000..e4010437 --- /dev/null +++ b/pkg/runtime/function/if_null_test.go @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncIFNULL is https://dev.mysql.com/doc/refman/5.6/en/flow-control-functions.html#function_ifnull +func TestIfNull(t *testing.T) { + fn := proto.MustGetFunc(FuncIfNull) + assert.Equal(t, 2, fn.NumInput()) + type tt struct { + inFirst proto.Value + inSecond proto.Value + want proto.Value + } + for _, v := range []tt{ + {nil, proto.NewValueInt64(2), proto.NewValueInt64(2)}, + {nil, proto.NewValueString("yes"), proto.NewValueString("yes")}, + {nil, proto.NewValueBool(true), proto.NewValueBool(true)}, + {nil, proto.NewValueFloat64(9.009123), proto.NewValueFloat64(9.009123)}, + {nil, proto.NewValueDecimal(decimal.NewFromInt(8080)), proto.NewValueDecimal(decimal.NewFromInt(8080))}, + {proto.NewValueInt64(2), proto.NewValueFloat64(9.009123), proto.NewValueInt64(2)}, + {proto.NewValueString("yes"), proto.NewValueInt64(2), proto.NewValueString("yes")}, + {proto.NewValueBool(true), proto.NewValueFloat64(9.009123), proto.NewValueBool(true)}, + {proto.NewValueFloat64(9.009123), proto.NewValueBool(true), proto.NewValueFloat64(9.009123)}, + {proto.NewValueDecimal(decimal.NewFromInt(8080)), proto.NewValueFloat64(9.009123), proto.NewValueDecimal(decimal.NewFromInt(8080))}, + } { + t.Run(fmt.Sprint(v.inFirst), func(t *testing.T) { + out, err := fn.Apply(context.Background(), proto.ToValuer(v.inFirst), proto.ToValuer(v.inSecond)) + assert.NoError(t, err) + assert.Equal(t, v.want, out) + }) + } +} diff --git a/pkg/runtime/function/if_test.go b/pkg/runtime/function/if_test.go new file mode 100644 index 00000000..bc1b5b78 --- /dev/null +++ b/pkg/runtime/function/if_test.go @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncIf is https://dev.mysql.com/doc/refman/5.6/en/flow-control-functions.html#function_if +func TestIf(t *testing.T) { + fn := proto.MustGetFunc(FuncIf) + assert.Equal(t, 3, fn.NumInput()) + type tt struct { + inFirst proto.Value + inSecond proto.Value + inThird proto.Value + want proto.Value + } + for _, v := range []tt{ + {proto.NewValueBool(true), proto.NewValueInt64(1), proto.NewValueInt64(2), proto.NewValueInt64(1)}, + {proto.NewValueBool(true), proto.NewValueString("yes"), proto.NewValueString("no"), proto.NewValueString("yes")}, + {proto.NewValueBool(true), proto.NewValueFloat64(0.00000000001), proto.NewValueFloat64(0.00000000002), proto.NewValueFloat64(0.00000000001)}, + {proto.NewValueBool(true), proto.NewValueDecimal(decimal.NewFromInt(0)), proto.NewValueDecimal(decimal.NewFromInt(1)), proto.NewValueDecimal(decimal.NewFromInt(0))}, + {proto.NewValueBool(true), proto.NewValueString("yes"), proto.NewValueFloat64(0.00000000001), proto.NewValueString("yes")}, + {proto.NewValueBool(true), proto.NewValueFloat64(0.00000000001), proto.NewValueDecimal(decimal.NewFromInt(200)), proto.NewValueFloat64(0.00000000001)}, + {proto.NewValueBool(true), proto.NewValueDecimal(decimal.NewFromInt(1)), proto.NewValueInt64(2), proto.NewValueDecimal(decimal.NewFromInt(1))}, + {proto.NewValueBool(true), nil, proto.NewValueInt64(2), proto.NewValueInt64(2)}, + {proto.NewValueBool(true), nil, proto.NewValueString("no"), proto.NewValueString("no")}, + {proto.NewValueBool(false), proto.NewValueInt64(1), proto.NewValueInt64(2), proto.NewValueInt64(2)}, + {proto.NewValueBool(false), proto.NewValueString("yes"), proto.NewValueString("no"), proto.NewValueString("no")}, + {proto.NewValueBool(false), proto.NewValueFloat64(0.00000000001), proto.NewValueFloat64(0.00000000002), proto.NewValueFloat64(0.00000000002)}, + {proto.NewValueBool(false), proto.NewValueDecimal(decimal.NewFromInt(0)), proto.NewValueDecimal(decimal.NewFromInt(1)), proto.NewValueDecimal(decimal.NewFromInt(1))}, + {proto.NewValueBool(false), proto.NewValueString("yes"), proto.NewValueFloat64(0.00000000001), proto.NewValueFloat64(0.00000000001)}, + {proto.NewValueBool(false), proto.NewValueFloat64(0.00000000001), proto.NewValueDecimal(decimal.NewFromInt(200)), proto.NewValueDecimal(decimal.NewFromInt(200))}, + {proto.NewValueBool(false), proto.NewValueDecimal(decimal.NewFromInt(1)), proto.NewValueInt64(2), proto.NewValueInt64(2)}, + {nil, proto.NewValueDecimal(decimal.NewFromInt(1)), proto.NewValueInt64(2), proto.NewValueInt64(2)}, + {nil, proto.NewValueDecimal(decimal.NewFromInt(1)), nil, nil}, + } { + t.Run(fmt.Sprint(v.inFirst), func(t *testing.T) { + out, err := fn.Apply(context.Background(), proto.ToValuer(v.inFirst), proto.ToValuer(v.inSecond), proto.ToValuer(v.inThird)) + assert.NoError(t, err) + assert.Equal(t, v.want, out) + }) + } +}