Skip to content

Commit

Permalink
new node selectors
Browse files Browse the repository at this point in the history
- added a bunch of unit tests
- added some new integration tests
- cleaned up some selection logic
  • Loading branch information
Jacob Beck committed Jul 15, 2020
1 parent 6e161ab commit 37c82d5
Show file tree
Hide file tree
Showing 20 changed files with 1,108 additions and 262 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

### Features
- Added support for Snowflake query tags at the connection and model level ([#1030](https://github.com/fishtown-analytics/dbt/issues/1030), [#2555](https://github.com/fishtown-analytics/dbt/pull/2555/))
- Added new node selector methods (`config`, `test_type`, `test_name`, `package`) ([#2425](https://github.com/fishtown-analytics/dbt/issues/2425), [#2629](https://github.com/fishtown-analytics/dbt/pull/2629))
- Added option to specify profile when connecting to Redshift via IAM ([#2437](https://github.com/fishtown-analytics/dbt/issues/2437), [#2581](https://github.com/fishtown-analytics/dbt/pull/2581))

### Fixes
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def insensitive_patterns(*patterns: str):


Severity = NewType('Severity', str)

register_pattern(Severity, insensitive_patterns('warn', 'error'))


Expand Down
6 changes: 6 additions & 0 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,12 @@ def __init__(self, thread_id, known, node=None):
)


class InvalidSelectorException(RuntimeException):
def __init__(self, name: str):
self.name = name
super().__init__(name)


def raise_compiler_error(msg, node=None) -> NoReturn:
raise CompilationException(msg, node)

Expand Down
5 changes: 4 additions & 1 deletion core/dbt/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
ResourceTypeSelector,
NodeSelector,
)
from .cli import parse_difference # noqa: F401
from .cli import ( # noqa: F401
parse_difference,
parse_test_selectors,
)
from .queue import GraphQueue # noqa: F401
from .graph import Graph # noqa: F401
33 changes: 33 additions & 0 deletions core/dbt/graph/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

DEFAULT_INCLUDES: List[str] = ['fqn:*', 'source:*']
DEFAULT_EXCLUDES: List[str] = []
DATA_TEST_SELECTOR: str = 'test_type:data'
SCHEMA_TEST_SELECTOR: str = 'test_type:schema'


def parse_union(
Expand Down Expand Up @@ -64,3 +66,34 @@ def parse_difference(
included = parse_union_from_default(include, DEFAULT_INCLUDES)
excluded = parse_union_from_default(exclude, DEFAULT_EXCLUDES)
return SelectionDifference(components=[included, excluded])


def parse_test_selectors(
data: bool, schema: bool, base: SelectionSpec
) -> SelectionSpec:
union_components = []

if data:
union_components.append(
SelectionCriteria.from_single_spec(DATA_TEST_SELECTOR)
)
if schema:
union_components.append(
SelectionCriteria.from_single_spec(SCHEMA_TEST_SELECTOR)
)

intersect_with: SelectionSpec
if not union_components:
return base
elif len(union_components) == 1:
intersect_with = union_components[0]
else: # data and schema tests
intersect_with = SelectionUnion(
components=union_components,
expect_exists=True,
raw=[DATA_TEST_SELECTOR, SCHEMA_TEST_SELECTOR],
)

return SelectionIntersection(
components=[base, intersect_with], expect_exists=True
)
111 changes: 37 additions & 74 deletions core/dbt/graph/selector.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@

from typing import (
Set, List, Dict, Union, Type
)
from typing import Set, List, Union

from .graph import Graph, UniqueId
from .queue import GraphQueue
from .selector_methods import (
MethodName,
SelectorMethod,
QualifiedNameSelectorMethod,
TagSelectorMethod,
SourceSelectorMethod,
PathSelectorMethod,
)
from .selector_methods import MethodManager
from .selector_spec import SelectionCriteria, SelectionSpec

from dbt.logger import GLOBAL_LOGGER as logger
from dbt.node_types import NodeType
from dbt.exceptions import InternalException, warn_or_error
from dbt.exceptions import (
InternalException,
InvalidSelectorException,
warn_or_error,
)
from dbt.contracts.graph.compiled import NonSourceNode, CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedSourceDefinition
Expand All @@ -35,15 +30,9 @@ def alert_non_existence(raw_spec, nodes):
)


class InvalidSelectorError(Exception):
# this internal exception should never escape the module.
pass


class NodeSelector:
class NodeSelector(MethodManager):
"""The node selector is aware of the graph and manifest,
"""
SELECTOR_METHODS: Dict[str, Type[SelectorMethod]] = {}

def __init__(
self,
Expand All @@ -53,52 +42,50 @@ def __init__(
self.full_graph = graph
self.manifest = manifest

@classmethod
def register_method(cls, name: MethodName, method: Type[SelectorMethod]):
cls.SELECTOR_METHODS[name] = method

def get_method(self, method: str) -> SelectorMethod:
if method in self.SELECTOR_METHODS:
cls: Type[SelectorMethod] = self.SELECTOR_METHODS[method]
return cls(self.manifest)
else:
raise InvalidSelectorError(method)
# build a subgraph containing only non-empty, enabled nodes and enabled
# sources.
graph_members = {
unique_id for unique_id in self.full_graph.nodes()
if self._is_graph_member(unique_id)
}
self.graph = self.full_graph.subgraph(graph_members)

def select_included(
self, included_nodes: Set[str], spec: SelectionCriteria,
) -> Set[str]:
"""Select the explicitly included nodes, using the given spec. Return
the selected set of unique IDs.
"""
method = self.get_method(spec.method)
method = self.get_method(spec.method, spec.method_arguments)
return set(method.search(included_nodes, spec.value))

def get_nodes_from_criteria(
self, graph: Graph, spec: SelectionCriteria
self, spec: SelectionCriteria
) -> Set[str]:
"""Given a Graph, get all nodes specified by the spec.
"""Get all nodes specified by the single selection criteria.
- collect the directly included nodes
- find their specified relatives
- perform any selector-specific expansion
"""
nodes = graph.nodes()

nodes = self.graph.nodes()
try:
collected = self.select_included(nodes, spec)
except InvalidSelectorError:
except InvalidSelectorException:
valid_selectors = ", ".join(self.SELECTOR_METHODS)
logger.info(
f"The '{spec.method}' selector specified in {spec.raw} is "
f"invalid. Must be one of [{valid_selectors}]"
)
return set()

extras = self.collect_specified_neighbors(spec, graph, collected)
result = self.expand_selection(graph, collected | extras)
extras = self.collect_specified_neighbors(spec, collected)
result = self.expand_selection(collected | extras)
return result

def collect_specified_neighbors(
self, spec: SelectionCriteria, graph: Graph, selected: Set[UniqueId]
self, spec: SelectionCriteria, selected: Set[UniqueId]
) -> Set[UniqueId]:
"""Given the set of models selected by the explicit part of the
selector (like "tag:foo"), apply the modifiers on the spec ("+"/"@").
Expand All @@ -107,35 +94,32 @@ def collect_specified_neighbors(
"""
additional: Set[UniqueId] = set()
if spec.select_childrens_parents:
additional.update(graph.select_childrens_parents(selected))
additional.update(self.graph.select_childrens_parents(selected))

if spec.select_parents:
additional.update(
graph.select_parents(selected, spec.select_parents_max_depth)
)
depth = spec.select_parents_max_depth
additional.update(self.graph.select_parents(selected, depth))

if spec.select_children:
additional.update(
graph.select_children(selected, spec.select_children_max_depth)
)
depth = spec.select_children_max_depth
additional.update(self.graph.select_children(selected, depth))
return additional

def select_nodes(self, graph: Graph, spec: SelectionSpec) -> Set[str]:
def select_nodes(self, spec: SelectionSpec) -> Set[str]:
"""Select the nodes in the graph according to the spec.
If the spec is a composite spec (a union, difference, or intersection),
recurse into its selections and combine them. If the spec is a concrete
selection criteria, resolve that using the given graph.
"""
if isinstance(spec, SelectionCriteria):
result = self.get_nodes_from_criteria(graph, spec)
result = self.get_nodes_from_criteria(spec)
else:
node_selections = [
self.select_nodes(graph, component)
self.select_nodes(component)
for component in spec
]
if node_selections:
result = spec.combine_selections(node_selections)
else:
result = set()
result = spec.combined(node_selections)
if spec.expect_exists:
alert_non_existence(spec.raw, result)
return result
Expand Down Expand Up @@ -168,16 +152,6 @@ def _is_match(self, unique_id: str) -> bool:
)
return self.node_is_match(node)

def build_graph_member_subgraph(self) -> Graph:
"""Build a subgraph of all enabled, non-empty nodes based on the full
graph.
"""
graph_members = {
unique_id for unique_id in self.full_graph.nodes()
if self._is_graph_member(unique_id)
}
return self.full_graph.subgraph(graph_members)

def filter_selection(self, selected: Set[str]) -> Set[str]:
"""Return the subset of selected nodes that is a match for this
selector.
Expand All @@ -186,17 +160,13 @@ def filter_selection(self, selected: Set[str]) -> Set[str]:
unique_id for unique_id in selected if self._is_match(unique_id)
}

def expand_selection(
self, filtered_graph: Graph, selected: Set[str]
) -> Set[str]:
def expand_selection(self, selected: Set[str]) -> Set[str]:
"""Perform selector-specific expansion."""
return selected

def get_selected(self, spec: SelectionSpec) -> Set[str]:
"""get_selected runs trhough the node selection process:
- build a subgraph containing only non-empty, enabled nodes and
enabled sources.
- node selection. Based on the include/exclude sets, the set
of matched unique IDs is returned
- expand the graph at each leaf node, before combination
Expand All @@ -206,8 +176,7 @@ def get_selected(self, spec: SelectionSpec) -> Set[str]:
- selectors can filter the nodes after all of them have been
selected
"""
filtered_graph = self.build_graph_member_subgraph()
selected_nodes = self.select_nodes(filtered_graph, spec)
selected_nodes = self.select_nodes(spec)
filtered_nodes = self.filter_selection(selected_nodes)
return filtered_nodes

Expand Down Expand Up @@ -236,9 +205,3 @@ def __init__(

def node_is_match(self, node):
return node.resource_type in self.resource_types


NodeSelector.register_method(MethodName.FQN, QualifiedNameSelectorMethod)
NodeSelector.register_method(MethodName.Tag, TagSelectorMethod)
NodeSelector.register_method(MethodName.Source, SourceSelectorMethod)
NodeSelector.register_method(MethodName.Path, PathSelectorMethod)
Loading

0 comments on commit 37c82d5

Please sign in to comment.