From 3a384bf67b63cb06f9a884c1141290c1352db5c7 Mon Sep 17 00:00:00 2001 From: Dan Date: Thu, 11 May 2023 11:32:35 -0600 Subject: [PATCH 1/3] squash simplify --dbt installation --- data_diff/dbt_parser.py | 55 +++++++++---------- data_diff/sqeleton/databases/base.py | 2 +- poetry.lock | 6 +- pyproject.toml | 6 +- tests/test_dbt.py | 82 +++++++++++++++++----------- 5 files changed, 82 insertions(+), 69 deletions(-) diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index effd877c..7bb933d8 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -4,9 +4,12 @@ import os from pathlib import Path from typing import List, Dict, Tuple, Set, Optional +import yaml from packaging.version import parse as parse_version import pydantic +from dbt_artifacts_parser.parser import parse_run_results, parse_manifest +from dbt.config.renderer import ProfileRenderer from .utils import getLogger, get_from_dict_with_raise from .version import __version__ @@ -15,23 +18,9 @@ logger = getLogger(__name__) -def import_dbt_dependencies(): - try: - from dbt_artifacts_parser.parser import parse_run_results, parse_manifest - from dbt.config.renderer import ProfileRenderer - import yaml - except ImportError: - raise RuntimeError("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'.") - - # dbt 1.5+ specific stuff to power selection of models - try: - # ProfileRenderer.render_data() fails without instantiating global flag MACRO_DEBUGGING in dbt-core 1.5 - from dbt.flags import set_flags - - set_flags(Namespace(MACRO_DEBUGGING=False)) - except: - pass - +# getting this dbt_runner will only succeed in dbt-core>=1.5 +# it's needed for `--select` functionality +def try_get_dbt_runner(): try: from dbt.cli.main import dbtRunner except ImportError: @@ -42,7 +31,18 @@ def import_dbt_dependencies(): else: dbt_runner = None - return parse_run_results, parse_manifest, ProfileRenderer, yaml, dbt_runner + return dbt_runner + + +# ProfileRenderer.render_data() fails without instantiating global flag MACRO_DEBUGGING in dbt-core 1.5 +# hacky but seems to be a bug on dbt's end +def try_set_dbt_flags(): + try: + from dbt.flags import set_flags + + set_flags(Namespace(MACRO_DEBUGGING=False)) + except: + pass RUN_RESULTS_PATH = "target/run_results.json" @@ -77,13 +77,8 @@ class TDatadiffModelConfig(pydantic.BaseModel): class DbtParser: def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None: - ( - self.parse_run_results, - self.parse_manifest, - self.ProfileRenderer, - self.yaml, - self.dbt_runner, - ) = import_dbt_dependencies() + try_set_dbt_flags() + self.dbt_runner = try_get_dbt_runner() self.profiles_dir = Path(profiles_dir_override or default_profiles_dir()) self.project_dir = Path(project_dir_override or default_project_dir()) self.connection = {} @@ -173,7 +168,7 @@ def get_run_results_models(self): with open(self.project_dir / RUN_RESULTS_PATH) as run_results: logger.info(f"Parsing file {RUN_RESULTS_PATH}") run_results_dict = json.load(run_results) - run_results_obj = self.parse_run_results(run_results=run_results_dict) + run_results_obj = parse_run_results(run_results=run_results_dict) dbt_version = parse_version(run_results_obj.metadata.dbt_version) @@ -199,20 +194,20 @@ def get_manifest_obj(self): with open(self.project_dir / MANIFEST_PATH) as manifest: logger.info(f"Parsing file {MANIFEST_PATH}") manifest_dict = json.load(manifest) - manifest_obj = self.parse_manifest(manifest=manifest_dict) + manifest_obj = parse_manifest(manifest=manifest_dict) return manifest_obj def get_project_dict(self): with open(self.project_dir / PROJECT_FILE) as project: logger.info(f"Parsing file {PROJECT_FILE}") - project_dict = self.yaml.safe_load(project) + project_dict = yaml.safe_load(project) return project_dict def get_connection_creds(self) -> Tuple[Dict[str, str], str]: profiles_path = self.profiles_dir / PROFILES_FILE with open(profiles_path) as profiles: logger.info(f"Parsing file {profiles_path}") - profiles = self.yaml.safe_load(profiles) + profiles = yaml.safe_load(profiles) dbt_profile_var = self.project_dict.get("profile") @@ -220,7 +215,7 @@ def get_connection_creds(self) -> Tuple[Dict[str, str], str]: profiles, dbt_profile_var, f"No profile '{dbt_profile_var}' found in '{profiles_path}'." ) # values can contain env_vars - rendered_profile = self.ProfileRenderer().render_data(profile) + rendered_profile = ProfileRenderer().render_data(profile) profile_target = get_from_dict_with_raise( rendered_profile, "target", f"No target found in profile '{dbt_profile_var}' in '{profiles_path}'." ) diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index 8ef01373..f211a549 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -63,7 +63,7 @@ def _inner(): except ModuleNotFoundError as e: s = text if package: - s += f"You can install it using 'pip install data_diff[{package}]'." + s += f"Please complete setup by running 'pip install data_diff[{package}]'." raise ModuleNotFoundError(f"{e}\n\n{s}\n") return _inner diff --git a/poetry.lock b/poetry.lock index 2dc50831..d0ff670b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. [[package]] name = "agate" @@ -2482,13 +2482,13 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [extras] clickhouse = ["clickhouse-driver"] -dbt = ["dbt-artifacts-parser", "dbt-core"] duckdb = ["duckdb"] mysql = ["mysql-connector-python"] oracle = ["cx_Oracle"] postgresql = ["psycopg2"] preql = ["preql"] presto = ["presto-python-client"] +redshift = ["psycopg2"] snowflake = ["cryptography", "snowflake-connector-python"] trino = ["trino"] vertica = ["vertica-python"] @@ -2496,4 +2496,4 @@ vertica = ["vertica-python"] [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "96bcb369e66de27b5ad8e86337bcf7af471c8b2c84b748c3b9ddd4ee8155c001" +content-hash = "719f272f2c581722d09a319586392329963fdd6a4f1aba739e1e7f244744ad80" diff --git a/pyproject.toml b/pyproject.toml index 1f305ebb..d3e0af32 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ trino = {version="^0.314.0", optional=true} presto-python-client = {version="*", optional=true} clickhouse-driver = {version="*", optional=true} duckdb = {version="^0.7.0", optional=true} -dbt-artifacts-parser = {version="^0.3.0", optional=true} -dbt-core = {version="^1.0.0", optional=true} +dbt-artifacts-parser = {version="^0.3.0"} +dbt-core = {version="^1.0.0"} keyring = "*" tabulate = "^0.9.0" preql = {version="^0.2.19", optional=true} @@ -71,6 +71,7 @@ dbt-core = "^1.0.0" preql = ["preql"] mysql = ["mysql-connector-python"] postgresql = ["psycopg2"] +redshift = ["psycopg2"] snowflake = ["snowflake-connector-python", "cryptography"] presto = ["presto-python-client"] oracle = ["cx_Oracle"] @@ -79,7 +80,6 @@ trino = ["trino"] clickhouse = ["clickhouse-driver"] vertica = ["vertica-python"] duckdb = ["duckdb"] -dbt = ["dbt-core", "dbt-artifacts-parser"] [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/test_dbt.py b/tests/test_dbt.py index 6e3dbeb0..59317190 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -100,15 +100,16 @@ def test_get_models_no_selection(self): mock_self.get_run_results_models.assert_called() self.assertEqual(models, mock_return_value) + @patch("data_diff.dbt_parser.parse_run_results") @patch("builtins.open", new_callable=mock_open, read_data="{}") - def test_get_run_results_models(self, mock_open): + def test_get_run_results_models(self, mock_open, mock_artifact_parser): mock_model = {"success_unique_id": "expected_value"} mock_self = Mock() mock_self.project_dir = Path() mock_run_results = Mock() mock_success_result = Mock() mock_failed_result = Mock() - mock_self.parse_run_results.return_value = mock_run_results + mock_artifact_parser.return_value = mock_run_results mock_run_results.metadata.dbt_version = "1.0.0" mock_success_result.unique_id = "success_unique_id" mock_failed_result.unique_id = "failed_unique_id" @@ -121,32 +122,34 @@ def test_get_run_results_models(self, mock_open): self.assertEqual(mock_model, models[0]) mock_open.assert_any_call(Path(RUN_RESULTS_PATH)) - mock_self.parse_run_results.assert_called_once_with(run_results={}) + mock_artifact_parser.assert_called_once_with(run_results={}) + @patch("data_diff.dbt_parser.parse_run_results") @patch("builtins.open", new_callable=mock_open, read_data="{}") - def test_get_run_results_models_bad_lower_dbt_version(self, mock_open): + def test_get_run_results_models_bad_lower_dbt_version(self, mock_open, mock_artifact_parser): mock_self = Mock() mock_self.project_dir = Path() mock_run_results = Mock() - mock_self.parse_run_results.return_value = mock_run_results + mock_artifact_parser.return_value = mock_run_results mock_run_results.metadata.dbt_version = "0.19.0" with self.assertRaises(Exception) as ex: DbtParser.get_run_results_models(mock_self) mock_open.assert_called_once_with(Path(RUN_RESULTS_PATH)) - mock_self.parse_run_results.assert_called_once_with(run_results={}) + mock_artifact_parser.assert_called_once_with(run_results={}) mock_self.parse_manifest.assert_not_called() self.assertIn("version to be", ex.exception.args[0]) + @patch("data_diff.dbt_parser.parse_run_results") @patch("builtins.open", new_callable=mock_open, read_data="{}") - def test_get_run_results_models_no_success(self, mock_open): + def test_get_run_results_models_no_success(self, mock_open, mock_artifact_parser): mock_self = Mock() mock_self.project_dir = Path() mock_run_results = Mock() mock_success_result = Mock() mock_failed_result = Mock() - mock_self.parse_run_results.return_value = mock_run_results + mock_artifact_parser.return_value = mock_run_results mock_run_results.metadata.dbt_version = "1.0.0" mock_failed_result.unique_id = "failed_unique_id" mock_success_result.status.name = "success" @@ -157,15 +160,16 @@ def test_get_run_results_models_no_success(self, mock_open): DbtParser.get_run_results_models(mock_self) mock_open.assert_any_call(Path(RUN_RESULTS_PATH)) - mock_self.parse_run_results.assert_called_once_with(run_results={}) + mock_artifact_parser.assert_called_once_with(run_results={}) + @patch("data_diff.dbt_parser.yaml") @patch("builtins.open", new_callable=mock_open, read_data="key:\n value") - def test_get_project_dict(self, mock_open): + def test_get_project_dict(self, mock_open, mock_yaml): expected_dict = {"key1": "value1"} mock_self = Mock() mock_self.project_dir = Path() - mock_self.yaml.safe_load.return_value = expected_dict + mock_yaml.safe_load.return_value = expected_dict project_dict = DbtParser.get_project_dict(mock_self) self.assertEqual(project_dict, expected_dict) @@ -300,8 +304,10 @@ def test_set_connection_not_implemented(self): self.assertNotIsInstance(mock_self.connection, dict) + @patch("data_diff.dbt_parser.yaml") + @patch("data_diff.dbt_parser.ProfileRenderer") @patch("builtins.open", new_callable=mock_open, read_data="") - def test_get_connection_creds_success(self, mock_open): + def test_get_connection_creds_success(self, mock_open, mock_profile_renderer, mock_yaml): profiles_dict = { "a_profile": { "outputs": { @@ -315,26 +321,30 @@ def test_get_connection_creds_success(self, mock_open): mock_self = Mock() mock_self.profiles_dir = Path() mock_self.project_dict = {"profile": "a_profile"} - mock_self.yaml.safe_load.return_value = profiles_dict - mock_self.ProfileRenderer().render_data.return_value = profile + mock_yaml.safe_load.return_value = profiles_dict + mock_profile_renderer().render_data.return_value = profile credentials, conn_type = DbtParser.get_connection_creds(mock_self) self.assertEqual(credentials, expected_credentials) self.assertEqual(conn_type, "type1") + @patch("data_diff.dbt_parser.yaml") + @patch("data_diff.dbt_parser.ProfileRenderer") @patch("builtins.open", new_callable=mock_open, read_data="") - def test_get_connection_no_matching_profile(self, mock_open): + def test_get_connection_no_matching_profile(self, mock_open, mock_profile_renderer, mock_yaml): profiles_dict = {"a_profile": {}} mock_self = Mock() mock_self.profiles_dir = Path() mock_self.project_dict = {"profile": "wrong_profile"} - mock_self.yaml.safe_load.return_value = profiles_dict + mock_yaml.safe_load.return_value = profiles_dict profile = profiles_dict["a_profile"] - mock_self.ProfileRenderer().render_data.return_value = profile + mock_profile_renderer().render_data.return_value = profile with self.assertRaises(ValueError): _, _ = DbtParser.get_connection_creds(mock_self) + @patch("data_diff.dbt_parser.yaml") + @patch("data_diff.dbt_parser.ProfileRenderer") @patch("builtins.open", new_callable=mock_open, read_data="") - def test_get_connection_no_target(self, mock_open): + def test_get_connection_no_target(self, mock_open, mock_profile_renderer, mock_yaml): profiles_dict = { "a_profile": { "outputs": { @@ -345,9 +355,9 @@ def test_get_connection_no_target(self, mock_open): mock_self = Mock() mock_self.profiles_dir = Path() profile = profiles_dict["a_profile"] - mock_self.ProfileRenderer().render_data.return_value = profile + mock_profile_renderer().render_data.return_value = profile mock_self.project_dict = {"profile": "a_profile"} - mock_self.yaml.safe_load.return_value = profiles_dict + mock_yaml.safe_load.return_value = profiles_dict with self.assertRaises(ValueError): _, _ = DbtParser.get_connection_creds(mock_self) @@ -356,20 +366,24 @@ def test_get_connection_no_target(self, mock_open): target: a_target """ + @patch("data_diff.dbt_parser.yaml") + @patch("data_diff.dbt_parser.ProfileRenderer") @patch("builtins.open", new_callable=mock_open, read_data="") - def test_get_connection_no_outputs(self, mock_open): + def test_get_connection_no_outputs(self, mock_open, mock_profile_renderer, mock_yaml): profiles_dict = {"a_profile": {"target": "a_target"}} mock_self = Mock() mock_self.profiles_dir = Path() mock_self.project_dict = {"profile": "a_profile"} profile = profiles_dict["a_profile"] - mock_self.ProfileRenderer().render_data.return_value = profile - mock_self.yaml.safe_load.return_value = profiles_dict + mock_profile_renderer().render_data.return_value = profile + mock_yaml.safe_load.return_value = profiles_dict with self.assertRaises(ValueError): _, _ = DbtParser.get_connection_creds(mock_self) + @patch("data_diff.dbt_parser.yaml") + @patch("data_diff.dbt_parser.ProfileRenderer") @patch("builtins.open", new_callable=mock_open, read_data="") - def test_get_connection_no_credentials(self, mock_open): + def test_get_connection_no_credentials(self, mock_open, mock_profile_renderer, mock_yaml): profiles_dict = { "a_profile": { "outputs": {"a_target": {}}, @@ -379,14 +393,16 @@ def test_get_connection_no_credentials(self, mock_open): mock_self = Mock() mock_self.profiles_dir = Path() mock_self.project_dict = {"profile": "a_profile"} - mock_self.yaml.safe_load.return_value = profiles_dict + mock_yaml.safe_load.return_value = profiles_dict profile = profiles_dict["a_profile"] - mock_self.ProfileRenderer().render_data.return_value = profile + mock_profile_renderer().render_data.return_value = profile with self.assertRaises(ValueError): _, _ = DbtParser.get_connection_creds(mock_self) + @patch("data_diff.dbt_parser.yaml") + @patch("data_diff.dbt_parser.ProfileRenderer") @patch("builtins.open", new_callable=mock_open, read_data="") - def test_get_connection_no_target_credentials(self, mock_open): + def test_get_connection_no_target_credentials(self, mock_open, mock_profile_renderer, mock_yaml): profiles_dict = { "a_profile": { "outputs": { @@ -399,13 +415,15 @@ def test_get_connection_no_target_credentials(self, mock_open): mock_self.profiles_dir = Path() mock_self.project_dict = {"profile": "a_profile"} profile = profiles_dict["a_profile"] - mock_self.ProfileRenderer().render_data.return_value = profile - mock_self.yaml.safe_load.return_value = profiles_dict + mock_profile_renderer().render_data.return_value = profile + mock_yaml.safe_load.return_value = profiles_dict with self.assertRaises(ValueError): _, _ = DbtParser.get_connection_creds(mock_self) + @patch("data_diff.dbt_parser.yaml") + @patch("data_diff.dbt_parser.ProfileRenderer") @patch("builtins.open", new_callable=mock_open, read_data="") - def test_get_connection_no_type(self, mock_open): + def test_get_connection_no_type(self, mock_open, mock_profile_renderer, mock_yaml): profiles_dict = { "a_profile": { "outputs": {"a_target": {"credential_1": "credential_1", "credential_2": "credential_2"}}, @@ -415,9 +433,9 @@ def test_get_connection_no_type(self, mock_open): mock_self = Mock() mock_self.profiles_dir = Path() mock_self.project_dict = {"profile": "a_profile"} - mock_self.yaml.safe_load.return_value = profiles_dict + mock_yaml.safe_load.return_value = profiles_dict profile = profiles_dict["a_profile"] - mock_self.ProfileRenderer().render_data.return_value = profile + mock_profile_renderer().render_data.return_value = profile with self.assertRaises(ValueError): _, _ = DbtParser.get_connection_creds(mock_self) From 9914a6d8303f6dac7b8e4e6e803357f29c5619ed Mon Sep 17 00:00:00 2001 From: Dan Date: Thu, 11 May 2023 11:33:13 -0600 Subject: [PATCH 2/3] format black -l 120 --- data_diff/sqeleton/databases/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index f211a549..174e4bc8 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -36,7 +36,7 @@ DbTime, DbPath, Boolean, - JSON + JSON, ) from ..abcs.mixins import Compilable from ..abcs.mixins import ( From 8a34aca469c53b482e78c925b1adb444738af5a7 Mon Sep 17 00:00:00 2001 From: Dan Date: Thu, 11 May 2023 11:35:50 -0600 Subject: [PATCH 3/3] add quotes in pip install instruction --- data_diff/sqeleton/databases/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index 174e4bc8..73c69424 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -63,7 +63,7 @@ def _inner(): except ModuleNotFoundError as e: s = text if package: - s += f"Please complete setup by running 'pip install data_diff[{package}]'." + s += f"Please complete setup by running: pip install 'data_diff[{package}]'." raise ModuleNotFoundError(f"{e}\n\n{s}\n") return _inner