Skip to content

Commit

Permalink
fix(snowflake): Wrap DIV0 operands if they're binary expressions (#4393)
Browse files Browse the repository at this point in the history
* fix(snowflake): Wrap DIV0 operands if they're binary ops

* Add _wrap variants

* Simplify _wrap

---------

Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
  • Loading branch information
VaggelisD and georgesittas authored Nov 14, 2024
1 parent 37c4809 commit 79f6783
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
9 changes: 6 additions & 3 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,14 @@ def _builder(args: t.List) -> E:

# https://docs.snowflake.com/en/sql-reference/functions/div0
def _build_if_from_div0(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)).and_(
exp.Is(this=seq_get(args, 0), expression=exp.null()).not_()
lhs = exp._wrap(seq_get(args, 0), exp.Binary)
rhs = exp._wrap(seq_get(args, 1), exp.Binary)

cond = exp.EQ(this=rhs, expression=exp.Literal.number(0)).and_(
exp.Is(this=lhs, expression=exp.null()).not_()
)
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
false = exp.Div(this=lhs, expression=rhs)
return exp.If(this=cond, true=true, false=false)


Expand Down
10 changes: 9 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6972,7 +6972,15 @@ def _combine(
return this


def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
@t.overload
def _wrap(expression: None, kind: t.Type[Expression]) -> None: ...


@t.overload
def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: ...


def _wrap(expression: t.Optional[E], kind: t.Type[Expression]) -> t.Optional[E] | Paren:
return Paren(this=expression) if isinstance(expression, kind) else expression


Expand Down
11 changes: 11 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,17 @@ def test_snowflake(self):
"duckdb": "CASE WHEN bar = 0 AND NOT foo IS NULL THEN 0 ELSE foo / bar END",
},
)
self.validate_all(
"DIV0(a - b, c - d)",
write={
"snowflake": "IFF((c - d) = 0 AND NOT (a - b) IS NULL, 0, (a - b) / (c - d))",
"sqlite": "IIF((c - d) = 0 AND NOT (a - b) IS NULL, 0, CAST((a - b) AS REAL) / (c - d))",
"presto": "IF((c - d) = 0 AND NOT (a - b) IS NULL, 0, CAST((a - b) AS DOUBLE) / (c - d))",
"spark": "IF((c - d) = 0 AND NOT (a - b) IS NULL, 0, (a - b) / (c - d))",
"hive": "IF((c - d) = 0 AND NOT (a - b) IS NULL, 0, (a - b) / (c - d))",
"duckdb": "CASE WHEN (c - d) = 0 AND NOT (a - b) IS NULL THEN 0 ELSE (a - b) / (c - d) END",
},
)
self.validate_all(
"ZEROIFNULL(foo)",
write={
Expand Down

0 comments on commit 79f6783

Please sign in to comment.