diff --git a/engine_test.go b/engine_test.go index 3b94eceaf..18d1fe7d5 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1064,6 +1064,10 @@ var queries = []struct { `SELECT i AS foo FROM mytable ORDER BY mytable.i`, []sql.Row{{int64(1)}, {int64(2)}, {int64(3)}}, }, + { + `SELECT JSON_EXTRACT('[1, 2, 3]', '$.[0]')`, + []sql.Row{{float64(1)}}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/expression/function/json_extract.go b/sql/expression/function/json_extract.go index 5f5876b0f..1b35c90cd 100644 --- a/sql/expression/function/json_extract.go +++ b/sql/expression/function/json_extract.go @@ -47,16 +47,11 @@ func (j *JSONExtract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - js, err = sql.JSON.Convert(js) + doc, err := unmarshalVal(js) if err != nil { return nil, err } - var doc interface{} - if err := json.Unmarshal(js.([]byte), &doc); err != nil { - return nil, err - } - var result = make([]interface{}, len(j.Paths)) for i, p := range j.Paths { path, err := p.Eval(ctx, row) @@ -84,6 +79,20 @@ func (j *JSONExtract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return result, nil } +func unmarshalVal(v interface{}) (interface{}, error) { + v, err := sql.JSON.Convert(v) + if err != nil { + return nil, err + } + + var doc interface{} + if err := json.Unmarshal(v.([]byte), &doc); err != nil { + return nil, err + } + + return doc, nil +} + // IsNullable implements the sql.Expression interface. func (j *JSONExtract) IsNullable() bool { for _, p := range j.Paths { diff --git a/sql/type.go b/sql/type.go index 9401ca0d7..bd07721f0 100644 --- a/sql/type.go +++ b/sql/type.go @@ -671,7 +671,16 @@ func (t jsonT) SQL(v interface{}) sqltypes.Value { // Convert implements Type interface. func (t jsonT) Convert(v interface{}) (interface{}, error) { - return json.Marshal(v) + switch v := v.(type) { + case string: + var doc interface{} + if err := json.Unmarshal([]byte(v), &doc); err != nil { + return json.Marshal(v) + } + return json.Marshal(doc) + default: + return json.Marshal(v) + } } // Compare implements Type interface. diff --git a/sql/type_test.go b/sql/type_test.go index 0d4a01fd5..0b3eaa629 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -200,6 +200,7 @@ func TestBlob(t *testing.T) { func TestJSON(t *testing.T) { convert(t, JSON, "", []byte(`""`)) convert(t, JSON, []int{1, 2}, []byte("[1,2]")) + convert(t, JSON, `{"a": true, "b": 3}`, []byte(`{"a":true,"b":3}`)) lt(t, JSON, []byte("A"), []byte("B")) eq(t, JSON, []byte("A"), []byte("A"))