Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add unix-style fqn wildcard selector method #6599

Merged
merged 4 commits into from
Mar 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230112-191705.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: ✨ add unix-style wildcard selector method
time: 2023-01-12T19:17:05.841918-07:00
custom:
Author: z3z1ma
Issue: "6598"
50 changes: 32 additions & 18 deletions core/dbt/graph/selector_methods.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from fnmatch import fnmatch
from itertools import chain
from pathlib import Path
from typing import Set, List, Dict, Iterator, Tuple, Any, Union, Type, Optional, Callable
Expand Down Expand Up @@ -46,10 +47,13 @@ class MethodName(StrEnum):
Metric = "metric"
Result = "result"
SourceStatus = "source_status"
Wildcard = "wildcard"


def is_selected_node(fqn: List[str], node_selector: str):

def is_selected_node(
fqn: List[str],
node_selector: str,
) -> bool:
# If qualified_name exactly matches model name (fqn's leaf), return True
if fqn[-1] == node_selector:
return True
Expand All @@ -59,15 +63,26 @@ def is_selected_node(fqn: List[str], node_selector: str):
if len(flat_fqn) < len(node_selector.split(".")):
return False

slurp_from_ix: Optional[int] = None
for i, selector_part in enumerate(node_selector.split(".")):
# if we hit a GLOB, then this node is selected
if selector_part == SELECTOR_GLOB:
return True
if any(wildcard in selector_part for wildcard in ("*", "?", "[", "]")):
slurp_from_ix = i
break
elif flat_fqn[i] == selector_part:
continue
else:
return False

if slurp_from_ix is not None:
# If we have a wildcard, we need to make sure that the selector matches the
# rest of the fqn, this is 100% backwards compatible with the old behavior of
# encountering a wildcard but more expressive in naturally allowing you to
# match the rest of the fqn with more advanced patterns
return fnmatch(
".".join(flat_fqn[slurp_from_ix:]),
".".join(node_selector.split(".")[slurp_from_ix:]),
)

# if we get all the way down here, then the node is a match
return True

Expand Down Expand Up @@ -195,7 +210,7 @@ class TagSelectorMethod(SelectorMethod):
def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
"""yields nodes from included that have the specified tag"""
for node, real_node in self.all_nodes(included_nodes):
if selector in real_node.tags:
if any(fnmatch(tag, selector) for tag in real_node.tags):
yield node


Expand All @@ -213,7 +228,7 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu
parts = selector.split(".")
target_package = SELECTOR_GLOB
if len(parts) == 1:
target_source, target_table = parts[0], None
target_source, target_table = parts[0], SELECTOR_GLOB
elif len(parts) == 2:
target_source, target_table = parts
elif len(parts) == 3:
Expand All @@ -228,13 +243,12 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu
raise DbtRuntimeError(msg)

for node, real_node in self.source_nodes(included_nodes):
if target_package not in (real_node.package_name, SELECTOR_GLOB):
if not fnmatch(real_node.package_name, target_package):
continue
if target_source not in (real_node.source_name, SELECTOR_GLOB):
if not fnmatch(real_node.source_name, target_source):
continue
if target_table not in (None, real_node.name, SELECTOR_GLOB):
if not fnmatch(real_node.name, target_table):
continue

yield node


Expand All @@ -255,9 +269,9 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu
raise DbtRuntimeError(msg)

for node, real_node in self.exposure_nodes(included_nodes):
if target_package not in (real_node.package_name, SELECTOR_GLOB):
if not fnmatch(real_node.package_name, target_package):
continue
if target_name not in (real_node.name, SELECTOR_GLOB):
if not fnmatch(real_node.name, target_name):
continue

yield node
Expand All @@ -280,9 +294,9 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu
raise DbtRuntimeError(msg)

for node, real_node in self.metric_nodes(included_nodes):
if target_package not in (real_node.package_name, SELECTOR_GLOB):
if not fnmatch(real_node.package_name, target_package):
continue
if target_name not in (real_node.name, SELECTOR_GLOB):
if not fnmatch(real_node.name, target_name):
continue

yield node
Expand All @@ -306,15 +320,15 @@ class FileSelectorMethod(SelectorMethod):
def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
"""Yields nodes from included that match the given file name."""
for node, real_node in self.all_nodes(included_nodes):
if Path(real_node.original_file_path).name == selector:
if fnmatch(Path(real_node.original_file_path).name, selector):
yield node


class PackageSelectorMethod(SelectorMethod):
def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
"""Yields nodes from included that have the specified package"""
for node, real_node in self.all_nodes(included_nodes):
if real_node.package_name == selector:
if fnmatch(real_node.package_name, selector):
yield node


