Skip to content

Commit

Permalink
Feat: allow supplying dialect in diff, conditionally copy ASTs (#4208)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Oct 3, 2024
1 parent 89c0703 commit 977d9e5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
15 changes: 11 additions & 4 deletions sqlglot/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from sqlglot import Dialect, expressions as exp
from sqlglot.helper import ensure_list

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType


@dataclass(frozen=True)
class Insert:
Expand Down Expand Up @@ -63,6 +66,7 @@ def diff(
target: exp.Expression,
matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
delta_only: bool = False,
copy: bool = True,
**kwargs: t.Any,
) -> t.List[Edit]:
"""
Expand Down Expand Up @@ -91,6 +95,9 @@ def diff(
Note: expression references in this list must refer to the same node objects that are
referenced in source / target trees.
delta_only: excludes all `Keep` nodes from the diff.
copy: whether to copy the input expressions.
Note: if this is set to false, the caller must ensure that there are no shared references
in the two ASTs, otherwise the diffing algorithm may produce unexpected behavior.
kwargs: additional arguments to pass to the ChangeDistiller instance.
Returns:
Expand All @@ -110,8 +117,8 @@ def compute_node_mappings(
if id(old_node) in matching_ids
}

source_copy = source.copy()
target_copy = target.copy()
source_copy = source.copy() if copy else source
target_copy = target.copy() if copy else target

node_mappings = {
**compute_node_mappings(source, source_copy),
Expand Down Expand Up @@ -149,10 +156,10 @@ class ChangeDistiller:
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
"""

def __init__(self, f: float = 0.6, t: float = 0.6) -> None:
def __init__(self, f: float = 0.6, t: float = 0.6, dialect: DialectType = None) -> None:
self.f = f
self.t = t
self._sql_generator = Dialect().generator()
self._sql_generator = Dialect.get_or_raise(dialect).generator()

def diff(
self,
Expand Down
15 changes: 15 additions & 0 deletions tests/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,20 @@ def test_identifier(self):
],
)

def test_dialect_aware_diff(self):
from sqlglot.generator import logger

with self.assertLogs(logger) as cm:
# We want to assert there are no warnings, but the 'assertLogs' method does not support that.
# Therefore, we are adding a dummy warning, and then we will assert it is the only warning.
logger.warning("Dummy warning")

expression = parse_one("SELECT foo FROM bar FOR UPDATE", dialect="oracle")
self._validate_delta_only(
diff_delta_only(expression, expression.copy(), dialect="oracle"), []
)

self.assertEqual(["WARNING:sqlglot:Dummy warning"], cm.output)

def _validate_delta_only(self, actual_delta, expected_delta):
self.assertEqual(set(actual_delta), set(expected_delta))

0 comments on commit 977d9e5

Please sign in to comment.