Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
function: make array_length not fail with literal null (#767)
Browse files Browse the repository at this point in the history
function: make array_length not fail with literal null
  • Loading branch information
ajnavarro authored Jun 24, 2019
2 parents e1d8da3 + 2cc9356 commit 8702d43
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
8 changes: 8 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,14 @@ var queries = []struct {
[]sql.Row{{nil}},
},
{
`SELECT ARRAY_LENGTH(null)`,
[]sql.Row{{nil}},
},
{
`SELECT ARRAY_LENGTH("foo")`,
[]sql.Row{{nil}},
},
{
`SELECT * FROM mytable WHERE NULL AND i = 3`,
[]sql.Row{},
},
Expand Down
5 changes: 2 additions & 3 deletions sql/expression/function/arraylength.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package function // import "github.com/src-d/go-mysql-server/sql/expression/func

import (
"fmt"
"reflect"

"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
Expand Down Expand Up @@ -39,7 +38,7 @@ func (f *ArrayLength) TransformUp(fn sql.TransformExprFunc) (sql.Expression, err
// Eval implements the Expression interface.
func (f *ArrayLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
if t := f.Child.Type(); !sql.IsArray(t) && t != sql.JSON {
return nil, sql.ErrInvalidType.New(f.Child.Type().Type().String())
return nil, nil
}

child, err := f.Child.Eval(ctx, row)
Expand All @@ -53,7 +52,7 @@ func (f *ArrayLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {

array, ok := child.([]interface{})
if !ok {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(child))
return nil, nil
}

return int32(len(array)), nil
Expand Down
12 changes: 6 additions & 6 deletions sql/expression/function/arraylength_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package function
import (
"testing"

"github.com/stretchr/testify/require"
errors "gopkg.in/src-d/go-errors.v1"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/stretchr/testify/require"
errors "gopkg.in/src-d/go-errors.v1"
)

func TestArrayLength(t *testing.T) {
Expand All @@ -19,7 +19,7 @@ func TestArrayLength(t *testing.T) {
err *errors.Kind
}{
{"array is nil", sql.NewRow(nil), nil, nil},
{"array is not of right type", sql.NewRow(5), nil, sql.ErrInvalidType},
{"array is not of right type", sql.NewRow(5), nil, nil},
{"array is ok", sql.NewRow([]interface{}{1, 2, 3}), int32(3), nil},
}

Expand All @@ -40,7 +40,7 @@ func TestArrayLength(t *testing.T) {

f = NewArrayLength(expression.NewGetField(0, sql.Tuple(sql.Int64, sql.Int64), "", false))
require := require.New(t)
_, err := f.Eval(sql.NewEmptyContext(), []interface{}{int64(1), int64(2)})
require.Error(err)
require.True(sql.ErrInvalidType.Is(err))
v, err := f.Eval(sql.NewEmptyContext(), []interface{}{int64(1), int64(2)})
require.NoError(err)
require.Nil(v)
}

0 comments on commit 8702d43

Please sign in to comment.