diff --git a/paracelsus/graph.py b/paracelsus/graph.py index 5f98048..7031cd1 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -2,10 +2,10 @@ import os import sys from pathlib import Path +import re from typing import List, Set -from sqlalchemy import MetaData - +from sqlalchemy.schema import MetaData from .transformers.dot import Dot from .transformers.mermaid import Mermaid @@ -83,15 +83,18 @@ def resolve_included_tables( case 0, 0: return all_tables case 0, int(): - return all_tables - exclude_tables + excluded = {table for table in all_tables if any(re.match(pattern, table) for pattern in exclude_tables)} + return all_tables - excluded case int(), 0: - if not include_tables.issubset(all_tables): + included = {table for table in all_tables if any(re.match(pattern, table) for pattern in include_tables)} + + if not included: non_existent_tables = include_tables - all_tables raise ValueError( f"Some tables to include ({non_existent_tables}) don't exist" "withinthe found tables ({all_tables})." ) - return include_tables + return included case _: raise ValueError( f"Only one or none of include_tables ({include_tables}) or exclude_tables" diff --git a/tests/test_cli.py b/tests/test_cli.py index 667a253..032bf3e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,3 +1,5 @@ +from pathlib import Path +from typing import Literal import pytest from typer.testing import CliRunner @@ -9,7 +11,7 @@ runner = CliRunner() -def test_graph(package_path): +def test_graph(package_path: Path): result = runner.invoke( app, ["graph", "example.base:Base", "--import-module", "example.models", "--python-dir", str(package_path)], @@ -20,7 +22,7 @@ def test_graph(package_path): @pytest.mark.parametrize("column_sort_arg", ["key-based", "preserve-order"]) -def test_graph_column_sort(package_path, column_sort_arg): +def test_graph_column_sort(package_path: Path, column_sort_arg: Literal["key-based"] | Literal["preserve-order"]): result = runner.invoke( app, [ @@ -39,7 +41,7 @@ def test_graph_column_sort(package_path, column_sort_arg): mermaid_assert(result.stdout) -def test_graph_with_exclusion(package_path): +def test_graph_with_exclusion(package_path: Path): result = runner.invoke( app, [ @@ -58,7 +60,7 @@ def test_graph_with_exclusion(package_path): assert "comments {" not in result.stdout -def test_graph_with_inclusion(package_path): +def test_graph_with_inclusion(package_path: Path): result = runner.invoke( app, [ @@ -77,7 +79,7 @@ def test_graph_with_inclusion(package_path): assert "comments {" in result.stdout -def test_inject_check(package_path): +def test_inject_check(package_path: Path): result = runner.invoke( app, [ @@ -94,7 +96,7 @@ def test_inject_check(package_path): assert result.exit_code == 1 -def test_inject(package_path): +def test_inject(package_path: Path): result = runner.invoke( app, [ @@ -115,7 +117,7 @@ def test_inject(package_path): @pytest.mark.parametrize("column_sort_arg", ["key-based", "preserve-order"]) -def test_inject_column_sort(package_path, column_sort_arg): +def test_inject_column_sort(package_path: Path, column_sort_arg: Literal["key-based"] | Literal["preserve-order"]): result = runner.invoke( app, [ @@ -140,3 +142,43 @@ def test_inject_column_sort(package_path, column_sort_arg): def test_version(): result = runner.invoke(app, ["version"]) assert result.exit_code == 0 + + +def test_graph_with_inclusion_regex(package_path: Path): + result = runner.invoke( + app, + [ + "graph", + "example.base:Base", + "--import-module", + "example.models", + "--python-dir", + str(package_path), + "--include-tables", + "^com.*", + ], + ) + assert result.exit_code == 0 + assert "comments {" in result.stdout + assert "users {" not in result.stdout + assert "post{" not in result.stdout + + +def test_graph_with_exclusion_regex(package_path: Path): + result = runner.invoke( + app, + [ + "graph", + "example.base:Base", + "--import-module", + "example.models", + "--python-dir", + str(package_path), + "--exclude-tables", + "^pos*.", + ], + ) + assert result.exit_code == 0 + assert "comments {" in result.stdout + assert "users {" in result.stdout + assert "post {" not in result.stdout