From 6405aed0dbdfb9335b8c790cd5813fd9ce3fc9a3 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+Themiscodes@users.noreply.github.com> Date: Thu, 8 Aug 2024 15:31:45 +0300 Subject: [PATCH] Feat!: Support pre-post statements in python models at creation time (#2977) --- sqlmesh/core/loader.py | 8 +- sqlmesh/core/model/decorator.py | 19 +- sqlmesh/core/model/definition.py | 268 +++++++++++++------------- tests/core/test_model.py | 77 +++++++- tests/core/test_snapshot_evaluator.py | 66 ++++++- 5 files changed, 291 insertions(+), 147 deletions(-) diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 43e39472d..cc5a7fab8 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -329,7 +329,7 @@ def _load_models( """ models = self._load_sql_models(macros, jinja_macros, audits) models.update(self._load_external_models(gateway)) - models.update(self._load_python_models()) + models.update(self._load_python_models(macros, jinja_macros)) return models @@ -392,7 +392,9 @@ def _load() -> Model: return models - def _load_python_models(self) -> UniqueKeyDict[str, Model]: + def _load_python_models( + self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry + ) -> UniqueKeyDict[str, Model]: """Loads the python models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") registry = model_registry.registry() @@ -418,6 +420,8 @@ def _load_python_models(self) -> UniqueKeyDict[str, Model]: path=path, module_path=context_path, defaults=config.model_defaults.dict(), + macros=macros, + jinja_macros=jinja_macros, dialect=config.model_defaults.dialect, time_column_format=config.time_column_format, physical_schema_override=config.physical_schema_override, diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 6f4f0c9f1..21a2f0ec1 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -8,8 +8,10 @@ from sqlglot import exp from sqlglot.dialects.dialect import DialectType +from sqlmesh.core.macros import MacroRegistry +from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.core import constants as c -from sqlmesh.core.dialect import MacroFunc +from sqlmesh.core.dialect import MacroFunc, parse_one from sqlmesh.core.model.definition import ( Model, create_python_model, @@ -75,6 +77,8 @@ def model( module_path: Path, path: Path, defaults: t.Optional[t.Dict[str, t.Any]] = None, + macros: t.Optional[MacroRegistry] = None, + jinja_macros: t.Optional[JinjaMacroRegistry] = None, dialect: t.Optional[str] = None, time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, physical_schema_override: t.Optional[t.Dict[str, str]] = None, @@ -123,7 +127,9 @@ def model( for key in ("pre_statements", "post_statements"): statements = common_kwargs.get(key) if statements: - common_kwargs[key] = [exp.maybe_parse(s, dialect=dialect) for s in statements] + common_kwargs[key] = [ + parse_one(s, dialect=dialect) if isinstance(s, str) else s for s in statements + ] if self.is_sql: query = MacroFunc(this=exp.Anonymous(this=entrypoint)) @@ -132,5 +138,12 @@ def model( ) return create_python_model( - self.name, entrypoint, columns=self.columns, dialect=dialect, **common_kwargs + self.name, + entrypoint, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + columns=self.columns, + dialect=dialect, + **common_kwargs, ) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 5d9e60377..aa0f430e1 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -117,6 +117,14 @@ class _Model(ModelMeta, frozen=True): mapping_schema: t.Dict[str, t.Any] = {} _full_depends_on: t.Optional[t.Set[str]] = None + __statement_renderers: t.Dict[int, ExpressionRenderer] = {} + + pre_statements_: t.Optional[t.List[exp.Expression]] = Field( + default=None, alias="pre_statements" + ) + post_statements_: t.Optional[t.List[exp.Expression]] = Field( + default=None, alias="post_statements" + ) _expressions_validator = expression_validator @@ -335,7 +343,17 @@ def render_pre_statements( Returns: The list of rendered expressions. """ - return [] + return self._render_statements( + self.pre_statements, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + expand=expand, + deployability_index=deployability_index, + engine_adapter=engine_adapter, + **kwargs, + ) def render_post_statements( self, @@ -367,7 +385,58 @@ def render_post_statements( Returns: The list of rendered expressions. """ - return [] + return self._render_statements( + self.post_statements, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + expand=expand, + deployability_index=deployability_index, + engine_adapter=engine_adapter, + **kwargs, + ) + + @property + def pre_statements(self) -> t.List[exp.Expression]: + return self.pre_statements_ or [] + + @property + def post_statements(self) -> t.List[exp.Expression]: + return self.post_statements_ or [] + + @property + def macro_definitions(self) -> t.List[d.MacroDef]: + """All macro definitions from the list of expressions.""" + return [s for s in self.pre_statements + self.post_statements if isinstance(s, d.MacroDef)] + + def _render_statements( + self, + statements: t.Iterable[exp.Expression], + **kwargs: t.Any, + ) -> t.List[exp.Expression]: + rendered = ( + self._statement_renderer(statement).render(**kwargs) + for statement in statements + if not isinstance(statement, d.MacroDef) + ) + return [r for expressions in rendered if expressions for r in expressions] + + def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer: + expression_key = id(expression) + if expression_key not in self.__statement_renderers: + self.__statement_renderers[expression_key] = ExpressionRenderer( + expression, + self.dialect, + self.macro_definitions, + path=self._path, + jinja_macro_registry=self.jinja_macros, + python_env=self.python_env, + only_execution_time=self.kind.only_execution_time, + default_catalog=self.default_catalog, + model_fqn=self.fqn, + ) + return self.__statement_renderers[expression_key] def render_signals( self, @@ -755,6 +824,18 @@ def _data_hash_values(self) -> t.List[str]: data.append(key) data.append(gen(value)) + for statement in (*self.pre_statements, *self.post_statements): + statement_exprs: t.List[exp.Expression] = [] + if not isinstance(statement, d.MacroDef): + rendered = self._statement_renderer(statement).render() + if self._is_metadata_statement(statement): + continue + if rendered: + statement_exprs = rendered + else: + statement_exprs = [statement] + data.extend(gen(e) for e in statement_exprs) + return data # type: ignore def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str: @@ -839,8 +920,24 @@ def _additional_metadata(self) -> t.List[str]: if metadata_only_macros: additional_metadata.append(str(metadata_only_macros)) + for statement in (*self.pre_statements, *self.post_statements): + if self._is_metadata_statement(statement): + additional_metadata.append(gen(statement)) + return additional_metadata + def _is_metadata_statement(self, statement: exp.Expression) -> bool: + if isinstance(statement, d.MacroDef): + return True + if isinstance(statement, d.MacroFunc): + target_macro = macro.get_registry().get(statement.name) + if target_macro: + return target_macro.metadata_only + target_macro = self.python_env.get(statement.name) + if target_macro: + return bool(target_macro.is_metadata) + return False + @property def full_depends_on(self) -> t.Set[str]: if not self._full_depends_on: @@ -857,16 +954,8 @@ def full_depends_on(self) -> t.Set[str]: class _SqlBasedModel(_Model): - pre_statements_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="pre_statements" - ) - post_statements_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="post_statements" - ) inline_audits_: t.Dict[str, t.Any] = Field(default={}, alias="inline_audits") - __statement_renderers: t.Dict[int, ExpressionRenderer] = {} - _expression_validator = expression_validator @field_validator("inline_audits_", mode="before") @@ -887,139 +976,10 @@ def _inline_audits_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any return inline_audits - def render_pre_statements( - self, - *, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - execution_time: t.Optional[TimeLike] = None, - snapshots: t.Optional[t.Collection[Snapshot]] = None, - expand: t.Iterable[str] = tuple(), - deployability_index: t.Optional[DeployabilityIndex] = None, - engine_adapter: t.Optional[EngineAdapter] = None, - **kwargs: t.Any, - ) -> t.List[exp.Expression]: - return self._render_statements( - self.pre_statements, - start=start, - end=end, - execution_time=execution_time, - snapshots=snapshots, - expand=expand, - deployability_index=deployability_index, - engine_adapter=engine_adapter, - **kwargs, - ) - - def render_post_statements( - self, - *, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - execution_time: t.Optional[TimeLike] = None, - snapshots: t.Optional[t.Collection[Snapshot]] = None, - expand: t.Iterable[str] = tuple(), - deployability_index: t.Optional[DeployabilityIndex] = None, - engine_adapter: t.Optional[EngineAdapter] = None, - **kwargs: t.Any, - ) -> t.List[exp.Expression]: - return self._render_statements( - self.post_statements, - start=start, - end=end, - execution_time=execution_time, - snapshots=snapshots, - expand=expand, - deployability_index=deployability_index, - engine_adapter=engine_adapter, - **kwargs, - ) - - @property - def pre_statements(self) -> t.List[exp.Expression]: - return self.pre_statements_ or [] - - @property - def post_statements(self) -> t.List[exp.Expression]: - return self.post_statements_ or [] - - @property - def macro_definitions(self) -> t.List[d.MacroDef]: - """All macro definitions from the list of expressions.""" - return [s for s in self.pre_statements + self.post_statements if isinstance(s, d.MacroDef)] - @property def inline_audits(self) -> t.Dict[str, ModelAudit]: return self.inline_audits_ - def _render_statements( - self, - statements: t.Iterable[exp.Expression], - **kwargs: t.Any, - ) -> t.List[exp.Expression]: - rendered = ( - self._statement_renderer(statement).render(**kwargs) - for statement in statements - if not isinstance(statement, d.MacroDef) - ) - return [r for expressions in rendered if expressions for r in expressions] - - def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer: - expression_key = id(expression) - if expression_key not in self.__statement_renderers: - self.__statement_renderers[expression_key] = ExpressionRenderer( - expression, - self.dialect, - self.macro_definitions, - path=self._path, - jinja_macro_registry=self.jinja_macros, - python_env=self.python_env, - only_execution_time=self.kind.only_execution_time, - default_catalog=self.default_catalog, - model_fqn=self.fqn, - ) - return self.__statement_renderers[expression_key] - - @property - def _data_hash_values(self) -> t.List[str]: - data_hash_values = super()._data_hash_values - - for statement in (*self.pre_statements, *self.post_statements): - statement_exprs: t.List[exp.Expression] = [] - if not isinstance(statement, d.MacroDef): - rendered = self._statement_renderer(statement).render() - if self._is_metadata_statement(statement): - continue - if rendered: - statement_exprs = rendered - else: - statement_exprs = [statement] - data_hash_values.extend(gen(e) for e in statement_exprs) - - return data_hash_values - - @property - def _additional_metadata(self) -> t.List[str]: - additional_metadata = super()._additional_metadata - - for statement in (*self.pre_statements, *self.post_statements): - if self._is_metadata_statement(statement): - additional_metadata.append(gen(statement)) - - return additional_metadata - - def _is_metadata_statement(self, statement: exp.Expression) -> bool: - if isinstance(statement, d.MacroDef): - return True - if isinstance(statement, d.MacroFunc): - target_macro = macro.get_registry().get(statement.name) - if target_macro: - return target_macro.metadata_only - target_macro = self.python_env.get(statement.name) - if target_macro: - return bool(target_macro.is_metadata) - return False - class SqlModel(_SqlBasedModel): """The model definition which relies on a SQL query to fetch the data. @@ -1868,8 +1828,11 @@ def create_python_model( entrypoint: str, python_env: t.Dict[str, Executable], *, + macros: t.Optional[MacroRegistry] = None, + jinja_macros: t.Optional[JinjaMacroRegistry] = None, defaults: t.Optional[t.Dict[str, t.Any]] = None, path: Path = Path(), + module_path: Path = Path(), time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, depends_on: t.Optional[t.Set[str]] = None, physical_schema_override: t.Optional[t.Dict[str, str]] = None, @@ -1890,6 +1853,32 @@ def create_python_model( """ # Find dependencies for python models by parsing code if they are not explicitly defined # Also remove self-references that are found + + pre_statements = kwargs.get("pre_statements", None) or [] + post_statements = kwargs.get("post_statements", None) or [] + + if pre_statements or post_statements: + jinja_macro_references, used_variables = extract_macro_references_and_variables( + *(gen(e) for e in pre_statements), + *(gen(e) for e in post_statements), + ) + + jinja_macros = (jinja_macros or JinjaMacroRegistry()).trim(jinja_macro_references) + for jinja_macro in jinja_macros.root_macros.values(): + used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1]) + + python_env.update( + _python_env( + [*pre_statements, *post_statements], + jinja_macro_references, + module_path, + macros or macro.get_registry(), + variables=variables, + used_variables=used_variables, + path=path, + ) + ) + parsed_depends_on, referenced_variables = ( _parse_dependencies(python_env, entrypoint) if python_env is not None else (set(), set()) ) @@ -1909,6 +1898,7 @@ def create_python_model( depends_on=depends_on, entrypoint=entrypoint, python_env=python_env, + jinja_macros=jinja_macros, physical_schema_override=physical_schema_override, **kwargs, ) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 1e6d1e042..048725021 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -51,7 +51,7 @@ from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory from sqlmesh.utils.date import TimeLike, to_datetime, to_ds, to_timestamp from sqlmesh.utils.errors import ConfigError, SQLMeshError -from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo +from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo, MacroExtractor from sqlmesh.utils.metaprogramming import Executable @@ -949,6 +949,66 @@ def test_seed_with_special_characters_in_column(tmp_path, assert_exp_eq): ) +def test_python_model_jinja_pre_post_statements(): + macros = """ + {% macro test_macro(v) %}{{ v }}{% endmacro %} + {% macro extra_macro(v) %}{{ v + 1 }}{% endmacro %} + """ + + jinja_macros = JinjaMacroRegistry() + jinja_macros.add_macros(MacroExtractor().extract(macros)) + + @model( + "db.test_model", + kind="full", + columns={"id": "string", "name": "string"}, + pre_statements=[ + "JINJA_STATEMENT_BEGIN;\n{% set table_name = 'x' %}\nCREATE OR REPLACE TABLE {{table_name}}{{ 1 + 1 }};\nJINJA_END;" + ], + post_statements=[ + "JINJA_STATEMENT_BEGIN;\nCREATE INDEX {{test_macro('idx')}} ON db.test_model(id);\nJINJA_END;", + parse_one("DROP TABLE x2;"), + ], + ) + def model_with_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), path=Path("."), dialect="duckdb", jinja_macros=jinja_macros + ) + + assert len(jinja_macros.root_macros) == 2 + assert len(python_model.jinja_macros.root_macros) == 1 + assert "test_macro" in python_model.jinja_macros.root_macros + assert "extra_macro" not in python_model.jinja_macros.root_macros + + expected_pre = [ + d.jinja_statement( + "{% set table_name = 'x' %}\nCREATE OR REPLACE TABLE {{table_name}}{{ 1 + 1 }};" + ), + ] + assert python_model.pre_statements == expected_pre + assert python_model.render_pre_statements()[0].sql() == 'CREATE OR REPLACE TABLE "x2"' + + expected_post = [ + d.jinja_statement("CREATE INDEX {{test_macro('idx')}} ON db.test_model(id);"), + *d.parse("DROP TABLE x2;"), + ] + assert python_model.post_statements == expected_post + assert ( + python_model.render_post_statements()[0].sql() + == 'CREATE INDEX "idx" ON "db"."test_model"("id" NULLS LAST)' + ) + assert python_model.render_post_statements()[1].sql() == 'DROP TABLE "x2"' + + def test_audits(): expressions = d.parse( """ @@ -1640,7 +1700,14 @@ def test_parse(assert_exp_eq): def test_python_model(assert_exp_eq) -> None: from functools import reduce - @model(name="my_model", kind="full", columns={'"COL"': "int"}, enabled=True) + @model( + name="my_model", + kind="full", + columns={'"COL"': "int"}, + pre_statements=["CACHE TABLE x AS SELECT 1;"], + post_statements=["DROP TABLE x;"], + enabled=True, + ) def my_model(context, **kwargs): context.table("foo") context.table(model_name=CONST + ".baz") @@ -1654,6 +1721,12 @@ def my_model(context, **kwargs): dialect="duckdb", ) + assert list(m.pre_statements) == [ + d.parse_one("CACHE TABLE x AS SELECT 1"), + ] + assert list(m.post_statements) == [ + d.parse_one("DROP TABLE x"), + ] assert m.enabled assert m.dialect == "duckdb" assert m.depends_on == {'"foo"', '"bar"."baz"'} diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index cac20b35f..5440f9b83 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -4,6 +4,7 @@ import logging import pytest +import pandas as pd from pathlib import Path from pytest_mock.plugin import MockerFixture from sqlglot import expressions as exp @@ -20,7 +21,7 @@ InsertOverwriteStrategy, ) from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.macros import RuntimeStage, macro +from sqlmesh.core.macros import RuntimeStage, macro, MacroEvaluator, MacroFunc from sqlmesh.core.model import ( Model, FullKind, @@ -33,6 +34,7 @@ ViewKind, load_sql_based_model, ExternalModel, + model, ) from sqlmesh.core.model.kind import OnDestructiveChange, ExternalKind from sqlmesh.core.node import IntervalUnit @@ -2215,6 +2217,68 @@ def test_create_post_statements_use_deployable_table( assert post_calls[0].sql(dialect="postgres") == expected_call +def test_create_pre_post_statements_python_model( + mocker: MockerFixture, adapter_mock, make_snapshot +): + evaluator = SnapshotEvaluator(adapter_mock) + + @macro() + def create_index( + evaluator: MacroEvaluator, + index_name: str, + model_name: str, + column: str, + ): + if evaluator.runtime_stage == "creating": + return f"CREATE INDEX IF NOT EXISTS {index_name} ON {model_name}({column});" + + @model( + "db.test_model", + kind="full", + columns={"id": "string", "name": "string"}, + pre_statements=["CREATE INDEX IF NOT EXISTS idx ON db.test_model(id);"], + post_statements=["@CREATE_INDEX('idx', 'db.test_model', id)"], + ) + def model_with_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + dialect="postgres", + ) + + assert len(python_model.python_env) == 3 + assert len(python_model.pre_statements) == 1 + assert len(python_model.post_statements) == 1 + assert isinstance(python_model.python_env["create_index"], Executable) + assert isinstance(python_model.pre_statements[0], exp.Create) + assert isinstance(python_model.post_statements[0], MacroFunc) + + snapshot = make_snapshot(python_model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) + expected_call = f'CREATE INDEX IF NOT EXISTS "idx" ON "sqlmesh__db"."db__test_model__{snapshot.version}" /* db.test_model */("id")' + + call_args = adapter_mock.execute.call_args_list + pre_calls = call_args[0][0][0] + assert len(pre_calls) == 1 + assert pre_calls[0].sql(dialect="postgres") == expected_call + + post_calls = call_args[1][0][0] + assert len(post_calls) == 1 + assert post_calls[0].sql(dialect="postgres") == expected_call + + def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, adapter_mock): model = SqlModel( name="test_schema.test_model",