diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 80bd7bd18..97b879082 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -106,11 +106,14 @@ def _builder(args: t.List) -> E: # https://docs.snowflake.com/en/sql-reference/functions/div0 def _build_if_from_div0(args: t.List) -> exp.If: - cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)).and_( - exp.Is(this=seq_get(args, 0), expression=exp.null()).not_() + lhs = exp._wrap(seq_get(args, 0), exp.Binary) + rhs = exp._wrap(seq_get(args, 1), exp.Binary) + + cond = exp.EQ(this=rhs, expression=exp.Literal.number(0)).and_( + exp.Is(this=lhs, expression=exp.null()).not_() ) true = exp.Literal.number(0) - false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) + false = exp.Div(this=lhs, expression=rhs) return exp.If(this=cond, true=true, false=false) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 7bc2f6e4e..679a844e8 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -6972,7 +6972,15 @@ def _combine( return this -def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: +@t.overload +def _wrap(expression: None, kind: t.Type[Expression]) -> None: ... + + +@t.overload +def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: ... + + +def _wrap(expression: t.Optional[E], kind: t.Type[Expression]) -> t.Optional[E] | Paren: return Paren(this=expression) if isinstance(expression, kind) else expression diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index e2db661e0..515a07c4f 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -605,6 +605,17 @@ def test_snowflake(self): "duckdb": "CASE WHEN bar = 0 AND NOT foo IS NULL THEN 0 ELSE foo / bar END", }, ) + self.validate_all( + "DIV0(a - b, c - d)", + write={ + "snowflake": "IFF((c - d) = 0 AND NOT (a - b) IS NULL, 0, (a - b) / (c - d))", + "sqlite": "IIF((c - d) = 0 AND NOT (a - b) IS NULL, 0, CAST((a - b) AS REAL) / (c - d))", + "presto": "IF((c - d) = 0 AND NOT (a - b) IS NULL, 0, CAST((a - b) AS DOUBLE) / (c - d))", + "spark": "IF((c - d) = 0 AND NOT (a - b) IS NULL, 0, (a - b) / (c - d))", + "hive": "IF((c - d) = 0 AND NOT (a - b) IS NULL, 0, (a - b) / (c - d))", + "duckdb": "CASE WHEN (c - d) = 0 AND NOT (a - b) IS NULL THEN 0 ELSE (a - b) / (c - d) END", + }, + ) self.validate_all( "ZEROIFNULL(foo)", write={