Expand Down Expand Up @@ -395,7 +409,7 @@ class TestNameSelectorMethod(SelectorMethod):
def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
for node, real_node in self.parsed_nodes(included_nodes):
if real_node.resource_type == NodeType.Test and hasattr(real_node, "test_metadata"):
if real_node.test_metadata.name == selector: # type: ignore[union-attr]
if fnmatch(real_node.test_metadata.name, selector): # type: ignore[union-attr]
yield node


Expand Down
41 changes: 35 additions & 6 deletions test/unit/test_graph_selector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,26 @@ def test_select_fqn(manifest):
'mynamespace.union_model', 'mynamespace.ephemeral_model', 'mynamespace.seed'}
assert search_manifest_using_method(
manifest, method, 'ext') == {'ext_model'}
# wildcards
assert search_manifest_using_method(manifest, method, '*.*.*_model') == {
'mynamespace.union_model', 'mynamespace.ephemeral_model', 'union_model'}
# multiple wildcards
assert search_manifest_using_method(
manifest, method, '*unions*') == {'union_model', 'mynamespace.union_model'}
# negation
assert not search_manifest_using_method(manifest, method, '!pkg*')
# single wildcard
assert search_manifest_using_method(manifest, method, 'pkg.t*') == {
'table_model', 'table_model_py', 'table_model_csv'}
# wildcard and ? (matches exactly one character)
assert search_manifest_using_method(
manifest, method, '*ext_m?del') == {'ext_model'}
# multiple ?
assert search_manifest_using_method(manifest, method, '*.?????_model') == {
'union_model', 'table_model', 'mynamespace.union_model'}
# multiple ranges
assert search_manifest_using_method(manifest, method, '*.[t-u][a-n][b-i][l-o][e-n]_model') == {
'union_model', 'table_model', 'mynamespace.union_model'}


def test_select_tag(manifest):
Expand All @@ -692,7 +712,8 @@ def test_select_tag(manifest):
assert search_manifest_using_method(manifest, method, 'uses_ephemeral') == {
'view_model', 'table_model'}
assert not search_manifest_using_method(manifest, method, 'missing')

assert search_manifest_using_method(manifest, method, 'uses_eph*') == {
'view_model', 'table_model'}

def test_select_group(manifest, view_model):
group_name = 'my_group'
Expand Down Expand Up @@ -740,7 +761,8 @@ def test_select_source(manifest):
assert search_manifest_using_method(
manifest, method, 'ext.ext_raw.*') == {'ext_raw.ext_source', 'ext_raw.ext_source_2'}
assert not search_manifest_using_method(manifest, method, 'pkg.ext_raw.*')

assert search_manifest_using_method(manifest, method, '*.ext_[s]ourc?') == {
'ext_raw.ext_source', 'raw.ext_source'}

# TODO: this requires writing out files
@pytest.mark.skip('TODO: write manifest files to disk')
Expand Down Expand Up @@ -781,7 +803,8 @@ def test_select_file(manifest):
manifest, method, 'missing.sql')
assert not search_manifest_using_method(
manifest, method, 'missing.py')

assert search_manifest_using_method(
manifest, method, 'table_*.csv') == {'table_model_csv'}

def test_select_package(manifest):
methods = MethodManager(manifest, None)
Expand All @@ -798,6 +821,9 @@ def test_select_package(manifest):

assert not search_manifest_using_method(manifest, method, 'missing')

assert search_manifest_using_method(manifest, method, 'ex*') == {
'ext_model', 'ext_raw.ext_source', 'ext_raw.ext_source_2', 'raw.ext_source', 'raw.ext_source_2', 'unique_ext_raw_ext_source_id'}


def test_select_config_materialized(manifest):
methods = MethodManager(manifest, None)
Expand Down Expand Up @@ -844,7 +870,8 @@ def test_select_test_name(manifest):
assert search_manifest_using_method(manifest, method, 'not_null') == {
'not_null_table_model_id'}
assert not search_manifest_using_method(manifest, method, 'notatest')

assert search_manifest_using_method(manifest, method, 'not_*') == {
'not_null_table_model_id'}

def test_select_test_type(manifest):
methods = MethodManager(manifest, None)
Expand Down Expand Up @@ -872,7 +899,8 @@ def test_select_exposure(manifest):
manifest, method, 'my_exposure') == {'my_exposure'}
assert not search_manifest_using_method(
manifest, method, 'not_my_exposure')

assert search_manifest_using_method(
manifest, method, 'my_e*e') == {'my_exposure'}

def test_select_metric(manifest):
metric = make_metric('test', 'my_metric')
Expand All @@ -884,7 +912,8 @@ def test_select_metric(manifest):
manifest, method, 'my_metric') == {'my_metric'}
assert not search_manifest_using_method(
manifest, method, 'not_my_metric')

assert search_manifest_using_method(
manifest, method, '*_metric') == {'my_metric'}

@pytest.fixture
def previous_state(manifest):
Expand Down