From 977d9e5a854b58b4469be1af6aa14a5bf5a4b8c6 Mon Sep 17 00:00:00 2001 From: Jo <46752250+georgesittas@users.noreply.github.com> Date: Thu, 3 Oct 2024 19:50:09 +0300 Subject: [PATCH] Feat: allow supplying dialect in diff, conditionally copy ASTs (#4208) --- sqlglot/diff.py | 15 +++++++++++---- tests/test_diff.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/sqlglot/diff.py b/sqlglot/diff.py index eaca8b8a7..023139ce2 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -14,6 +14,9 @@ from sqlglot import Dialect, expressions as exp from sqlglot.helper import ensure_list +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + @dataclass(frozen=True) class Insert: @@ -63,6 +66,7 @@ def diff( target: exp.Expression, matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, delta_only: bool = False, + copy: bool = True, **kwargs: t.Any, ) -> t.List[Edit]: """ @@ -91,6 +95,9 @@ def diff( Note: expression references in this list must refer to the same node objects that are referenced in source / target trees. delta_only: excludes all `Keep` nodes from the diff. + copy: whether to copy the input expressions. + Note: if this is set to false, the caller must ensure that there are no shared references + in the two ASTs, otherwise the diffing algorithm may produce unexpected behavior. kwargs: additional arguments to pass to the ChangeDistiller instance. Returns: @@ -110,8 +117,8 @@ def compute_node_mappings( if id(old_node) in matching_ids } - source_copy = source.copy() - target_copy = target.copy() + source_copy = source.copy() if copy else source + target_copy = target.copy() if copy else target node_mappings = { **compute_node_mappings(source, source_copy), @@ -149,10 +156,10 @@ class ChangeDistiller: Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf. """ - def __init__(self, f: float = 0.6, t: float = 0.6) -> None: + def __init__(self, f: float = 0.6, t: float = 0.6, dialect: DialectType = None) -> None: self.f = f self.t = t - self._sql_generator = Dialect().generator() + self._sql_generator = Dialect.get_or_raise(dialect).generator() def diff( self, diff --git a/tests/test_diff.py b/tests/test_diff.py index a83452329..edd3b267f 100644 --- a/tests/test_diff.py +++ b/tests/test_diff.py @@ -225,5 +225,20 @@ def test_identifier(self): ], ) + def test_dialect_aware_diff(self): + from sqlglot.generator import logger + + with self.assertLogs(logger) as cm: + # We want to assert there are no warnings, but the 'assertLogs' method does not support that. + # Therefore, we are adding a dummy warning, and then we will assert it is the only warning. + logger.warning("Dummy warning") + + expression = parse_one("SELECT foo FROM bar FOR UPDATE", dialect="oracle") + self._validate_delta_only( + diff_delta_only(expression, expression.copy(), dialect="oracle"), [] + ) + + self.assertEqual(["WARNING:sqlglot:Dummy warning"], cm.output) + def _validate_delta_only(self, actual_delta, expected_delta): self.assertEqual(set(actual_delta), set(expected_delta))