Skip to content

Commit

Permalink
Feat!: Support pre-post statements in python models at creation time (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Themiscodes authored Aug 8, 2024
1 parent 9c575d3 commit 6405aed
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 147 deletions.
8 changes: 6 additions & 2 deletions sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _load_models(
"""
models = self._load_sql_models(macros, jinja_macros, audits)
models.update(self._load_external_models(gateway))
models.update(self._load_python_models())
models.update(self._load_python_models(macros, jinja_macros))

return models

Expand Down Expand Up @@ -392,7 +392,9 @@ def _load() -> Model:

return models

def _load_python_models(self) -> UniqueKeyDict[str, Model]:
def _load_python_models(
self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
) -> UniqueKeyDict[str, Model]:
"""Loads the python models into a Dict"""
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
registry = model_registry.registry()
Expand All @@ -418,6 +420,8 @@ def _load_python_models(self) -> UniqueKeyDict[str, Model]:
path=path,
module_path=context_path,
defaults=config.model_defaults.dict(),
macros=macros,
jinja_macros=jinja_macros,
dialect=config.model_defaults.dialect,
time_column_format=config.time_column_format,
physical_schema_override=config.physical_schema_override,
Expand Down
19 changes: 16 additions & 3 deletions sqlmesh/core/model/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from sqlglot import exp
from sqlglot.dialects.dialect import DialectType

from sqlmesh.core.macros import MacroRegistry
from sqlmesh.utils.jinja import JinjaMacroRegistry
from sqlmesh.core import constants as c
from sqlmesh.core.dialect import MacroFunc
from sqlmesh.core.dialect import MacroFunc, parse_one
from sqlmesh.core.model.definition import (
Model,
create_python_model,
Expand Down Expand Up @@ -75,6 +77,8 @@ def model(
module_path: Path,
path: Path,
defaults: t.Optional[t.Dict[str, t.Any]] = None,
macros: t.Optional[MacroRegistry] = None,
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
dialect: t.Optional[str] = None,
time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT,
physical_schema_override: t.Optional[t.Dict[str, str]] = None,
Expand Down Expand Up @@ -123,7 +127,9 @@ def model(
for key in ("pre_statements", "post_statements"):
statements = common_kwargs.get(key)
if statements:
common_kwargs[key] = [exp.maybe_parse(s, dialect=dialect) for s in statements]
common_kwargs[key] = [
parse_one(s, dialect=dialect) if isinstance(s, str) else s for s in statements
]

if self.is_sql:
query = MacroFunc(this=exp.Anonymous(this=entrypoint))
Expand All @@ -132,5 +138,12 @@ def model(
)

return create_python_model(
self.name, entrypoint, columns=self.columns, dialect=dialect, **common_kwargs
self.name,
entrypoint,
module_path=module_path,
macros=macros,
jinja_macros=jinja_macros,
columns=self.columns,
dialect=dialect,
**common_kwargs,
)
Loading

0 comments on commit 6405aed

Please sign in to comment.