diff --git a/sql/expression/function/logarithm.go b/sql/expression/function/logarithm.go new file mode 100644 index 000000000..8b38aeb3c --- /dev/null +++ b/sql/expression/function/logarithm.go @@ -0,0 +1,203 @@ +package function + +import ( + "math" + "reflect" + "fmt" + + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-errors.v1" +) + +// ErrInvalidArgumentForLogarithm is returned when an invalid argument value is passed to a +// logarithm function +var ErrInvalidArgumentForLogarithm = errors.NewKind("invalid argument value for logarithm: %v") + +// LogBaseMaker returns LogBase creator functions with a specific base. +func LogBaseMaker(base float64) func(e sql.Expression) sql.Expression { + return func(e sql.Expression) sql.Expression { + return NewLogBase(base, e) + } +} + +// LogBase is a function that returns the logarithm of a value with a specific base. +type LogBase struct { + expression.UnaryExpression + base float64 +} + +// NewLogBase creates a new LogBase expression. +func NewLogBase(base float64, e sql.Expression) sql.Expression { + return &LogBase{UnaryExpression: expression.UnaryExpression{Child: e}, base: base} +} + +func (l *LogBase) String() string { + switch l.base { + case float64(math.E): + return fmt.Sprintf("ln(%s)", l.Child) + case float64(10): + return fmt.Sprintf("log10(%s)", l.Child) + case float64(2): + return fmt.Sprintf("log2(%s)", l.Child) + default: + return fmt.Sprintf("log(%v, %s)", l.base, l.Child) + } +} + +// TransformUp implements the Expression interface. +func (l *LogBase) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + child, err := l.Child.TransformUp(f) + if err != nil { + return nil, err + } + return f(NewLogBase(l.base, child)) +} + +// Type returns the resultant type of the function. +func (l *LogBase) Type() sql.Type { + return sql.Float64 +} + +// IsNullable implements the sql.Expression interface. +func (l *LogBase) IsNullable() bool { + return l.base == float64(1) || l.base <= float64(0) || l.Child.IsNullable() +} + +// Eval implements the Expression interface. +func (l *LogBase) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + v, err := l.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if v == nil { + return nil, nil + } + + val, err := sql.Float64.Convert(v) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(v)) + } + return computeLog(val.(float64), l.base) +} + +// Log is a function that returns the natural logarithm of a value. +type Log struct { + expression.BinaryExpression +} + +// NewLn creates a new Log expression. +func NewLog(args ...sql.Expression) (sql.Expression, error) { + argLen := len(args) + if argLen == 0 || argLen > 2 { + return nil, sql.ErrInvalidArgumentNumber.New("1 or 2", argLen) + } + + if argLen == 1 { + return &Log{expression.BinaryExpression{Left: expression.NewLiteral(math.E, sql.Float64), Right: args[0]}}, nil + } else { + return &Log{expression.BinaryExpression{Left: args[0], Right: args[1]}}, nil + } +} + +func (l *Log) String() string { + return fmt.Sprintf("log(%s, %s)", l.Left, l.Right) +} + +// TransformUp implements the Expression interface. +func (l *Log) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + var args = make([]sql.Expression, 2) + arg, err := l.Left.TransformUp(f) + if err != nil { + return nil, err + } + args[0] = arg + + arg, err = l.Right.TransformUp(f) + if err != nil { + return nil, err + } + args[1] = arg + expr, err := NewLog(args...) + if err != nil { + return nil, err + } + + return f(expr) +} + +// Children implements the Expression interface. +func (l *Log) Children() []sql.Expression { + return []sql.Expression{l.Left, l.Right} +} + +// Type returns the resultant type of the function. +func (l *Log) Type() sql.Type { + return sql.Float64 +} + +// IsNullable implements the Expression interface. +func (l *Log) IsNullable() bool { + return l.Left.IsNullable() || l.Right.IsNullable() +} + +// Eval implements the Expression interface. +func (l *Log) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + left, err := l.Left.Eval(ctx, row) + if err != nil { + return nil, err + } + + if left == nil { + return nil, nil + } + + lhs, err := sql.Float64.Convert(left) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(left)) + } + + right, err := l.Right.Eval(ctx, row) + if err != nil { + return nil, err + } + + if right == nil { + return nil, nil + } + + rhs, err := sql.Float64.Convert(right) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(right)) + } + + // rhs becomes value, lhs becomes base + return computeLog(rhs.(float64), lhs.(float64)) +} + +func computeLog(v float64, base float64) (float64, error) { + if v <= 0 { + return float64(0), ErrInvalidArgumentForLogarithm.New(v) + } + if base == float64(1) || base <= float64(0) { + return float64(0), ErrInvalidArgumentForLogarithm.New(base) + } + switch base { + case float64(2): + return math.Log2(v), nil + case float64(10): + return math.Log10(v), nil + case math.E: + return math.Log(v), nil + default: + // LOG(BASE,V) is equivalent to LOG(V) / LOG(BASE). + return float64(math.Log(v) / math.Log(base)), nil + } +} diff --git a/sql/expression/function/logarithm_test.go b/sql/expression/function/logarithm_test.go new file mode 100644 index 000000000..b20bc2797 --- /dev/null +++ b/sql/expression/function/logarithm_test.go @@ -0,0 +1,209 @@ +package function + +import ( + "testing" + "math" + "fmt" + + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-errors.v1" + "github.com/stretchr/testify/require" +) + +var epsilon = math.Nextafter(1, 2) - 1 + +func TestLn(t *testing.T) { + var testCases = []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + err *errors.Kind + }{ + {"Input value is zero", sql.Float64, sql.NewRow(0), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is negative", sql.Float64, sql.NewRow(-1), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is valid string", sql.Float64, sql.NewRow("2"), float64(0.6931471805599453), nil}, + {"Input value is invalid string", sql.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType}, + {"Input value is valid float64", sql.Float64, sql.NewRow(3), float64(1.0986122886681096), nil}, + {"Input value is valid float32", sql.Float32, sql.NewRow(float32(6)), float64(1.791759469228055), nil}, + {"Input value is valid int64", sql.Int64, sql.NewRow(int64(8)), float64(2.0794415416798357), nil}, + {"Input value is valid int32", sql.Int32, sql.NewRow(int32(10)), float64(2.302585092994046), nil}, + } + + for _, tt := range testCases { + f := LogBaseMaker(math.E)(expression.NewGetField(0, tt.rowType, "", false)) + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.InEpsilonf(tt.expected, result, epsilon, fmt.Sprintf("Actual is: %v", result), ) + } + }) + } + + // Test Nil + f := LogBaseMaker(math.E)(expression.NewGetField(0, sql.Float64, "", true)) + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(nil)) + require.NoError(err) + require.Nil(result) + require.True(f.IsNullable()) +} + +func TestLog2(t *testing.T) { + var testCases = []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + err *errors.Kind + }{ + {"Input value is zero", sql.Float64, sql.NewRow(0), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is negative", sql.Float64, sql.NewRow(-1), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is valid string", sql.Float64, sql.NewRow("2"), float64(1), nil}, + {"Input value is invalid string", sql.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType}, + {"Input value is valid float64", sql.Float64, sql.NewRow(3), float64(1.5849625007211563), nil}, + {"Input value is valid float32", sql.Float32, sql.NewRow(float32(6)), float64(2.584962500721156), nil}, + {"Input value is valid int64", sql.Int64, sql.NewRow(int64(8)), float64(3), nil}, + {"Input value is valid int32", sql.Int32, sql.NewRow(int32(10)), float64(3.321928094887362), nil}, + } + + for _, tt := range testCases { + f := LogBaseMaker(float64(2))(expression.NewGetField(0, tt.rowType, "", false)) + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.InEpsilonf(tt.expected, result, epsilon, fmt.Sprintf("Actual is: %v", result), ) + } + }) + } + + // Test Nil + f := LogBaseMaker(float64(2))(expression.NewGetField(0, sql.Float64, "", true)) + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(nil)) + require.NoError(err) + require.Nil(result) + require.True(f.IsNullable()) +} + +func TestLog10(t *testing.T) { + var testCases = []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + err *errors.Kind + }{ + {"Input value is zero", sql.Float64, sql.NewRow(0), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is negative", sql.Float64, sql.NewRow(-1), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is valid string", sql.Float64, sql.NewRow("2"), float64(0.3010299956639812), nil}, + {"Input value is invalid string", sql.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType}, + {"Input value is valid float64", sql.Float64, sql.NewRow(3), float64(0.4771212547196624), nil}, + {"Input value is valid float32", sql.Float32, sql.NewRow(float32(6)), float64(0.7781512503836436), nil}, + {"Input value is valid int64", sql.Int64, sql.NewRow(int64(8)), float64(0.9030899869919435), nil}, + {"Input value is valid int32", sql.Int32, sql.NewRow(int32(10)), float64(1), nil}, + } + + for _, tt := range testCases { + f := LogBaseMaker(float64(10))(expression.NewGetField(0, tt.rowType, "", false)) + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.InEpsilonf(tt.expected, result, epsilon, fmt.Sprintf("Actual is: %v", result), ) + } + }) + } + + // Test Nil + f := LogBaseMaker(float64(10))(expression.NewGetField(0, sql.Float64, "", true)) + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(nil)) + require.NoError(err) + require.Nil(result) + require.True(f.IsNullable()) +} + +func TestLogInvalidArguments(t *testing.T) { + _, err := NewLog() + require.True(t, sql.ErrInvalidArgumentNumber.Is(err)) + + _, err = NewLog( + expression.NewLiteral(1, sql.Float64), + expression.NewLiteral(1, sql.Float64), + expression.NewLiteral(1, sql.Float64), + ) + require.True(t, sql.ErrInvalidArgumentNumber.Is(err)) +} + +func TestLog(t *testing.T) { + var testCases = []struct { + name string + input []sql.Expression + expected interface{} + err *errors.Kind + }{ + {"Input base is 1", []sql.Expression{expression.NewLiteral(float64(1), sql.Float64), expression.NewLiteral(float64(10), sql.Float64)}, nil, ErrInvalidArgumentForLogarithm}, + {"Input base is zero", []sql.Expression{expression.NewLiteral(float64(0), sql.Float64), expression.NewLiteral(float64(10), sql.Float64)}, nil, ErrInvalidArgumentForLogarithm}, + {"Input base is negative", []sql.Expression{expression.NewLiteral(float64(-5), sql.Float64), expression.NewLiteral(float64(10), sql.Float64)}, nil, ErrInvalidArgumentForLogarithm}, + {"Input base is valid string", []sql.Expression{expression.NewLiteral("4", sql.Text), expression.NewLiteral(float64(10), sql.Float64)}, float64(1.6609640474436813), nil}, + {"Input base is invalid string", []sql.Expression{expression.NewLiteral("bbb", sql.Text), expression.NewLiteral(float64(10), sql.Float64)}, nil, sql.ErrInvalidType}, + + {"Input value is zero", []sql.Expression{expression.NewLiteral(float64(0), sql.Float64)}, nil, ErrInvalidArgumentForLogarithm}, + {"Input value is negative", []sql.Expression{expression.NewLiteral(float64(-9), sql.Float64)}, nil, ErrInvalidArgumentForLogarithm}, + {"Input value is valid string", []sql.Expression{expression.NewLiteral("7", sql.Text)}, float64(1.9459101490553132), nil}, + {"Input value is invalid string", []sql.Expression{expression.NewLiteral("766j", sql.Text)}, nil, sql.ErrInvalidType}, + + {"Input base is valid float64", []sql.Expression{expression.NewLiteral(float64(5), sql.Float64), expression.NewLiteral(float64(99), sql.Float64)}, float64(2.855108491376949), nil}, + {"Input base is valid float32", []sql.Expression{expression.NewLiteral(float32(6), sql.Float32), expression.NewLiteral(float64(80), sql.Float64)}, float64(2.4456556306420936), nil}, + {"Input base is valid int64", []sql.Expression{expression.NewLiteral(int64(8), sql.Int64), expression.NewLiteral(float64(64), sql.Float64)}, float64(2), nil}, + {"Input base is valid int32", []sql.Expression{expression.NewLiteral(int32(10), sql.Int32), expression.NewLiteral(float64(100), sql.Float64)}, float64(2), nil}, + + {"Input value is valid float64", []sql.Expression{expression.NewLiteral(float64(5), sql.Float64), expression.NewLiteral(float64(66), sql.Float64)}, float64(2.6031788549643564), nil}, + {"Input value is valid float32", []sql.Expression{expression.NewLiteral(float32(3), sql.Float32), expression.NewLiteral(float64(50), sql.Float64)}, float64(3.560876795007312), nil}, + {"Input value is valid int64", []sql.Expression{expression.NewLiteral(int64(5), sql.Int64), expression.NewLiteral(float64(77), sql.Float64)}, float64(2.698958057527146), nil}, + {"Input value is valid int32", []sql.Expression{expression.NewLiteral(int32(4), sql.Int32), expression.NewLiteral(float64(40), sql.Float64)}, float64(2.6609640474436813), nil}, + } + + for _, tt := range testCases { + f, _ := NewLog(tt.input...) + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), nil) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.InEpsilonf(tt.expected, result, epsilon, fmt.Sprintf("Actual is: %v", result), ) + } + }) + } + + // Test Nil + f, _ := NewLog(expression.NewLiteral(nil, sql.Float64)) + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Nil(result) + require.True(f.IsNullable()) +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 818e3ae9a..d19324124 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -1,6 +1,8 @@ package function import ( + "math" + "gopkg.in/src-d/go-mysql-server.v0/sql" "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation" ) @@ -43,4 +45,8 @@ var Defaults = sql.Functions{ "coalesce": sql.FunctionN(NewCoalesce), "json_extract": sql.FunctionN(NewJSONExtract), "connection_id": sql.Function0(NewConnectionID), + "ln": sql.Function1(LogBaseMaker(float64(math.E))), + "log2": sql.Function1(LogBaseMaker(float64(2))), + "log10": sql.Function1(LogBaseMaker(float64(10))), + "log": sql.FunctionN(NewLog), }