Skip to content

Commit

Permalink
Fix: make some SQL builders pure (#1526)
Browse files Browse the repository at this point in the history
* Fix: make some SQL builders pure

* Fixup

* Fixup

* Fixup

* Comment

* PR feedback, replace slice syntax with .between(low, high, ..)

* Add copy arg to some methods, fix tests

* Add copy flag to paren too

* Use convert instead of maybe_parse for between helper

* Fix test

* Change isin so that it receives t.Any for expressions
  • Loading branch information
georgesittas authored May 4, 2023
1 parent 911e4e9 commit 2a6a3e7
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 50 deletions.
120 changes: 72 additions & 48 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def load(cls, obj):


class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts):
def and_(self, *expressions, dialect=None, copy=True, **opts):
"""
AND this condition with one or multiple expressions.
Expand All @@ -662,14 +662,15 @@ def and_(self, *expressions, dialect=None, **opts):
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
opts (kwargs): other options to use to parse the input expressions.
Returns:
And: the new condition.
"""
return and_(self, *expressions, dialect=dialect, **opts)
return and_(self, *expressions, dialect=dialect, copy=copy, **opts)

def or_(self, *expressions, dialect=None, **opts):
def or_(self, *expressions, dialect=None, copy=True, **opts):
"""
OR this condition with one or multiple expressions.
Expand All @@ -681,50 +682,59 @@ def or_(self, *expressions, dialect=None, **opts):
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
opts (kwargs): other options to use to parse the input expressions.
Returns:
Or: the new condition.
"""
return or_(self, *expressions, dialect=dialect, **opts)
return or_(self, *expressions, dialect=dialect, copy=copy, **opts)

def not_(self):
def not_(self, copy=True):
"""
Wrap this condition with NOT.
Example:
>>> condition("x=1").not_().sql()
'NOT x = 1'
Args:
copy (bool): whether or not to copy this object.
Returns:
Not: the new condition.
"""
return not_(self)
return not_(self, copy=copy)

def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
this = self
other = convert(other)
this = self.copy()
other = convert(other, copy=True)
if not isinstance(this, klass) and not isinstance(other, klass):
this = _wrap(this, Binary)
other = _wrap(other, Binary)
if reverse:
return klass(this=other, expression=this)
return klass(this=this, expression=other)

def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
if isinstance(other, slice):
return Between(
this=self,
low=convert(other.start),
high=convert(other.stop),
)
return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]):
return Bracket(
this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)]
)

def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
def isin(
self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
) -> In:
return In(
this=self,
expressions=[convert(e) for e in expressions],
query=maybe_parse(query, **opts) if query else None,
this=_maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
)

def between(self, low: t.Any, high: t.Any, copy=True, **opts) -> Between:
return Between(
this=_maybe_copy(self, copy),
low=convert(low, copy=copy, **opts),
high=convert(high, copy=copy, **opts),
)

def like(self, other: ExpOrStr) -> Like:
Expand Down Expand Up @@ -809,10 +819,10 @@ def __ror__(self, other: ExpOrStr) -> Or:
return self._binop(Or, other, reverse=True)

def __neg__(self) -> Neg:
return Neg(this=_wrap(self, Binary))
return Neg(this=_wrap(self.copy(), Binary))

def __invert__(self) -> Not:
return not_(self)
return not_(self.copy())


class Predicate(Condition):
Expand Down Expand Up @@ -2611,7 +2621,7 @@ def join(
join.set("kind", kind.text)

if on:
on = and_(*ensure_collection(on), dialect=dialect, **opts)
on = and_(*ensure_collection(on), dialect=dialect, copy=copy, **opts)
join.set("on", on)

if using:
Expand Down Expand Up @@ -3540,14 +3550,20 @@ class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}

def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
this = self.copy() if copy else self
this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
return this
instance = _maybe_copy(self, copy)
instance.append(
"ifs",
If(
this=maybe_parse(condition, copy=copy, **opts),
true=maybe_parse(then, copy=copy, **opts),
),
)
return instance

def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
this = self.copy() if copy else self
this.set("default", maybe_parse(condition, **opts))
return this
instance = _maybe_copy(self, copy)
instance.set("default", maybe_parse(condition, copy=copy, **opts))
return instance


