Skip to content

Commit

Permalink
Feat: allow formatter to use CAST over :: syntax (#3173)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Sep 25, 2024
1 parent b9b5e45 commit 3b03f51
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 28 deletions.
2 changes: 2 additions & 0 deletions docs/reference/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ Options:
-t, --transpile TEXT Transpile project models to the specified
dialect.
--append-newline Include a newline at the end of each file.
--no-rewrite-casts Preserve the existing casts, without rewriting
them to use the :: syntax.
--normalize Whether or not to normalize identifiers to
lowercase.
--pad INTEGER Determines the pad size in a formatted string.
Expand Down
1 change: 1 addition & 0 deletions docs/reference/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Formatting settings for the `sqlmesh format` command and UI.
| `leading_comma` | Whether to use leading commas (Default: False) | boolean | N |
| `max_text_width` | The maximum text width in a segment before creating new lines (Default: 80) | int | N |
| `append_newline` | Whether to append a newline to the end of the file (Default: False) | boolean | N |
| `no_rewrite_casts` | Preserve the existing casts, without rewriting them to use the :: syntax. (Default: False) | boolean | N |

## UI

Expand Down
6 changes: 4 additions & 2 deletions docs/reference/notebook.md
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ options:

#### format
```
%format [--transpile TRANSPILE] [--append-newline] [--normalize]
[--pad PAD] [--indent INDENT]
%format [--transpile TRANSPILE] [--append-newline] [--no-rewrite-casts]
[--normalize] [--pad PAD] [--indent INDENT]
[--normalize-functions NORMALIZE_FUNCTIONS] [--leading-comma]
[--max-text-width MAX_TEXT_WIDTH] [--check]
Expand All @@ -453,6 +453,8 @@ options:
Transpile project models to the specified dialect.
--append-newline Whether or not to append a newline to the end of the
file.
--no-rewrite-casts Preserve the existing casts, without rewriting them
to use the :: syntax.
--normalize Whether or not to normalize identifiers to lowercase.
--pad PAD Determines the pad size in a formatted string.
--indent INDENT Determines the indentation size in a formatted string.
Expand Down
9 changes: 9 additions & 0 deletions sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ def evaluate(
help="Include a newline at the end of each file.",
default=None,
)
@click.option(
"--no-rewrite-casts",
is_flag=True,
help="Preserve the existing casts, without rewriting them to use the :: syntax.",
default=None,
)
@click.option(
"--normalize",
is_flag=True,
Expand Down Expand Up @@ -266,6 +272,9 @@ def evaluate(
@cli_analytics
def format(ctx: click.Context, **kwargs: t.Any) -> None:
"""Format all SQL models and audits."""
if kwargs.pop("no_rewrite_casts", None):
kwargs["rewrite_casts"] = False

if not ctx.obj.format(**{k: v for k, v in kwargs.items() if v is not None}):
ctx.exit(1)

Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/core/config/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class FormatConfig(BaseConfig):
leading_comma: Whether to use leading commas or not.
max_text_width: The maximum text width in a segment before creating new lines.
append_newline: Whether to append a newline to the end of the file or not.
no_rewrite_casts: Preserve the existing casts, without rewriting them to use the :: syntax.
"""

normalize: bool = False
Expand All @@ -25,6 +26,7 @@ class FormatConfig(BaseConfig):
leading_comma: bool = False
max_text_width: int = 80
append_newline: bool = False
no_rewrite_casts: bool = False

@property
def generator_options(self) -> t.Dict[str, t.Any]:
Expand All @@ -33,4 +35,4 @@ def generator_options(self) -> t.Dict[str, t.Any]:
Returns:
The generator options.
"""
return self.dict(exclude={"append_newline"})
return self.dict(exclude={"append_newline", "no_rewrite_casts"})
21 changes: 17 additions & 4 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,7 @@ def evaluate(
def format(
self,
transpile: t.Optional[str] = None,
rewrite_casts: t.Optional[bool] = None,
append_newline: t.Optional[bool] = None,
*,
check: t.Optional[bool] = None,
Expand All @@ -910,6 +911,7 @@ def format(
for target in format_targets.values():
if target._path is None or target._path.suffix != ".sql":
continue

with open(target._path, "r+", encoding="utf-8") as file:
before = file.read()
expressions = parse(before, default_dialect=self.config_for_node(target).dialect)
Expand All @@ -922,13 +924,24 @@ def format(
value=exp.Literal.string(transpile or target.dialect),
)
)
format = self.config_for_node(target).format
opts = {**format.generator_options, **kwargs}
after = format_model_expressions(expressions, transpile or target.dialect, **opts)

format_config = self.config_for_node(target).format
after = format_model_expressions(
expressions,
transpile or target.dialect,
rewrite_casts=(
rewrite_casts
if rewrite_casts is not None
else not format_config.no_rewrite_casts
),
**{**format_config.generator_options, **kwargs},
)

if append_newline is None:
append_newline = format.append_newline
append_newline = format_config.append_newline
if append_newline:
after += "\n"

if not check:
file.seek(0)
file.write(after)
Expand Down
42 changes: 24 additions & 18 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,13 +643,17 @@ def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None:


def format_model_expressions(
expressions: t.List[exp.Expression], dialect: t.Optional[str] = None, **kwargs: t.Any
expressions: t.List[exp.Expression],
dialect: t.Optional[str] = None,
rewrite_casts: bool = True,
**kwargs: t.Any,
) -> str:
"""Format a model's expressions into a standardized format.
Args:
expressions: The model's expressions, must be at least model def + query.
dialect: The dialect to render the expressions as.
rewrite_casts: Whether to rewrite all casts to use the :: syntax.
**kwargs: Additional keyword arguments to pass to the sql generator.
Returns:
Expand All @@ -660,26 +664,28 @@ def format_model_expressions(

*statements, query = expressions

def cast_to_colon(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Cast) and not any(
# Only convert CAST into :: if it doesn't have additional args set, otherwise this
# conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST)
arg
for name, arg in node.args.items()
if name not in ("this", "to")
):
this = node.this
if rewrite_casts:

if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
cast = DColonCast(this=this, to=node.to)
cast.comments = node.comments
node = cast
def cast_to_colon(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Cast) and not any(
# Only convert CAST into :: if it doesn't have additional args set, otherwise this
# conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST)
arg
for name, arg in node.args.items()
if name not in ("this", "to")
):
this = node.this

exp.replace_children(node, cast_to_colon)
return node
if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
cast = DColonCast(this=this, to=node.to)
cast.comments = node.comments
node = cast

exp.replace_children(node, cast_to_colon)
return node

query = query.copy()
exp.replace_children(query, cast_to_colon)
query = query.copy()
exp.replace_children(query, cast_to_colon)

return ";\n\n".join(
[
Expand Down
18 changes: 15 additions & 3 deletions sqlmesh/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def model(self, context: Context, line: str, sql: t.Optional[str] = None) -> Non
expressions = parse(file.read(), default_dialect=config.dialect)

formatted = format_model_expressions(
expressions, model.dialect, **config.format.generator_options
expressions,
model.dialect,
rewrite_casts=not config.format.no_rewrite_casts,
**config.format.generator_options,
)

self._shell.set_next_input(
Expand Down Expand Up @@ -703,6 +706,12 @@ def rewrite(self, context: Context, line: str, sql: str) -> None:
help="Whether or not to append a newline to the end of the file.",
default=None,
)
@argument(
"--no-rewrite-casts",
action="store_true",
help="Whether or not to preserve the existing casts, without rewriting them to use the :: syntax.",
default=None,
)
@argument(
"--normalize",
action="store_true",
Expand Down Expand Up @@ -745,8 +754,11 @@ def rewrite(self, context: Context, line: str, sql: str) -> None:
@pass_sqlmesh_context
def format(self, context: Context, line: str) -> bool:
"""Format all SQL models and audits."""
args = parse_argstring(self.format, line)
return context.format(**{k: v for k, v in vars(args).items() if v is not None})
format_opts = vars(parse_argstring(self.format, line))
if format_opts.pop("no_rewrite_casts", None):
format_opts["rewrite_casts"] = False

return context.format(**{k: v for k, v in format_opts.items() if v is not None})

@magic_arguments()
@argument("environment", type=str, help="The environment to diff local state against.")
Expand Down
19 changes: 19 additions & 0 deletions tests/core/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ def test_format_model_expressions():
SAFE_CAST('bla' AS INT64) AS FOO"""
)

x = format_model_expressions(
parse(
"""
MODEL(name foo);
SELECT 1::INT AS bla
"""
),
rewrite_casts=False,
)
assert (
x
== """MODEL (
name foo
);
SELECT
CAST(1 AS INT) AS bla"""
)


def test_macro_format():
assert parse_one("@EACH(ARRAY(1,2), x -> x)").sql() == "@EACH(ARRAY(1, 2), x -> x)"
Expand Down

0 comments on commit 3b03f51

Please sign in to comment.