Skip to content

Commit

Permalink
versioned ref for python models
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Apr 11, 2023
1 parent 3b0e7da commit 1b33416
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
16 changes: 12 additions & 4 deletions core/dbt/include/global_project/macros/python_model/python.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,21 @@

{%- set ref_dict = {} -%}
{%- for _ref in model.refs -%}
{%- set resolved = ref(*_ref) -%}
{%- do ref_dict.update({_ref | join('.'): resolve_model_name(resolved)}) -%}
{% set _ref_args = [_ref.get('package'), _ref['name']] if _ref.get('package') else [_ref['name'],] %}
{%- set resolved = ref(*_ref_args, v=_ref.get('version')) -%}
{%- if _ref.get('version') -%}
{% do _ref_args.extend(["v" ~ _ref['version']]) %}
{%- endif -%}
{%- do ref_dict.update({_ref_args | join('.'): resolve_model_name(resolved)}) -%}
{%- endfor -%}

def ref(*args,dbt_load_df_function):
def ref(*args, **kwargs):
refs = {{ ref_dict | tojson }}
key = '.'.join(args)
version = kwargs.get("v") or kwargs.get("version")
if version:
key += f".v{version}"
dbt_load_df_function = kwargs.get("dbt_load_df_function")
return dbt_load_df_function(refs[key])

{% endmacro %}
Expand Down Expand Up @@ -81,7 +89,7 @@ class this:
class dbtObj:
def __init__(self, load_df_function) -> None:
self.source = lambda *args: source(*args, dbt_load_df_function=load_df_function)
self.ref = lambda *args: ref(*args, dbt_load_df_function=load_df_function)
self.ref = lambda *args, **kwargs: ref(*args, **kwargs, dbt_load_df_function=load_df_function)
self.config = config
self.this = this()
self.is_incremental = {{ is_incremental() }}
Expand Down
14 changes: 12 additions & 2 deletions tests/adapter/dbt/tests/adapter/python_model/test_python_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def model(dbt, _):
materialized='table',
)
df = dbt.ref("my_sql_model")
df2 = dbt.source('test_source', 'test_table')
df2 = dbt.ref("my_versioned_sql_model", v=1)
df3 = dbt.ref("my_versioned_sql_model", version=1)
df4 = dbt.ref("test", "my_versioned_sql_model", v=1)
df5 = dbt.ref("test", "my_versioned_sql_model", version=1)
df6 = dbt.source("test_source", "test_table")
df = df.limit(2)
return df
"""
Expand All @@ -26,6 +30,11 @@ def model(dbt, _):
select * from {{ref('my_python_model')}}
"""
schema_yml = """version: 2
models:
- name: my_versioned_sql_model
versions:
- v: 1
sources:
- name: test_source
loader: custom
Expand Down Expand Up @@ -63,6 +72,7 @@ def models(self):
return {
"schema.yml": schema_yml,
"my_sql_model.sql": basic_sql,
"my_versioned_sql_model_v1.sql": basic_sql,
"my_python_model.py": basic_python,
"second_sql_model.sql": second_sql,
}
Expand All @@ -75,7 +85,7 @@ def test_singular_tests(self, project):

run_dbt(["seed", "--vars", yaml.safe_dump(vars_dict)])
results = run_dbt(["run", "--vars", yaml.safe_dump(vars_dict)])
assert len(results) == 3
assert len(results) == 4


m_1 = """
Expand Down

0 comments on commit 1b33416

Please sign in to comment.