diff --git a/core/dbt/include/global_project/macros/python_model/python.sql b/core/dbt/include/global_project/macros/python_model/python.sql index 8bf1c4b89f2..d658ff185b2 100644 --- a/core/dbt/include/global_project/macros/python_model/python.sql +++ b/core/dbt/include/global_project/macros/python_model/python.sql @@ -10,13 +10,21 @@ {%- set ref_dict = {} -%} {%- for _ref in model.refs -%} - {%- set resolved = ref(*_ref) -%} - {%- do ref_dict.update({_ref | join('.'): resolve_model_name(resolved)}) -%} + {% set _ref_args = [_ref.get('package'), _ref['name']] if _ref.get('package') else [_ref['name'],] %} + {%- set resolved = ref(*_ref_args, v=_ref.get('version')) -%} + {%- if _ref.get('version') -%} + {% do _ref_args.extend(["v" ~ _ref['version']]) %} + {%- endif -%} + {%- do ref_dict.update({_ref_args | join('.'): resolve_model_name(resolved)}) -%} {%- endfor -%} -def ref(*args,dbt_load_df_function): +def ref(*args, **kwargs): refs = {{ ref_dict | tojson }} key = '.'.join(args) + version = kwargs.get("v") or kwargs.get("version") + if version: + key += f".v{version}" + dbt_load_df_function = kwargs.get("dbt_load_df_function") return dbt_load_df_function(refs[key]) {% endmacro %} @@ -81,7 +89,7 @@ class this: class dbtObj: def __init__(self, load_df_function) -> None: self.source = lambda *args: source(*args, dbt_load_df_function=load_df_function) - self.ref = lambda *args: ref(*args, dbt_load_df_function=load_df_function) + self.ref = lambda *args, **kwargs: ref(*args, **kwargs, dbt_load_df_function=load_df_function) self.config = config self.this = this() self.is_incremental = {{ is_incremental() }} diff --git a/tests/adapter/dbt/tests/adapter/python_model/test_python_model.py b/tests/adapter/dbt/tests/adapter/python_model/test_python_model.py index d7c23730648..259895abde9 100644 --- a/tests/adapter/dbt/tests/adapter/python_model/test_python_model.py +++ b/tests/adapter/dbt/tests/adapter/python_model/test_python_model.py @@ -17,7 +17,11 @@ def model(dbt, _): materialized='table', ) df = dbt.ref("my_sql_model") - df2 = dbt.source('test_source', 'test_table') + df2 = dbt.ref("my_versioned_sql_model", v=1) + df3 = dbt.ref("my_versioned_sql_model", version=1) + df4 = dbt.ref("test", "my_versioned_sql_model", v=1) + df5 = dbt.ref("test", "my_versioned_sql_model", version=1) + df6 = dbt.source("test_source", "test_table") df = df.limit(2) return df """ @@ -26,6 +30,11 @@ def model(dbt, _): select * from {{ref('my_python_model')}} """ schema_yml = """version: 2 +models: + - name: my_versioned_sql_model + versions: + - v: 1 + sources: - name: test_source loader: custom @@ -63,6 +72,7 @@ def models(self): return { "schema.yml": schema_yml, "my_sql_model.sql": basic_sql, + "my_versioned_sql_model_v1.sql": basic_sql, "my_python_model.py": basic_python, "second_sql_model.sql": second_sql, } @@ -75,7 +85,7 @@ def test_singular_tests(self, project): run_dbt(["seed", "--vars", yaml.safe_dump(vars_dict)]) results = run_dbt(["run", "--vars", yaml.safe_dump(vars_dict)]) - assert len(results) == 3 + assert len(results) == 4 m_1 = """