diff --git a/sql/core.go b/sql/core.go index 5d7fe82f2..14c916ef3 100644 --- a/sql/core.go +++ b/sql/core.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "math" + "strconv" "time" "gopkg.in/src-d/go-errors.v1" @@ -232,14 +233,37 @@ func EvaluateCondition(ctx *Context, cond Expression, row Row) (bool, error) { switch b := v.(type) { case bool: return b, nil - case int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8: - return b != 0, nil + case int: + return b != int(0), nil + case int64: + return b != int64(0), nil + case int32: + return b != int32(0), nil + case int16: + return b != int16(0), nil + case int8: + return b != int8(0), nil + case uint: + return b != uint(0), nil + case uint64: + return b != uint64(0), nil + case uint32: + return b != uint32(0), nil + case uint16: + return b != uint16(0), nil + case uint8: + return b != uint8(0), nil case time.Duration: return int64(b) != 0, nil case time.Time: return b.UnixNano() != 0, nil - case float32, float64: + case float64: return int(math.Round(v.(float64))) != 0, nil + case float32: + return int(math.Round(float64(v.(float32)))) != 0, nil + case string: + parsed, err := strconv.ParseFloat(v.(string), 64) + return err == nil && int(parsed) != 0, nil default: return false, nil } diff --git a/sql/core_test.go b/sql/core_test.go new file mode 100644 index 000000000..cf3e23acd --- /dev/null +++ b/sql/core_test.go @@ -0,0 +1,49 @@ +package sql_test + +import ( + "fmt" + "testing" + "time" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +var conditions = []struct { + evaluated bool + value interface{} + t sql.Type +}{ + {true, int16(1), sql.Int16}, + {false, int16(0), sql.Int16}, + {true, int32(1), sql.Int32}, + {false, int32(0), sql.Int32}, + {true, int(1), sql.Int64}, + {false, int(0), sql.Int64}, + {true, float32(1), sql.Float32}, + {true, float64(1), sql.Float64}, + {false, float32(0), sql.Float32}, + {false, float64(0), sql.Float64}, + {true, float32(0.5), sql.Float32}, + {true, float64(0.5), sql.Float64}, + {true, "1", sql.Text}, + {false, "0", sql.Text}, + {false, "foo", sql.Text}, + {false, "0.5", sql.Text}, + {false, time.Duration(0), sql.Timestamp}, + {true, time.Duration(1), sql.Timestamp}, + {false, false, sql.Boolean}, + {true, true, sql.Boolean}, +} + +func TestEvaluateCondition(t *testing.T) { + for _, v := range conditions { + t.Run(fmt.Sprint(v.value, " evaluated to ", v.evaluated, " type ", v.t), func(t *testing.T) { + require := require.New(t) + b, err := sql.EvaluateCondition(sql.NewEmptyContext(), expression.NewLiteral(v.value, v.t), sql.NewRow()) + require.NoError(err) + require.Equal(v.evaluated, b) + }) + } +}