class Cast(Func):
Expand Down Expand Up @@ -4407,14 +4423,16 @@ def _apply_conjunction_builder(
if append and existing is not None:
expressions = [existing.this if into else existing] + list(expressions)

node = and_(*expressions, dialect=dialect, **opts)
node = and_(*expressions, dialect=dialect, copy=copy, **opts)

inst.set(arg, into(this=node) if into else node)
return inst


def _combine(expressions, operator, dialect=None, **opts):
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
def _combine(expressions, operator, dialect=None, copy=True, **opts):
expressions = [
condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions
]
this = expressions[0]
if expressions[1:]:
this = _wrap(this, Connector)
Expand Down Expand Up @@ -4628,7 +4646,7 @@ def delete(
return delete_expr


def condition(expression, dialect=None, **opts) -> Condition:
def condition(expression, dialect=None, copy=True, **opts) -> Condition:
"""
Initialize a logical condition expression.
Expand All @@ -4647,6 +4665,7 @@ def condition(expression, dialect=None, **opts) -> Condition:
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
copy (bool): Whether or not to copy `expression` (only applies to expressions).
**opts: other options to use to parse the input expressions (again, in the case
that the input expression is a SQL string).
Expand All @@ -4657,11 +4676,12 @@ def condition(expression, dialect=None, **opts) -> Condition:
expression,
into=Condition,
dialect=dialect,
copy=copy,
**opts,
)


def and_(*expressions, dialect=None, **opts) -> And:
def and_(*expressions, dialect=None, copy=True, **opts) -> And:
"""
Combine multiple conditions with an AND logical operator.
Expand All @@ -4673,15 +4693,16 @@ def and_(*expressions, dialect=None, **opts) -> And:
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
And: the new condition
"""
return _combine(expressions, And, dialect, **opts)
return _combine(expressions, And, dialect, copy=copy, **opts)


def or_(*expressions, dialect=None, **opts) -> Or:
def or_(*expressions, dialect=None, copy=True, **opts) -> Or:
"""
Combine multiple conditions with an OR logical operator.
Expand All @@ -4693,15 +4714,16 @@ def or_(*expressions, dialect=None, **opts) -> Or:
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
Or: the new condition
"""
return _combine(expressions, Or, dialect, **opts)
return _combine(expressions, Or, dialect, copy=copy, **opts)


def not_(expression, dialect=None, **opts) -> Not:
def not_(expression, dialect=None, copy=True, **opts) -> Not:
"""
Wrap a condition with a NOT operator.
Expand All @@ -4721,13 +4743,14 @@ def not_(expression, dialect=None, **opts) -> Not:
this = condition(
expression,
dialect=dialect,
copy=copy,
**opts,
)
return Not(this=_wrap(this, Connector))


def paren(expression) -> Paren:
return Paren(this=expression)
def paren(expression, copy=True) -> Paren:
return Paren(this=_maybe_copy(expression, copy))


SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
Expand Down Expand Up @@ -5070,19 +5093,20 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
)


def convert(value) -> Expression:
def convert(value: t.Any, copy: bool = False) -> Expression:
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
Args:
value (Any): a python object
value: A python object.
copy: Whether or not to copy `value` (only applies to Expressions and collections).
Returns:
Expression: the equivalent expression object
Expression: the equivalent expression object.
"""
if isinstance(value, Expression):
return value
return _maybe_copy(value, copy)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, bool):
Expand All @@ -5100,13 +5124,13 @@ def convert(value) -> Expression:
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
return Tuple(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, list):
return Array(expressions=[convert(v) for v in value])
return Array(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, dict):
return Map(
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
keys=[convert(k, copy=copy) for k in value],
values=[convert(v, copy=copy) for v in value.values()],
)
raise ValueError(f"Cannot convert {value}")

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/eliminate_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def extract_condition(condition):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)

for condition in on.flatten():
if isinstance(condition, exp.EQ):
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _expand_using(scope, resolver):
tables[join_table] = None

join.args.pop("using")
join.set("on", exp.and_(*conditions))
join.set("on", exp.and_(*conditions, copy=False))

if column_tables:
for column in scope.columns:
Expand Down
1 change: 1 addition & 0 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
return exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
copy=False,
)
return expression

Expand Down
2 changes: 2 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3585,7 +3585,9 @@ def _parse_decode(self) -> t.Optional[exp.Expression]:
exp.and_(
exp.Is(this=expression.copy(), expression=exp.Null()),
exp.Is(this=search.copy(), expression=exp.Null()),
copy=False,
),
copy=False,
)
ifs.append(exp.If(this=cond, true=result))

Expand Down
16 changes: 16 additions & 0 deletions tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6385,6 +6385,14 @@ WITH "tmp1" AS (
"item"."i_brand" IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1')
OR "item"."i_class" IN ('personal', 'portable', 'reference', 'self-help')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_class" IN ('accessories', 'classical', 'fragrances', 'pants')
)
AND (
"item"."i_category" IN ('Books', 'Children', 'Electronics')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
Expand Down Expand Up @@ -7589,6 +7597,14 @@ WITH "tmp1" AS (
"item"."i_brand" IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1')
OR "item"."i_class" IN ('personal', 'portable', 'reference', 'self-help')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_class" IN ('accessories', 'classical', 'fragrances', 'pants')
)
AND (
"item"."i_category" IN ('Books', 'Children', 'Electronics')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
Expand Down
6 changes: 6 additions & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
class TestBuild(unittest.TestCase):
def test_build(self):
x = condition("x")
x_plus_one = x + 1

# Make sure we're not mutating x by changing its parent to be x_plus_one
self.assertIsNone(x.parent)
self.assertNotEqual(id(x_plus_one.this), id(x))

for expression, sql, *dialect in [
(lambda: x + 1, "x + 1"),
Expand Down Expand Up @@ -51,6 +56,7 @@ def test_build(self):
(lambda: x.neq(1), "x <> 1"),
(lambda: x.isin(1, "2"), "x IN (1, '2')"),
(lambda: x.isin(query="select 1"), "x IN (SELECT 1)"),
(lambda: x.between(1, 2), "x BETWEEN 1 AND 2"),
(lambda: 1 + x + 2 + 3, "1 + x + 2 + 3"),
(lambda: 1 + x * 2 + 3, "1 + (x * 2) + 3"),
(lambda: x * 1 * 2 + 3, "(x * 1 * 2) + 3"),
Expand Down

0 comments on commit 2a6a3e7

Please sign in to comment.