Skip to content

Commit

Permalink
sem: unify division by zero check and fix it in a few places
Browse files Browse the repository at this point in the history
Release note (bug fix): Previously, in some cases, CockroachDB didn't
check whether the right argument of `Div` (`/`), `FloorDiv` (`//`),
or `Mod` (`%`) operations was zero, so instead of correctly returning
a "division by zero" error, we were returning `NaN`, and this is now
fixed. Additionally, the error message of "modulus by zero" has been
changed to "division by zero" to be inline with Postgres.
  • Loading branch information
yuzefovich committed Jun 5, 2020
1 parent c9b9c01 commit b973a82
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 33 deletions.
24 changes: 23 additions & 1 deletion pkg/sql/colexec/execgen/cmd/execgen/overloads_bin.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,29 @@ func (decimalCustomizer) getBinOpAssignFunc() assignFunc {

func (c floatCustomizer) getBinOpAssignFunc() assignFunc {
return func(op *lastArgWidthOverload, targetElem, leftElem, rightElem, targetCol, leftCol, rightCol string) string {
return fmt.Sprintf("%s = float64(%s) %s float64(%s)", targetElem, leftElem, op.overloadBase.OpStr, rightElem)
binOp := op.overloadBase.BinOp
computeBinOp := fmt.Sprintf("float64(%s) %s float64(%s)", leftElem, binOp, rightElem)
args := map[string]interface{}{
"CheckRightIsZero": binOp == tree.Div,
"Target": targetElem,
"Right": rightElem,
"ComputeBinOp": computeBinOp,
}
buf := strings.Builder{}
t := template.Must(template.New("").Parse(`
{
{{if .CheckRightIsZero}}
if {{.Right}} == 0.0 {
colexecerror.ExpectedError(tree.ErrDivByZero)
}
{{end}}
{{.Target}} = {{.ComputeBinOp}}
}
`))
if err := t.Execute(&buf, args); err != nil {
colexecerror.InternalError(err)
}
return buf.String()
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/logictest/testdata/logic_test/builtin_function
Original file line number Diff line number Diff line change
Expand Up @@ -877,10 +877,10 @@ SELECT mod(5.0::float, 2.0), mod(1.0::float, 0.0), mod(5, 2), mod(19.3::decimal,
# mod returns the same results as PostgreSQL 9.4.4
# in tests below (except for the error message).

query error mod\(\): zero modulus
query error mod\(\): division by zero
SELECT mod(5, 0)

query error mod\(\): zero modulus
query error mod\(\): division by zero
SELECT mod(5::decimal, 0::decimal)

query II
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/sem/builtins/math_builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ var mathBuiltins = map[string]builtinDefinition{
}, "Calculates `x`%`y`.", tree.VolatilityImmutable),
decimalOverload2("x", "y", func(x, y *apd.Decimal) (tree.Datum, error) {
if y.Sign() == 0 {
return nil, tree.ErrZeroModulus
return nil, tree.ErrDivByZero
}
dd := &tree.DDecimal{}
_, err := tree.HighPrecisionCtx.Rem(&dd.Decimal, x, y)
Expand All @@ -346,7 +346,7 @@ var mathBuiltins = map[string]builtinDefinition{
Fn: func(_ *tree.EvalContext, args tree.Datums) (tree.Datum, error) {
y := tree.MustBeDInt(args[1])
if y == 0 {
return nil, tree.ErrZeroModulus
return nil, tree.ErrDivByZero
}
x := tree.MustBeDInt(args[0])
return tree.NewDInt(x % y), nil
Expand Down
60 changes: 39 additions & 21 deletions pkg/sql/sem/tree/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ var (
// ErrDivByZero is reported on a division by zero.
ErrDivByZero = pgerror.New(pgcode.DivisionByZero, "division by zero")
errSqrtOfNegNumber = pgerror.New(pgcode.InvalidArgumentForPowerFunction, "cannot take square root of a negative number")
// ErrZeroModulus is reported when computing the rest of a division by zero.
ErrZeroModulus = pgerror.New(pgcode.DivisionByZero, "zero modulus")

big10E6 = big.NewInt(1e6)
big10E10 = big.NewInt(1e10)
Expand Down Expand Up @@ -1310,13 +1308,13 @@ var BinOps = map[BinaryOperator]binOpOverload{
ReturnType: types.Decimal,
Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) {
rInt := MustBeDInt(right)
if rInt == 0 {
return nil, ErrDivByZero
}
div := ctx.getTmpDec().SetFinite(int64(rInt), 0)
dd := &DDecimal{}
dd.SetFinite(int64(MustBeDInt(left)), 0)
cond, err := DecimalCtx.Quo(&dd.Decimal, &dd.Decimal, div)
if cond.DivisionByZero() {
return dd, ErrDivByZero
}
_, err := DecimalCtx.Quo(&dd.Decimal, &dd.Decimal, div)
return dd, err
},
Volatility: VolatilityImmutable,
Expand All @@ -1341,11 +1339,11 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := &left.(*DDecimal).Decimal
r := &right.(*DDecimal).Decimal
dd := &DDecimal{}
cond, err := DecimalCtx.Quo(&dd.Decimal, l, r)
if cond.DivisionByZero() {
return dd, ErrDivByZero
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
_, err := DecimalCtx.Quo(&dd.Decimal, l, r)
return dd, err
},
Volatility: VolatilityImmutable,
Expand All @@ -1357,12 +1355,12 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := &left.(*DDecimal).Decimal
r := MustBeDInt(right)
if r == 0 {
return nil, ErrDivByZero
}
dd := &DDecimal{}
dd.SetFinite(int64(r), 0)
cond, err := DecimalCtx.Quo(&dd.Decimal, l, &dd.Decimal)
if cond.DivisionByZero() {
return dd, ErrDivByZero
}
_, err := DecimalCtx.Quo(&dd.Decimal, l, &dd.Decimal)
return dd, err
},
Volatility: VolatilityImmutable,
Expand All @@ -1374,12 +1372,12 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := MustBeDInt(left)
r := &right.(*DDecimal).Decimal
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
dd.SetFinite(int64(l), 0)
cond, err := DecimalCtx.Quo(&dd.Decimal, &dd.Decimal, r)
if cond.DivisionByZero() {
return dd, ErrDivByZero
}
_, err := DecimalCtx.Quo(&dd.Decimal, &dd.Decimal, r)
return dd, err
},
Volatility: VolatilityImmutable,
Expand Down Expand Up @@ -1433,6 +1431,9 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := float64(*left.(*DFloat))
r := float64(*right.(*DFloat))
if r == 0.0 {
return nil, ErrDivByZero
}
return NewDFloat(DFloat(math.Trunc(l / r))), nil
},
Volatility: VolatilityImmutable,
Expand All @@ -1444,6 +1445,9 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := &left.(*DDecimal).Decimal
r := &right.(*DDecimal).Decimal
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
_, err := HighPrecisionCtx.QuoInteger(&dd.Decimal, l, r)
return dd, err
Expand Down Expand Up @@ -1474,7 +1478,7 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := MustBeDInt(left)
r := &right.(*DDecimal).Decimal
if r.Sign() == 0 {
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
Expand All @@ -1494,7 +1498,7 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
r := MustBeDInt(right)
if r == 0 {
return nil, ErrZeroModulus
return nil, ErrDivByZero
}
return NewDInt(MustBeDInt(left) % r), nil
},
Expand All @@ -1505,7 +1509,12 @@ var BinOps = map[BinaryOperator]binOpOverload{
RightType: types.Float,
ReturnType: types.Float,
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
return NewDFloat(DFloat(math.Mod(float64(*left.(*DFloat)), float64(*right.(*DFloat))))), nil
l := float64(*left.(*DFloat))
r := float64(*right.(*DFloat))
if r == 0.0 {
return nil, ErrDivByZero
}
return NewDFloat(DFloat(math.Mod(l, r))), nil
},
Volatility: VolatilityImmutable,
},
Expand All @@ -1516,6 +1525,9 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := &left.(*DDecimal).Decimal
r := &right.(*DDecimal).Decimal
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
_, err := HighPrecisionCtx.Rem(&dd.Decimal, l, r)
return dd, err
Expand All @@ -1529,6 +1541,9 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := &left.(*DDecimal).Decimal
r := MustBeDInt(right)
if r == 0 {
return nil, ErrDivByZero
}
dd := &DDecimal{}
dd.SetFinite(int64(r), 0)
_, err := HighPrecisionCtx.Rem(&dd.Decimal, l, &dd.Decimal)
Expand All @@ -1543,6 +1558,9 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := MustBeDInt(left)
r := &right.(*DDecimal).Decimal
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
dd.SetFinite(int64(l), 0)
_, err := HighPrecisionCtx.Rem(&dd.Decimal, &dd.Decimal, r)
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/sem/tree/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ func TestEvalError(t *testing.T) {
expr string
expected string
}{
{`1 % 0`, `zero modulus`},
{`1 % 0`, `division by zero`},
{`1 / 0`, `division by zero`},
{`1::float / 0::float`, `division by zero`},
{`1 // 0`, `division by zero`},
Expand Down
6 changes: 1 addition & 5 deletions pkg/sql/sem/tree/eval_test/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func TestEval(t *testing.T) {
result.Op,
typs,
nil, /* output */
nil, /* metadataSourcesQueue */
result.MetadataSources,
nil, /* toClose */
nil, /* outputStatsToTrace */
nil, /* cancelFlow */
Expand All @@ -228,10 +228,6 @@ func TestEval(t *testing.T) {
t.Fatalf("unexpected metadata: %+v", meta)
}
if row == nil {
// Might be some metadata.
if meta := mat.DrainHelper(); meta.Err != nil {
t.Fatalf("unexpected error: %s", meta.Err)
}
t.Fatal("unexpected end of input")
}
return row[0].Datum.String()
Expand Down
50 changes: 50 additions & 0 deletions pkg/sql/sem/tree/testdata/eval/arithmetic_operators
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,31 @@ eval
----
2

eval
1 // 0
----
division by zero

eval
-4.5 // 1.2
----
-3

eval
1.0 // 0.0
----
division by zero

eval
1.0 // 0
----
division by zero

eval
1 // 0.0
----
division by zero

eval
3.1 % 2.0
----
Expand All @@ -118,6 +138,11 @@ eval
----
2

eval
1 % 0
----
division by zero

eval
1 + NULL
----
Expand Down Expand Up @@ -148,11 +173,36 @@ eval
----
1

eval
1.0 % 0.0
----
division by zero

eval
1.0 % 0
----
division by zero

eval
1 % 0.0
----
division by zero

eval
-4.5:::float // 1.2:::float
----
-3.0

eval
1:::float // 0:::float
----
division by zero

eval
1:::float % 0:::float
----
division by zero

eval
2 ^ 3
----
Expand Down
1 change: 0 additions & 1 deletion pkg/sql/tests/rsg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@ var ignoredErrorPatterns = []string{
"overflow",
"requested length too large",
"division by zero",
"zero modulus",
"is out of range",

// Type checking
Expand Down

0 comments on commit b973a82

Please sign in to comment.