diff --git a/.changes/unreleased/Features-20220408-112610.yaml b/.changes/unreleased/Features-20220408-112610.yaml new file mode 100644 index 00000000000..21d366bcae9 --- /dev/null +++ b/.changes/unreleased/Features-20220408-112610.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Add selector method when reading selector definitions +time: 2022-04-08T11:26:10.713088+10:00 +custom: + Author: danieldiamond + Issue: "4821" + PR: "4827" diff --git a/core/dbt/config/selectors.py b/core/dbt/config/selectors.py index 996a9fb9ead..193a1bb70a8 100644 --- a/core/dbt/config/selectors.py +++ b/core/dbt/config/selectors.py @@ -1,4 +1,5 @@ from pathlib import Path +from copy import deepcopy from typing import Dict, Any, Union from dbt.clients.yaml_helper import yaml, Loader, Dumper, load_yaml_text # noqa: F401 from dbt.dataclass_schema import ValidationError @@ -140,28 +141,33 @@ def validate_selector_default(selector_file: SelectorFile) -> None: # good to combine the two flows into one at some point. class SelectorDict: @classmethod - def parse_dict_definition(cls, definition): + def parse_dict_definition(cls, definition, selector_dict={}): key = list(definition)[0] value = definition[key] if isinstance(value, list): new_values = [] for sel_def in value: - new_value = cls.parse_from_definition(sel_def) + new_value = cls.parse_from_definition(sel_def, selector_dict=selector_dict) new_values.append(new_value) value = new_values if key == "exclude": definition = {key: value} elif len(definition) == 1: definition = {"method": key, "value": value} + elif key == "method" and value == "selector": + sel_def = definition.get("value") + if sel_def not in selector_dict: + raise DbtSelectorsError(f"Existing selector definition for {sel_def} not found.") + return selector_dict[definition["value"]]["definition"] return definition @classmethod - def parse_a_definition(cls, def_type, definition): + def parse_a_definition(cls, def_type, definition, selector_dict={}): # this definition must be a list new_dict = {def_type: []} for sel_def in definition[def_type]: if isinstance(sel_def, dict): - sel_def = cls.parse_from_definition(sel_def) + sel_def = cls.parse_from_definition(sel_def, selector_dict=selector_dict) new_dict[def_type].append(sel_def) elif isinstance(sel_def, str): sel_def = SelectionCriteria.dict_from_single_spec(sel_def) @@ -171,15 +177,17 @@ def parse_a_definition(cls, def_type, definition): return new_dict @classmethod - def parse_from_definition(cls, definition): + def parse_from_definition(cls, definition, selector_dict={}): if isinstance(definition, str): definition = SelectionCriteria.dict_from_single_spec(definition) elif "union" in definition: - definition = cls.parse_a_definition("union", definition) + definition = cls.parse_a_definition("union", definition, selector_dict=selector_dict) elif "intersection" in definition: - definition = cls.parse_a_definition("intersection", definition) + definition = cls.parse_a_definition( + "intersection", definition, selector_dict=selector_dict + ) elif isinstance(definition, dict): - definition = cls.parse_dict_definition(definition) + definition = cls.parse_dict_definition(definition, selector_dict=selector_dict) return definition # This is the normal entrypoint of this code. Give it the @@ -190,6 +198,8 @@ def parse_from_selectors_list(cls, selectors): for selector in selectors: sel_name = selector["name"] selector_dict[sel_name] = selector - definition = cls.parse_from_definition(selector["definition"]) + definition = cls.parse_from_definition( + selector["definition"], selector_dict=deepcopy(selector_dict) + ) selector_dict[sel_name]["definition"] = definition return selector_dict diff --git a/core/dbt/graph/cli.py b/core/dbt/graph/cli.py index 1154457c151..6059de6b042 100644 --- a/core/dbt/graph/cli.py +++ b/core/dbt/graph/cli.py @@ -1,5 +1,6 @@ # special support for CLI argument parsing. from dbt import flags +from copy import deepcopy import itertools from dbt.clients.yaml_helper import yaml, Loader, Dumper # noqa: F401 @@ -112,9 +113,9 @@ def _get_list_dicts(dct: Dict[str, Any], key: str) -> List[RawDefinition]: return result -def _parse_exclusions(definition) -> Optional[SelectionSpec]: +def _parse_exclusions(definition, result={}) -> Optional[SelectionSpec]: exclusions = _get_list_dicts(definition, "exclude") - parsed_exclusions = [parse_from_definition(excl) for excl in exclusions] + parsed_exclusions = [parse_from_definition(excl, result=result) for excl in exclusions] if len(parsed_exclusions) == 1: return parsed_exclusions[0] elif len(parsed_exclusions) > 1: @@ -124,7 +125,7 @@ def _parse_exclusions(definition) -> Optional[SelectionSpec]: def _parse_include_exclude_subdefs( - definitions: List[RawDefinition], + definitions: List[RawDefinition], result={} ) -> Tuple[List[SelectionSpec], Optional[SelectionSpec]]: include_parts: List[SelectionSpec] = [] diff_arg: Optional[SelectionSpec] = None @@ -138,16 +139,16 @@ def _parse_include_exclude_subdefs( f"You cannot provide multiple exclude arguments to the " f"same selector set operator:\n{yaml_sel_cfg}" ) - diff_arg = _parse_exclusions(definition) + diff_arg = _parse_exclusions(definition, result=result) else: - include_parts.append(parse_from_definition(definition)) + include_parts.append(parse_from_definition(definition, result=result)) return (include_parts, diff_arg) -def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec: +def parse_union_definition(definition: Dict[str, Any], result={}) -> SelectionSpec: union_def_parts = _get_list_dicts(definition, "union") - include, exclude = _parse_include_exclude_subdefs(union_def_parts) + include, exclude = _parse_include_exclude_subdefs(union_def_parts, result=result) union = SelectionUnion(components=include) @@ -158,9 +159,9 @@ def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec: return SelectionDifference(components=[union, exclude], raw=definition) -def parse_intersection_definition(definition: Dict[str, Any]) -> SelectionSpec: +def parse_intersection_definition(definition: Dict[str, Any], result={}) -> SelectionSpec: intersection_def_parts = _get_list_dicts(definition, "intersection") - include, exclude = _parse_include_exclude_subdefs(intersection_def_parts) + include, exclude = _parse_include_exclude_subdefs(intersection_def_parts, result=result) intersection = SelectionIntersection(components=include) if exclude is None: @@ -170,7 +171,7 @@ def parse_intersection_definition(definition: Dict[str, Any]) -> SelectionSpec: return SelectionDifference(components=[intersection, exclude], raw=definition) -def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec: +def parse_dict_definition(definition: Dict[str, Any], result={}) -> SelectionSpec: diff_arg: Optional[SelectionSpec] = None if len(definition) == 1: key = list(definition)[0] @@ -183,10 +184,15 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec: "method": key, "value": value, } + elif definition.get("method") == "selector": + sel_def = definition.get("value") + if sel_def not in result: + raise ValidationException(f"Existing selector definition for {sel_def} not found.") + return result[definition["value"]]["definition"] elif "method" in definition and "value" in definition: dct = definition if "exclude" in definition: - diff_arg = _parse_exclusions(definition) + diff_arg = _parse_exclusions(definition, result=result) dct = {k: v for k, v in dct.items() if k != "exclude"} else: raise ValidationException( @@ -202,7 +208,11 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec: return SelectionDifference(components=[base, diff_arg]) -def parse_from_definition(definition: RawDefinition, rootlevel=False) -> SelectionSpec: +def parse_from_definition( + definition: RawDefinition, + rootlevel=False, + result: Dict[str, Dict[str, Union[SelectionSpec, bool]]] = {}, +) -> SelectionSpec: if ( isinstance(definition, dict) @@ -218,11 +228,11 @@ def parse_from_definition(definition: RawDefinition, rootlevel=False) -> Selecti if isinstance(definition, str): return SelectionCriteria.from_single_spec(definition) elif "union" in definition: - return parse_union_definition(definition) + return parse_union_definition(definition, result=result) elif "intersection" in definition: - return parse_intersection_definition(definition) + return parse_intersection_definition(definition, result=result) elif isinstance(definition, dict): - return parse_dict_definition(definition) + return parse_dict_definition(definition, result=result) else: raise ValidationException( f"Expected to find union, intersection, str or dict, instead " @@ -238,6 +248,8 @@ def parse_from_selectors_definition( for selector in source.selectors: result[selector.name] = { "default": selector.default, - "definition": parse_from_definition(selector.definition, rootlevel=True), + "definition": parse_from_definition( + selector.definition, rootlevel=True, result=deepcopy(result) + ), } return result diff --git a/test/unit/test_graph_selector_parsing.py b/test/unit/test_graph_selector_parsing.py index 435e04c3709..9d98e9c8bc5 100644 --- a/test/unit/test_graph_selector_parsing.py +++ b/test/unit/test_graph_selector_parsing.py @@ -300,4 +300,66 @@ def test_parse_yaml_complex(): ), ), ), - ) == parsed['test_name']["definition"] \ No newline at end of file + ) == parsed['test_name']["definition"] + + +def test_parse_selection(): + sf = parse_file('''\ + selectors: + - name: default + definition: + union: + - tag: foo + - tag: bar + - name: inherited + definition: + method: selector + value: default + ''') + assert len(sf.selectors) == 2 + parsed = cli.parse_from_selectors_definition(sf) + assert 'default' in parsed + assert 'inherited' in parsed + assert Union( + Criteria(method=MethodName.Tag, value='foo'), + Criteria(method=MethodName.Tag, value='bar'), + ) == parsed['default']["definition"] + assert Union( + Criteria(method=MethodName.Tag, value='foo'), + Criteria(method=MethodName.Tag, value='bar'), + ) == parsed['inherited']["definition"] + + +def test_parse_selection_with_exclusion(): + sf = parse_file('''\ + selectors: + - name: default + definition: + union: + - tag: foo + - tag: bar + - name: inherited + definition: + union: + - method: selector + value: default + - exclude: + - tag: bar + ''') + assert len(sf.selectors) == 2 + parsed = cli.parse_from_selectors_definition(sf) + assert 'default' in parsed + assert 'inherited' in parsed + assert Union( + Criteria(method=MethodName.Tag, value='foo'), + Criteria(method=MethodName.Tag, value='bar'), + ) == parsed['default']["definition"] + assert Difference( + Union( + Union( + Criteria(method=MethodName.Tag, value='foo'), + Criteria(method=MethodName.Tag, value='bar'), + ) + ), + Criteria(method=MethodName.Tag, value='bar'), + ) == parsed['inherited']["definition"] diff --git a/test/unit/test_manifest_selectors.py b/test/unit/test_manifest_selectors.py index eea2ade8edf..e410755ef16 100644 --- a/test/unit/test_manifest_selectors.py +++ b/test/unit/test_manifest_selectors.py @@ -1,13 +1,15 @@ import dbt.exceptions import textwrap import yaml +from collections import OrderedDict import unittest from dbt.config.selectors import SelectorDict +from dbt.exceptions import DbtSelectorsError -def get_selector_dict(txt: str) -> dict: +def get_selector_dict(txt: str) -> OrderedDict: txt = textwrap.dedent(txt) - dct = yaml.safe_load(txt) + dct = OrderedDict(yaml.safe_load(txt)) return dct @@ -113,3 +115,73 @@ def test_plus_definition(self): expected = {'method': 'fqn', 'value': 'my_model', 'children': True, 'children_depth': '2'} definition = sel_dict['my_model_children_selector']['definition'] self.assertEqual(expected, definition) + + def test_selector_definition(self): + dct = get_selector_dict('''\ + selectors: + - name: default + definition: + union: + - intersection: + - tag: foo + - tag: bar + - name: inherited + definition: + method: selector + value: default + ''') + + sel_dict = SelectorDict.parse_from_selectors_list(dct['selectors']) + assert(sel_dict) + definition = sel_dict['default']['definition'] + expected = sel_dict['inherited']['definition'] + self.assertEqual(expected, definition) + + def test_selector_definition_with_exclusion(self): + dct = get_selector_dict('''\ + selectors: + - name: default + definition: + union: + - intersection: + - tag: foo + - tag: bar + - name: inherited + definition: + union: + - method: selector + value: default + - exclude: + - tag: bar + - name: comparison + definition: + union: + - union: + - intersection: + - tag: foo + - tag: bar + - exclude: + - tag: bar + ''') + + sel_dict = SelectorDict.parse_from_selectors_list((dct['selectors'])) + assert(sel_dict) + definition = sel_dict['inherited']['definition'] + expected = sel_dict['comparison']['definition'] + self.assertEqual(expected, definition) + + def test_missing_selector(self): + dct = get_selector_dict('''\ + selectors: + - name: inherited + definition: + method: selector + value: default + ''') + with self.assertRaises(DbtSelectorsError) as err: + sel_dict = SelectorDict.parse_from_selectors_list((dct['selectors'])) + + self.assertEqual( + 'Existing selector definition for default not found.', + str(err.exception.msg) + )