Skip to content

Commit

Permalink
Merge branch 'main' into issue-427-task-generator
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana authored Sep 11, 2023
2 parents 78598e4 + dd0dcb6 commit 9db8c8e
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 17 deletions.
1 change: 1 addition & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
select=select,
dbt_cmd=dbt_executable_path,
profile_config=profile_config,
operator_args=operator_args,
dbt_deps=dbt_deps,
)
dbt_graph.load(method=load_mode, execution_mode=execution_mode)
Expand Down
3 changes: 3 additions & 0 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,14 @@ def __init__(
select: list[str] | None = None,
dbt_cmd: str = get_system_dbt(),
profile_config: ProfileConfig | None = None,
operator_args: dict[str, Any] | None = None,
dbt_deps: bool | None = True,
):
self.project = project
self.exclude = exclude or []
self.select = select or []
self.profile_config = profile_config
self.operator_args = operator_args or {}
self.dbt_deps = dbt_deps

# specific to loading using ls
Expand Down Expand Up @@ -282,6 +284,7 @@ def load_via_custom_parser(self) -> None:
dbt_models_dir=self.project.models_dir.stem if self.project.models_dir else None,
dbt_seeds_dir=self.project.seeds_dir.stem if self.project.seeds_dir else None,
project_name=self.project.name,
operator_args=self.operator_args,
)
nodes = {}
models = itertools.chain(project.models.items(), project.snapshots.items(), project.seeds.items())
Expand Down
58 changes: 41 additions & 17 deletions cosmos/dbt/parser/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class DbtModel:
name: str
type: DbtModelType
path: Path
operator_args: Dict[str, Any] = field(default_factory=dict)
config: DbtModelConfig = field(default_factory=DbtModelConfig)

def __post_init__(self) -> None:
Expand All @@ -137,6 +138,7 @@ def __post_init__(self) -> None:
"""
# first, get an empty config
config = DbtModelConfig()
var_args: Dict[str, Any] = self.operator_args.get("vars", {})

if self.type == DbtModelType.DBT_MODEL:
# get the code from the file
Expand Down Expand Up @@ -165,23 +167,40 @@ def __post_init__(self) -> None:
# iterate over the jinja nodes to extract info
for base_node in jinja2_ast.find_all(jinja2.nodes.Call):
if hasattr(base_node.node, "name"):
# check we have a ref - this indicates a dependency
if base_node.node.name == "ref":
# if it is, get the first argument
first_arg = base_node.args[0]
if isinstance(first_arg, jinja2.nodes.Const):
# and add it to the config
config.upstream_models.add(first_arg.value)

# check if we have a config - this could contain tags
if base_node.node.name == "config":
# if it is, check if any kwargs are tags
for kwarg in base_node.kwargs:
for selector in self.config.config_types:
extracted_config = self._extract_config(kwarg=kwarg, config_name=selector)
config.config_selectors |= (
set(extracted_config) if isinstance(extracted_config, (str, List)) else set()
)
try:
# check we have a ref - this indicates a dependency
if base_node.node.name == "ref":
# if it is, get the first argument
first_arg = base_node.args[0]
# if it contains vars, render the value of the var
if isinstance(first_arg, jinja2.nodes.Concat):
value = ""
for node in first_arg.nodes:
if isinstance(node, jinja2.nodes.Const):
value += node.value
elif (
isinstance(node, jinja2.nodes.Call)
and isinstance(node.node, jinja2.nodes.Name)
and isinstance(node.args[0], jinja2.nodes.Const)
and node.node.name == "var"
):
value += var_args[node.args[0].value]
config.upstream_models.add(value)
elif isinstance(first_arg, jinja2.nodes.Const):
# and add it to the config
config.upstream_models.add(first_arg.value)

# check if we have a config - this could contain tags
if base_node.node.name == "config":
# if it is, check if any kwargs are tags
for kwarg in base_node.kwargs:
for selector in self.config.config_types:
extracted_config = self._extract_config(kwarg=kwarg, config_name=selector)
config.config_selectors |= (
set(extracted_config) if isinstance(extracted_config, (str, List)) else set()
)
except KeyError as e:
logger.warning(f"Could not add upstream model for config in {self.path}: {e}")

# set the config and set the parsed file flag to true
self.config = config
Expand Down Expand Up @@ -236,6 +255,8 @@ class DbtProject:
snapshots_dir: Path = field(init=False)
seeds_dir: Path = field(init=False)

operator_args: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self) -> None:
"""
Initializes the parser.
Expand Down Expand Up @@ -287,6 +308,7 @@ def _handle_csv_file(self, path: Path) -> None:
name=model_name,
type=DbtModelType.DBT_SEED,
path=path,
operator_args=self.operator_args,
)
# add the model to the project
self.seeds[model_name] = model
Expand All @@ -304,6 +326,7 @@ def _handle_sql_file(self, path: Path) -> None:
name=model_name,
type=DbtModelType.DBT_MODEL,
path=path,
operator_args=self.operator_args,
)
# add the model to the project
self.models[model.name] = model
Expand All @@ -313,6 +336,7 @@ def _handle_sql_file(self, path: Path) -> None:
name=model_name,
type=DbtModelType.DBT_SNAPSHOT,
path=path,
operator_args=self.operator_args,
)
# add the snapshot to the project
self.snapshots[model.name] = model
Expand Down
15 changes: 15 additions & 0 deletions tests/dbt/parser/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,18 @@ def test_dbtmodelconfig_with_sources(tmp_path):

dbt_model = DbtModel(name="some_name", type=DbtModelType.DBT_MODEL, path=path_with_sources)
assert "sample_source" not in dbt_model.config.upstream_models


def test_dbtmodelconfig_with_vars(tmp_path):
model_sql = SAMPLE_MODEL_SQL_PATH.read_text()
model_with_vars_sql = model_sql.replace("ref('stg_customers')", "ref('stg_customers_'~ var('country_code'))")
path_with_sources = tmp_path / "customers_with_sources.sql"
path_with_sources.write_text(model_with_vars_sql)

dbt_model = DbtModel(
name="some_name",
type=DbtModelType.DBT_MODEL,
path=path_with_sources,
operator_args={"vars": {"country_code": "us"}},
)
assert "stg_customers_us" in dbt_model.config.upstream_models

0 comments on commit 9db8c8e

Please sign in to comment.