Skip to content

Commit

Permalink
Add selector method capabilities to selectors (#4827)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldiamond authored Apr 26, 2022
1 parent 55af3c7 commit 32e1924
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 28 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20220408-112610.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Add selector method when reading selector definitions
time: 2022-04-08T11:26:10.713088+10:00
custom:
Author: danieldiamond
Issue: "4821"
PR: "4827"
28 changes: 19 additions & 9 deletions core/dbt/config/selectors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from copy import deepcopy
from typing import Dict, Any, Union
from dbt.clients.yaml_helper import yaml, Loader, Dumper, load_yaml_text # noqa: F401
from dbt.dataclass_schema import ValidationError
Expand Down Expand Up @@ -140,28 +141,33 @@ def validate_selector_default(selector_file: SelectorFile) -> None:
# good to combine the two flows into one at some point.
class SelectorDict:
@classmethod
def parse_dict_definition(cls, definition):
def parse_dict_definition(cls, definition, selector_dict={}):
key = list(definition)[0]
value = definition[key]
if isinstance(value, list):
new_values = []
for sel_def in value:
new_value = cls.parse_from_definition(sel_def)
new_value = cls.parse_from_definition(sel_def, selector_dict=selector_dict)
new_values.append(new_value)
value = new_values
if key == "exclude":
definition = {key: value}
elif len(definition) == 1:
definition = {"method": key, "value": value}
elif key == "method" and value == "selector":
sel_def = definition.get("value")
if sel_def not in selector_dict:
raise DbtSelectorsError(f"Existing selector definition for {sel_def} not found.")
return selector_dict[definition["value"]]["definition"]
return definition

@classmethod
def parse_a_definition(cls, def_type, definition):
def parse_a_definition(cls, def_type, definition, selector_dict={}):
# this definition must be a list
new_dict = {def_type: []}
for sel_def in definition[def_type]:
if isinstance(sel_def, dict):
sel_def = cls.parse_from_definition(sel_def)
sel_def = cls.parse_from_definition(sel_def, selector_dict=selector_dict)
new_dict[def_type].append(sel_def)
elif isinstance(sel_def, str):
sel_def = SelectionCriteria.dict_from_single_spec(sel_def)
Expand All @@ -171,15 +177,17 @@ def parse_a_definition(cls, def_type, definition):
return new_dict

@classmethod
def parse_from_definition(cls, definition):
def parse_from_definition(cls, definition, selector_dict={}):
if isinstance(definition, str):
definition = SelectionCriteria.dict_from_single_spec(definition)
elif "union" in definition:
definition = cls.parse_a_definition("union", definition)
definition = cls.parse_a_definition("union", definition, selector_dict=selector_dict)
elif "intersection" in definition:
definition = cls.parse_a_definition("intersection", definition)
definition = cls.parse_a_definition(
"intersection", definition, selector_dict=selector_dict
)
elif isinstance(definition, dict):
definition = cls.parse_dict_definition(definition)
definition = cls.parse_dict_definition(definition, selector_dict=selector_dict)
return definition

# This is the normal entrypoint of this code. Give it the
Expand All @@ -190,6 +198,8 @@ def parse_from_selectors_list(cls, selectors):
for selector in selectors:
sel_name = selector["name"]
selector_dict[sel_name] = selector
definition = cls.parse_from_definition(selector["definition"])
definition = cls.parse_from_definition(
selector["definition"], selector_dict=deepcopy(selector_dict)
)
selector_dict[sel_name]["definition"] = definition
return selector_dict
44 changes: 28 additions & 16 deletions core/dbt/graph/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# special support for CLI argument parsing.
from dbt import flags
from copy import deepcopy
import itertools
from dbt.clients.yaml_helper import yaml, Loader, Dumper # noqa: F401

Expand Down Expand Up @@ -112,9 +113,9 @@ def _get_list_dicts(dct: Dict[str, Any], key: str) -> List[RawDefinition]:
return result


def _parse_exclusions(definition) -> Optional[SelectionSpec]:
def _parse_exclusions(definition, result={}) -> Optional[SelectionSpec]:
exclusions = _get_list_dicts(definition, "exclude")
parsed_exclusions = [parse_from_definition(excl) for excl in exclusions]
parsed_exclusions = [parse_from_definition(excl, result=result) for excl in exclusions]
if len(parsed_exclusions) == 1:
return parsed_exclusions[0]
elif len(parsed_exclusions) > 1:
Expand All @@ -124,7 +125,7 @@ def _parse_exclusions(definition) -> Optional[SelectionSpec]:


def _parse_include_exclude_subdefs(
definitions: List[RawDefinition],
definitions: List[RawDefinition], result={}
) -> Tuple[List[SelectionSpec], Optional[SelectionSpec]]:
include_parts: List[SelectionSpec] = []
diff_arg: Optional[SelectionSpec] = None
Expand All @@ -138,16 +139,16 @@ def _parse_include_exclude_subdefs(
f"You cannot provide multiple exclude arguments to the "
f"same selector set operator:\n{yaml_sel_cfg}"
)
diff_arg = _parse_exclusions(definition)
diff_arg = _parse_exclusions(definition, result=result)
else:
include_parts.append(parse_from_definition(definition))
include_parts.append(parse_from_definition(definition, result=result))

return (include_parts, diff_arg)


def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
def parse_union_definition(definition: Dict[str, Any], result={}) -> SelectionSpec:
union_def_parts = _get_list_dicts(definition, "union")
include, exclude = _parse_include_exclude_subdefs(union_def_parts)
include, exclude = _parse_include_exclude_subdefs(union_def_parts, result=result)

union = SelectionUnion(components=include)

Expand All @@ -158,9 +159,9 @@ def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec:
return SelectionDifference(components=[union, exclude], raw=definition)


def parse_intersection_definition(definition: Dict[str, Any]) -> SelectionSpec:
def parse_intersection_definition(definition: Dict[str, Any], result={}) -> SelectionSpec:
intersection_def_parts = _get_list_dicts(definition, "intersection")
include, exclude = _parse_include_exclude_subdefs(intersection_def_parts)
include, exclude = _parse_include_exclude_subdefs(intersection_def_parts, result=result)
intersection = SelectionIntersection(components=include)

if exclude is None:
Expand All @@ -170,7 +171,7 @@ def parse_intersection_definition(definition: Dict[str, Any]) -> SelectionSpec:
return SelectionDifference(components=[intersection, exclude], raw=definition)


def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
def parse_dict_definition(definition: Dict[str, Any], result={}) -> SelectionSpec:
diff_arg: Optional[SelectionSpec] = None
if len(definition) == 1:
key = list(definition)[0]
Expand All @@ -183,10 +184,15 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
"method": key,
"value": value,
}
elif definition.get("method") == "selector":
sel_def = definition.get("value")
if sel_def not in result:
raise ValidationException(f"Existing selector definition for {sel_def} not found.")
return result[definition["value"]]["definition"]
elif "method" in definition and "value" in definition:
dct = definition
if "exclude" in definition:
diff_arg = _parse_exclusions(definition)
diff_arg = _parse_exclusions(definition, result=result)
dct = {k: v for k, v in dct.items() if k != "exclude"}
else:
raise ValidationException(
Expand All @@ -202,7 +208,11 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec:
return SelectionDifference(components=[base, diff_arg])


def parse_from_definition(definition: RawDefinition, rootlevel=False) -> SelectionSpec:
def parse_from_definition(
definition: RawDefinition,
rootlevel=False,
result: Dict[str, Dict[str, Union[SelectionSpec, bool]]] = {},
) -> SelectionSpec:

if (
isinstance(definition, dict)
Expand All @@ -218,11 +228,11 @@ def parse_from_definition(definition: RawDefinition, rootlevel=False) -> Selecti
if isinstance(definition, str):
return SelectionCriteria.from_single_spec(definition)
elif "union" in definition:
return parse_union_definition(definition)
return parse_union_definition(definition, result=result)
elif "intersection" in definition:
return parse_intersection_definition(definition)
return parse_intersection_definition(definition, result=result)
elif isinstance(definition, dict):
return parse_dict_definition(definition)
return parse_dict_definition(definition, result=result)
else:
raise ValidationException(
f"Expected to find union, intersection, str or dict, instead "
Expand All @@ -238,6 +248,8 @@ def parse_from_selectors_definition(
for selector in source.selectors:
result[selector.name] = {
"default": selector.default,
"definition": parse_from_definition(selector.definition, rootlevel=True),
"definition": parse_from_definition(
selector.definition, rootlevel=True, result=deepcopy(result)
),
}
return result
64 changes: 63 additions & 1 deletion test/unit/test_graph_selector_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,66 @@ def test_parse_yaml_complex():
),
),
),
) == parsed['test_name']["definition"]
) == parsed['test_name']["definition"]


