Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support all model methods in dbt projects #3161

Merged
merged 2 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,6 @@ def relation_info(self) -> AttributeDict[str, t.Any]:
}
)

def model_function(self) -> AttributeDict[str, t.Any]:
return AttributeDict({"config": self.config_attribute_dict})

@property
def tests_ref_source_dependencies(self) -> Dependencies:
dependencies = Dependencies()
Expand Down Expand Up @@ -280,6 +277,14 @@ def check_for_circular_test_refs(self, context: DbtContext) -> None:
def sqlmesh_config_fields(self) -> t.Set[str]:
return {"description", "owner", "stamp", "storage_format"}

@property
def node_name(self) -> str:
resource_type = getattr(self, "resource_type", "model")
node_name = f"{resource_type}.{self.package_name}.{self.name}"
if self.version:
node_name += f".v{self.version}"
return node_name

def sqlmesh_model_kwargs(self, context: DbtContext) -> t.Dict[str, t.Any]:
"""Get common sqlmesh model parameters"""
self.check_for_circular_test_refs(context)
Expand All @@ -289,10 +294,21 @@ def sqlmesh_model_kwargs(self, context: DbtContext) -> t.Dict[str, t.Any]:
jinja_macros = model_context.jinja_macros.trim(
self.dependencies.macros, package=self.package_name
)

model_node: AttributeDict[str, t.Any] = AttributeDict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, I don't think you need to convert to AttributeDict since the validator within JinjaMacros should take of that.

{
k: v
for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items()
if k in self.dependencies.model_attrs
}
if context._manifest and self.node_name in context._manifest._manifest.nodes
else {}
)

