diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80f9b44..2f089e8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort language_version: python3 diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a118b6..e4c51bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added ### Changed +- The minimum version of `prefect-snowflake` - [#112](https://github.com/PrefectHQ/prefect-dbt/pull/112) +- Decoupled fields of blocks from external Collections from the created dbt profile - [#112](https://github.com/PrefectHQ/prefect-dbt/pull/112) ### Deprecated diff --git a/prefect_dbt/cli/configs/bigquery.py b/prefect_dbt/cli/configs/bigquery.py index 2fe76e7..e870cd9 100644 --- a/prefect_dbt/cli/configs/bigquery.py +++ b/prefect_dbt/cli/configs/bigquery.py @@ -1,6 +1,8 @@ """Module containing models for BigQuery configs""" from typing import Any, Dict, Optional +from google.auth.transport.requests import Request + try: from typing import Literal except ImportError: @@ -74,6 +76,7 @@ class BigQueryTargetConfigs(TargetConfigs): schema="schema", project="project", credentials=credentials, + extras={"execution_project": "my_exe_project"}, ) ``` """ @@ -98,16 +101,55 @@ async def get_configs(self) -> Dict[str, Any]: self_copy = self.copy() if self_copy.project is not None: self_copy.credentials.project = None - configs_json = self._populate_configs_json( + all_configs_json = self._populate_configs_json( {}, self_copy.__fields__, model=self_copy ) - if "service_account_info" in configs_json: + # decouple prefect-gcp from prefect-dbt + # by mapping all the keys dbt gcp accepts + # https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup + rename_keys = { + # dbt + "type": "type", + "schema": "schema", + "threads": "threads", + # general + "dataset": "schema", + "method": "method", + "project": "project", + # service-account + "service_account_file": "keyfile", + # service-account json + "service_account_info": "keyfile_json", + # oauth secrets + "refresh_token": "refresh_token", + "client_id": "client_id", + "client_secret": "client_secret", + "token_uri": "token_uri", + # optional + "priority": "priority", + "timeout_seconds": "timeout_seconds", + "location": "location", + "maximum_bytes_billed": "maximum_bytes_billed", + "scopes": "scopes", + "impersonate_service_account": "impersonate_service_account", + "execution_project": "execution_project", + } + configs_json = {} + extras = self.extras or {} + for key in all_configs_json.keys(): + if key not in rename_keys and key not in extras: + # skip invalid keys + continue + # rename key to something dbt profile expects + dbt_key = rename_keys.get(key) or key + configs_json[dbt_key] = all_configs_json[key] + + if "keyfile_json" in configs_json: configs_json["method"] = "service-account-json" - configs_json["keyfile_json"] = configs_json.pop("service_account_info") - elif "service_account_file" in configs_json: + elif "keyfile" in configs_json: configs_json["method"] = "service-account" - configs_json["keyfile"] = str(configs_json.pop("service_account_file")) + configs_json["keyfile"] = str(configs_json["keyfile"]) else: configs_json["method"] = "oauth-secrets" # through gcloud application-default login @@ -115,7 +157,9 @@ async def get_configs(self) -> Dict[str, Any]: self_copy.credentials.get_credentials_from_service_account() ) if hasattr(google_credentials, "token"): - configs_json["token"] = await self_copy.credentials.get_access_token() + request = Request() + google_credentials.refresh(request) + configs_json["token"] = google_credentials.token else: for key in ("refresh_token", "client_id", "client_secret", "token_uri"): configs_json[key] = getattr(google_credentials, key) diff --git a/prefect_dbt/cli/configs/snowflake.py b/prefect_dbt/cli/configs/snowflake.py index acf26de..4b387ae 100644 --- a/prefect_dbt/cli/configs/snowflake.py +++ b/prefect_dbt/cli/configs/snowflake.py @@ -56,7 +56,8 @@ class SnowflakeTargetConfigs(TargetConfigs): credentials=credentials, ) target_configs = SnowflakeTargetConfigs( - connector=connector + connector=connector, + extras={"retry_on_database_errors": True}, ) ``` """ @@ -75,5 +76,44 @@ def get_configs(self) -> Dict[str, Any]: Returns: A configs JSON. """ - configs_json = super().get_configs() + all_configs_json = super().get_configs() + + # decouple prefect-snowflake from prefect-dbt + # by mapping all the keys dbt snowflake accepts + # https://docs.getdbt.com/reference/warehouse-setups/snowflake-setup + rename_keys = { + # dbt + "type": "type", + "schema": "schema", + "threads": "threads", + # general + "account": "account", + "user": "user", + "role": "role", + "database": "database", + "warehouse": "warehouse", + # user and password + "password": "password", + # duo mfa / sso + "authenticator": "authenticator", + # key pair + "private_key_path": "private_key_path", + "private_key_passphrase": "private_key_passphrase", + # optional + "client_session_keep_alive": "client_session_keep_alive", + "query_tag": "query_tag", + "connect_retries": "connect_retries", + "connect_timeout": "connect_timeout", + "retry_on_database_errors": "retry_on_database_errors", + "retry_all": "retry_all", + } + configs_json = {} + extras = self.extras or {} + for key in all_configs_json.keys(): + if key not in rename_keys and key not in extras: + # skip invalid keys, like fetch_size + poll_frequency_s + continue + # rename key to something dbt profile expects + dbt_key = rename_keys.get(key) or key + configs_json[dbt_key] = all_configs_json[key] return configs_json diff --git a/prefect_dbt/cloud/clients.py b/prefect_dbt/cloud/clients.py index 54a5ce4..cb9ec25 100644 --- a/prefect_dbt/cloud/clients.py +++ b/prefect_dbt/cloud/clients.py @@ -197,7 +197,6 @@ class DbtCloudMetadataClient: Args: api_key: API key to authenticate with the dbt Cloud administrative API. - account_id: ID of dbt Cloud account with which to interact. domain: Domain at which the dbt Cloud API is hosted. """ diff --git a/setup.py b/setup.py index 9ca94a5..57cbf65 100644 --- a/setup.py +++ b/setup.py @@ -13,8 +13,8 @@ extras_require = { "cli": ["dbt_core>=1.1.1"], - "snowflake": ["prefect-snowflake>=0.2.0"], - "bigquery": ["prefect-gcp>=0.1.8"], + "snowflake": ["prefect-snowflake>=0.2.4"], + "bigquery": ["prefect-gcp[bigquery]>=0.1.8"], "postgres": ["prefect-sqlalchemy>=0.2.1"], } extras_require["all_extras"] = sorted( diff --git a/tests/cli/configs/test_bigquery.py b/tests/cli/configs/test_bigquery.py index 564ae2f..05a0e29 100644 --- a/tests/cli/configs/test_bigquery.py +++ b/tests/cli/configs/test_bigquery.py @@ -82,6 +82,26 @@ def test_get_configs_service_account_info(self, service_account_info_dict): } assert actual == expected + def test_get_configs_service_account_info_extras(self, service_account_info_dict): + gcp_credentials = GcpCredentials(service_account_info=service_account_info_dict) + configs = BigQueryTargetConfigs( + credentials=gcp_credentials, + project="my_project", + schema="my_schema", + extras={"execution_project": "my_exe_project"}, + ) + actual = configs.get_configs() + expected = { + "type": "bigquery", + "schema": "my_schema", + "threads": 4, + "project": "my_project", + "execution_project": "my_exe_project", + "method": "service-account-json", + "keyfile_json": service_account_info_dict, + } + assert actual == expected + def test_get_configs_gcloud_cli_refresh_token(self, google_auth): gcp_credentials = GcpCredentials() configs = BigQueryTargetConfigs( diff --git a/tests/cli/configs/test_snowflake.py b/tests/cli/configs/test_snowflake.py index 34d02d6..337bb47 100644 --- a/tests/cli/configs/test_snowflake.py +++ b/tests/cli/configs/test_snowflake.py @@ -11,17 +11,16 @@ def test_snowflake_target_configs_get_configs(): user="user", password="password", ) - connector_kwargs = dict( + snowflake_connector = SnowflakeConnector( schema="schema", database="database", warehouse="warehouse", credentials=credentials, ) + configs = SnowflakeTargetConfigs( + connector=snowflake_connector, extras={"retry_on_database_errors": True} + ) - snowflake_connector = SnowflakeConnector(**connector_kwargs) - configs_kwargs = {"connector": snowflake_connector} - - configs = SnowflakeTargetConfigs(**configs_kwargs) actual = configs.get_configs() expected = dict( account="account", @@ -32,6 +31,7 @@ def test_snowflake_target_configs_get_configs(): database="database", warehouse="warehouse", authenticator="snowflake", + retry_on_database_errors=True, threads=4, ) for k, v in actual.items():