def test_parse_selection():
sf = parse_file('''\
selectors:
- name: default
definition:
union:
- tag: foo
- tag: bar
- name: inherited
definition:
method: selector
value: default
''')
assert len(sf.selectors) == 2
parsed = cli.parse_from_selectors_definition(sf)
assert 'default' in parsed
assert 'inherited' in parsed
assert Union(
Criteria(method=MethodName.Tag, value='foo'),
Criteria(method=MethodName.Tag, value='bar'),
) == parsed['default']["definition"]
assert Union(
Criteria(method=MethodName.Tag, value='foo'),
Criteria(method=MethodName.Tag, value='bar'),
) == parsed['inherited']["definition"]


def test_parse_selection_with_exclusion():
sf = parse_file('''\
selectors:
- name: default
definition:
union:
- tag: foo
- tag: bar
- name: inherited
definition:
union:
- method: selector
value: default
- exclude:
- tag: bar
''')
assert len(sf.selectors) == 2
parsed = cli.parse_from_selectors_definition(sf)
assert 'default' in parsed
assert 'inherited' in parsed
assert Union(
Criteria(method=MethodName.Tag, value='foo'),
Criteria(method=MethodName.Tag, value='bar'),
) == parsed['default']["definition"]
assert Difference(
Union(
Union(
Criteria(method=MethodName.Tag, value='foo'),
Criteria(method=MethodName.Tag, value='bar'),
)
),
Criteria(method=MethodName.Tag, value='bar'),
) == parsed['inherited']["definition"]
76 changes: 74 additions & 2 deletions test/unit/test_manifest_selectors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import dbt.exceptions
import textwrap
import yaml
from collections import OrderedDict
import unittest
from dbt.config.selectors import SelectorDict
from dbt.exceptions import DbtSelectorsError