jinja_macros.add_globals(
{
"this": self.relation_info,
"model": self.model_function(),
"model": model_node,
"schema": self.table_schema,
"config": self.config_attribute_dict,
**model_context.jinja_globals, # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/dbt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,15 @@ class Dependencies(PydanticModel):
sources: t.Set[str] = set()
refs: t.Set[str] = set()
variables: t.Set[str] = set()
model_attrs: t.Set[str] = set()

def union(self, other: Dependencies) -> Dependencies:
return Dependencies(
macros=list(set(self.macros) | set(other.macros)),
sources=self.sources | other.sources,
refs=self.refs | other.refs,
variables=self.variables | other.variables,
model_attrs=self.model_attrs | other.model_attrs,
)

@field_validator("macros", mode="after")
Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/dbt/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Override the file name to prevent dbt commands from invalidating the cache.
dbt_constants.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack"

import jinja2
from dbt.adapters.factory import register_adapter, reset_adapters
from dbt.config import Profile, Project, RuntimeConfig
from dbt.config.profile import read_profile
Expand Down Expand Up @@ -398,6 +399,9 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies:
for call_name, node in extract_call_names(target, cache=self._calls):
if call_name[0] == "config":
continue
elif isinstance(node, jinja2.nodes.Getattr):
tobymao marked this conversation as resolved.
Show resolved Hide resolved
if call_name[0] == "model":
dependencies.model_attrs.add(call_name[1])
elif call_name[0] == "source":
args = [jinja_call_arg_name(arg) for arg in node.args]
if args and all(arg for arg in args):
Expand Down
5 changes: 4 additions & 1 deletion sqlmesh/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def __setitem__(self, k: KEY, v: VALUE) -> None:


class AttributeDict(dict, t.Mapping[KEY, VALUE]):
__getattr__ = dict.get
def __getattr__(self, key: t.Any) -> t.Optional[VALUE]:
if key.startswith("__") and not hasattr(self, key):
raise AttributeError
return self.get(key)

def set(self, field: str, value: t.Any) -> str:
self[field] = value
Expand Down
11 changes: 8 additions & 3 deletions sqlmesh/utils/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from sqlmesh.utils import AttributeDict
from sqlmesh.utils.pydantic import PRIVATE_FIELDS, PydanticModel, field_serializer, field_validator


if t.TYPE_CHECKING:
CallNames = t.Tuple[t.Tuple[str, ...], t.Union[nodes.Call, nodes.Getattr]]

SQLMESH_JINJA_PACKAGE = "sqlmesh.utils.jinja"


Expand All @@ -28,8 +32,6 @@ def environment(**kwargs: t.Any) -> Environment:

ENVIRONMENT = environment()

CallNames = t.Tuple[t.Tuple[str, ...], nodes.Call]


class MacroReference(PydanticModel, frozen=True):
package: t.Optional[str] = None
Expand Down Expand Up @@ -136,7 +138,9 @@ def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[C
elif isinstance(child_node, nodes.Macro):
for arg in child_node.args:
vars_in_scope.add(arg.name)
elif isinstance(child_node, nodes.Call):
elif isinstance(child_node, nodes.Call) or (
isinstance(child_node, nodes.Getattr) and not isinstance(child_node.node, nodes.Getattr)
):
name = call_name(child_node)
if name[0][0] != "'" and name[0] not in vars_in_scope:
yield (name, child_node)
Expand Down Expand Up @@ -168,6 +172,7 @@ def extract_macro_references_and_variables(
for jinja_str in jinja_strs:
for call_name, node in extract_call_names(jinja_str):
if call_name[0] == c.VAR:
assert isinstance(node, nodes.Call)
args = [jinja_call_arg_name(arg) for arg in node.args]
if args and args[0]:
variables.add(args[0].lower())
Expand Down
2 changes: 1 addition & 1 deletion tests/core/engine_adapter/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2315,7 +2315,7 @@ def test_to_time_column(
# specific data type to validate what is returned.
import re

time_column = re.match("^(.*?)\+", time_column).group(1)
time_column = re.match(r"^(.*?)\+", time_column).group(1)
time_column_type = exp.DataType.build("TIMESTAMP('UTC')", dialect="clickhouse")

time_column = to_time_column(time_column, time_column_type, time_column_format)
Expand Down
24 changes: 24 additions & 0 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,30 @@ def test_dbt_select_star_is_directly_modified(sushi_test_dbt_context: Context):
assert plan.snapshots[snapshot_b_id].change_category == SnapshotChangeCategory.NON_BREAKING


def test_model_attr(sushi_test_dbt_context: Context, assert_exp_eq):
context = sushi_test_dbt_context
model = context.get_model("sushi.top_waiters")
assert_exp_eq(
model.render_query(),
"""
SELECT
CAST("waiter_id" AS INT) AS "waiter_id",
CAST("revenue" AS DOUBLE) AS "revenue",
3 AS "model_columns"
FROM "memory"."sushi"."waiter_revenue_by_day_v2" AS "waiter_revenue_by_day_v2"
WHERE
"ds" = (
SELECT
MAX("ds")
FROM "memory"."sushi"."waiter_revenue_by_day_v2" AS "waiter_revenue_by_day_v2"
)
ORDER BY
"revenue" DESC NULLS FIRST
LIMIT 10
""",
)


@freeze_time("2023-01-08 15:00:00")
def test_incremental_by_partition(init_and_plan_context: t.Callable):
context, plan = init_and_plan_context("examples/sushi")
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2673,7 +2673,7 @@ def test_create_managed_forward_only_with_previous_version_doesnt_clone_for_dev_
)
)

snapshot: Snapshot = make_snapshot(model)
snapshot = make_snapshot(model)
snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY)
snapshot.previous_versions = (
SnapshotDataVersion(
Expand Down
1 change: 1 addition & 0 deletions tests/dbt/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_manifest_helper(caplog):
assert models["top_waiters"].dependencies == Dependencies(
refs={"sushi.waiter_revenue_by_day", "waiter_revenue_by_day"},
variables={"top_waiters:revenue", "top_waiters:limit"},
model_attrs={"columns", "config"},
macros=[MacroReference(name="ref"), MacroReference(name="var")],
)
assert models["top_waiters"].materialized == "view"
Expand Down
9 changes: 8 additions & 1 deletion tests/fixtures/dbt/sushi_test/models/schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ version: 2

models:
- name: top_waiters
columns:
- name: waiter_id
data_type: int
- name: revenue
data_type: double
- name: model_columns
data_type: int
config:
dialect: postgres
- name: waiters
Expand Down Expand Up @@ -29,4 +36,4 @@ sources:
external_location: "read_parquet('path/to/external/{name}.parquet')"
tables:
- name: items
- name: orders
- name: orders
6 changes: 5 additions & 1 deletion tests/fixtures/dbt/sushi_test/models/top_waiters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
)
}}

{% set columns = model.columns %}
{% set config = model.config %}

SELECT
waiter_id::INT AS waiter_id,
revenue::DOUBLE AS {{ var("top_waiters:revenue") }}
revenue::DOUBLE AS {{ var("top_waiters:revenue") }},
{{ columns | length }} AS model_columns
FROM {{ ref('sushi', 'waiter_revenue_by_day') }}
WHERE
ds = (
Expand Down