Skip to content

Commit

Permalink
more complex visit_Call to parse chained command (#5677)
Browse files Browse the repository at this point in the history
* more complex visit_Call

* add changelog

* traversing all of the tree
  • Loading branch information
ChenyuLInx authored Aug 24, 2022
1 parent 7f8d9a7 commit 436737d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 11 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Fixes-20220817-163642.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Fixes
body: Fix Unexpected behavior when chaining methods on dbt-ref'ed/sourced dataframes
time: 2022-08-17T16:36:42.678275-07:00
custom:
Author: ChenyuLInx
Issue: "5646"
PR: "5677"
45 changes: 38 additions & 7 deletions core/dbt/parser/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
from dbt.exceptions import ParsingException, validator_error_message, UndefinedMacroException


dbt_function_key_words = set(["ref", "source", "config", "get"])
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])


class PythonValidationVisitor(ast.NodeVisitor):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -108,13 +112,40 @@ def _get_call_literals(self, node):
return arg_literals, kwarg_literals

def visit_Call(self, node: ast.Call) -> None:

func_name = self._flatten_attr(node.func)
if func_name in ["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"]:
# drop the dot-dbt prefix
func_name = func_name.split(".")[-1]
args, kwargs = self._get_call_literals(node)
self.dbt_function_calls.append((func_name, args, kwargs))
# check weather the current call could be a dbt function call
if isinstance(node.func, ast.Attribute) and node.func.attr in dbt_function_key_words:
func_name = self._flatten_attr(node.func)
# check weather the current call really is a dbt function call
if func_name in dbt_function_full_names:
# drop the dot-dbt prefix
func_name = func_name.split(".")[-1]
args, kwargs = self._get_call_literals(node)
self.dbt_function_calls.append((func_name, args, kwargs))

# no matter what happened above, we should keep visiting the rest of the tree
# visit args and kwargs to see if there's call in it
for obj in node.args + [kwarg.value for kwarg in node.keywords]:
if isinstance(obj, ast.Call):
self.visit_Call(obj)
# support dbt.ref in list args, kwargs
elif isinstance(obj, ast.List) or isinstance(obj, ast.Tuple):
for el in obj.elts:
if isinstance(el, ast.Call):
self.visit_Call(el)
# support dbt.ref in dict args, kwargs
elif isinstance(obj, ast.Dict):
for value in obj.values:
if isinstance(value, ast.Call):
self.visit_Call(value)
# visit node.func.value if we are at an call attr
if isinstance(node.func, ast.Attribute):
self.attribute_helper(node.func)

def attribute_helper(self, node: ast.Attribute) -> None:
while isinstance(node, ast.Attribute):
node = node.value # type: ignore
if isinstance(node, ast.Call):
self.visit_Call(node)

def visit_Import(self, node: ast.Import) -> None:
for n in node.names:
Expand Down
16 changes: 12 additions & 4 deletions test/unit/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,16 @@ def model(dbt, session):
from torch import b
import textblob.text
import sklearn
df = dbt.ref("my_sql_model")
df = dbt.ref("my_sql_model")
df = dbt.ref("my_sql_model_2")
df0 = pandas(dbt.ref("a_model"))
df1 = dbt.ref("my_sql_model").task.limit(2)
df2 = dbt.ref("my_sql_model_1")
df3 = dbt.ref("my_sql_model_2")
df4 = dbt.source("test", 'table1').limit(max = [max(dbt.ref('something'))])
df5 = [dbt.ref('test1')]
a_dict = {'test2' : dbt.ref('test2')}
df5 = anotherfunction({'test2' : dbt.ref('test3')})
df6 = [somethingelse.ref(dbt.ref("test4"))]
df = df.limit(2)
return df
Expand Down Expand Up @@ -582,7 +589,8 @@ def model(dbt, session):
checksum=block.file.checksum,
unrendered_config={'materialized': 'table', 'packages':python_packages},
config_call_dict={'materialized': 'table', 'packages':python_packages},
refs=[['my_sql_model'], ['my_sql_model'], ['my_sql_model_2']]
refs=[['a_model'], ['my_sql_model'], ['my_sql_model_1'], ['my_sql_model_2'], ['something'], ['test1'], ['test2'], ['test3'], ['test4']],
sources = [['test', 'table1']],
)
assertEqualNodes(node, expected)
file_id = 'snowplow://' + normalize('models/nested/py_model.py')
Expand Down

0 comments on commit 436737d

Please sign in to comment.