def get_selector_dict(txt: str) -> dict:
def get_selector_dict(txt: str) -> OrderedDict:
txt = textwrap.dedent(txt)
dct = yaml.safe_load(txt)
dct = OrderedDict(yaml.safe_load(txt))
return dct


Expand Down Expand Up @@ -113,3 +115,73 @@ def test_plus_definition(self):
expected = {'method': 'fqn', 'value': 'my_model', 'children': True, 'children_depth': '2'}
definition = sel_dict['my_model_children_selector']['definition']
self.assertEqual(expected, definition)

def test_selector_definition(self):
dct = get_selector_dict('''\
selectors:
- name: default
definition:
union:
- intersection:
- tag: foo
- tag: bar
- name: inherited
definition:
method: selector
value: default
''')

sel_dict = SelectorDict.parse_from_selectors_list(dct['selectors'])
assert(sel_dict)
definition = sel_dict['default']['definition']
expected = sel_dict['inherited']['definition']
self.assertEqual(expected, definition)

def test_selector_definition_with_exclusion(self):
dct = get_selector_dict('''\
selectors:
- name: default
definition:
union:
- intersection:
- tag: foo
- tag: bar
- name: inherited
definition:
union:
- method: selector
value: default
- exclude:
- tag: bar
- name: comparison
definition:
union:
- union:
- intersection:
- tag: foo
- tag: bar
- exclude:
- tag: bar
''')

sel_dict = SelectorDict.parse_from_selectors_list((dct['selectors']))
assert(sel_dict)
definition = sel_dict['inherited']['definition']
expected = sel_dict['comparison']['definition']
self.assertEqual(expected, definition)

def test_missing_selector(self):
dct = get_selector_dict('''\
selectors:
- name: inherited
definition:
method: selector
value: default
''')
with self.assertRaises(DbtSelectorsError) as err:
sel_dict = SelectorDict.parse_from_selectors_list((dct['selectors']))

self.assertEqual(
'Existing selector definition for default not found.',
str(err.exception.msg)
)

0 comments on commit 32e1924

Please sign in to comment.