Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Restrict target configs keys (#112)
Browse files Browse the repository at this point in the history
* Restrict target configs keys

* Undo async

* Remove async

* Support secret dict (#107)

* Updates `DbtCloudCredentials` to implement `CredentialsBlock` interface (#109)

* Adds metadata client

* Enhances docstrings

* Adds tests

* Adds changelog entry

* prep v0.2.7 release (#110)

* Fix private_keys

* Add keys to restrict

* Add schema to rename

* Fix tests

* Update chagnelog

* Add tests

* Add examples

* FIx reqs

* Maybe fix reqs

* Update CHANGELOG.md

* Move breaking change in a separate PR

* Remove breaking change in this PR

* Update .pre-commit-config.yaml

---------

Co-authored-by: Alexander Streed <desertaxle@users.noreply.github.com>
  • Loading branch information
ahuang11 and desertaxle authored Jan 31, 2023
1 parent dc5f504 commit c409166
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
language_version: python3
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

### Changed
- The minimum version of `prefect-snowflake` - [#112](https://github.com/PrefectHQ/prefect-dbt/pull/112)
- Decoupled fields of blocks from external Collections from the created dbt profile - [#112](https://github.com/PrefectHQ/prefect-dbt/pull/112)

### Deprecated

Expand Down
56 changes: 50 additions & 6 deletions prefect_dbt/cli/configs/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Module containing models for BigQuery configs"""
from typing import Any, Dict, Optional

from google.auth.transport.requests import Request

try:
from typing import Literal
except ImportError:
Expand Down Expand Up @@ -74,6 +76,7 @@ class BigQueryTargetConfigs(TargetConfigs):
schema="schema",
project="project",
credentials=credentials,
extras={"execution_project": "my_exe_project"},
)
```
"""
Expand All @@ -98,24 +101,65 @@ async def get_configs(self) -> Dict[str, Any]:
self_copy = self.copy()
if self_copy.project is not None:
self_copy.credentials.project = None
configs_json = self._populate_configs_json(
all_configs_json = self._populate_configs_json(
{}, self_copy.__fields__, model=self_copy
)

if "service_account_info" in configs_json:
# decouple prefect-gcp from prefect-dbt
# by mapping all the keys dbt gcp accepts
# https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup
rename_keys = {
# dbt
"type": "type",
"schema": "schema",
"threads": "threads",
# general
"dataset": "schema",
"method": "method",
"project": "project",
# service-account
"service_account_file": "keyfile",
# service-account json
"service_account_info": "keyfile_json",
# oauth secrets
"refresh_token": "refresh_token",
"client_id": "client_id",
"client_secret": "client_secret",
"token_uri": "token_uri",
# optional
"priority": "priority",
"timeout_seconds": "timeout_seconds",
"location": "location",
"maximum_bytes_billed": "maximum_bytes_billed",
"scopes": "scopes",
"impersonate_service_account": "impersonate_service_account",
"execution_project": "execution_project",
}
configs_json = {}
extras = self.extras or {}
for key in all_configs_json.keys():
if key not in rename_keys and key not in extras:
# skip invalid keys
continue
# rename key to something dbt profile expects
dbt_key = rename_keys.get(key) or key
configs_json[dbt_key] = all_configs_json[key]

if "keyfile_json" in configs_json:
configs_json["method"] = "service-account-json"
configs_json["keyfile_json"] = configs_json.pop("service_account_info")
elif "service_account_file" in configs_json:
elif "keyfile" in configs_json:
configs_json["method"] = "service-account"
configs_json["keyfile"] = str(configs_json.pop("service_account_file"))
configs_json["keyfile"] = str(configs_json["keyfile"])
else:
configs_json["method"] = "oauth-secrets"
# through gcloud application-default login
google_credentials = (
self_copy.credentials.get_credentials_from_service_account()
)
if hasattr(google_credentials, "token"):
configs_json["token"] = await self_copy.credentials.get_access_token()
request = Request()
google_credentials.refresh(request)
configs_json["token"] = google_credentials.token
else:
for key in ("refresh_token", "client_id", "client_secret", "token_uri"):
configs_json[key] = getattr(google_credentials, key)
Expand Down
44 changes: 42 additions & 2 deletions prefect_dbt/cli/configs/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class SnowflakeTargetConfigs(TargetConfigs):
credentials=credentials,
)
target_configs = SnowflakeTargetConfigs(
connector=connector
connector=connector,
extras={"retry_on_database_errors": True},
)
```
"""
Expand All @@ -75,5 +76,44 @@ def get_configs(self) -> Dict[str, Any]:
Returns:
A configs JSON.
"""
configs_json = super().get_configs()
all_configs_json = super().get_configs()

# decouple prefect-snowflake from prefect-dbt
# by mapping all the keys dbt snowflake accepts
# https://docs.getdbt.com/reference/warehouse-setups/snowflake-setup
rename_keys = {
# dbt
"type": "type",
"schema": "schema",
"threads": "threads",
# general
"account": "account",
"user": "user",
"role": "role",
"database": "database",
"warehouse": "warehouse",
# user and password
"password": "password",
# duo mfa / sso
"authenticator": "authenticator",
# key pair
"private_key_path": "private_key_path",
"private_key_passphrase": "private_key_passphrase",
# optional
"client_session_keep_alive": "client_session_keep_alive",
"query_tag": "query_tag",
"connect_retries": "connect_retries",
"connect_timeout": "connect_timeout",
"retry_on_database_errors": "retry_on_database_errors",
"retry_all": "retry_all",
}
configs_json = {}
extras = self.extras or {}
for key in all_configs_json.keys():
if key not in rename_keys and key not in extras:
# skip invalid keys, like fetch_size + poll_frequency_s
continue
# rename key to something dbt profile expects
dbt_key = rename_keys.get(key) or key
configs_json[dbt_key] = all_configs_json[key]
return configs_json
1 change: 0 additions & 1 deletion prefect_dbt/cloud/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ class DbtCloudMetadataClient:
Args:
api_key: API key to authenticate with the dbt Cloud administrative API.
account_id: ID of dbt Cloud account with which to interact.
domain: Domain at which the dbt Cloud API is hosted.
"""

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

extras_require = {
"cli": ["dbt_core>=1.1.1"],
"snowflake": ["prefect-snowflake>=0.2.0"],
"bigquery": ["prefect-gcp>=0.1.8"],
"snowflake": ["prefect-snowflake>=0.2.4"],
"bigquery": ["prefect-gcp[bigquery]>=0.1.8"],
"postgres": ["prefect-sqlalchemy>=0.2.1"],
}
extras_require["all_extras"] = sorted(
Expand Down
20 changes: 20 additions & 0 deletions tests/cli/configs/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ def test_get_configs_service_account_info(self, service_account_info_dict):
}
assert actual == expected

def test_get_configs_service_account_info_extras(self, service_account_info_dict):
gcp_credentials = GcpCredentials(service_account_info=service_account_info_dict)
configs = BigQueryTargetConfigs(
credentials=gcp_credentials,
project="my_project",
schema="my_schema",
extras={"execution_project": "my_exe_project"},
)
actual = configs.get_configs()
expected = {
"type": "bigquery",
"schema": "my_schema",
"threads": 4,
"project": "my_project",
"execution_project": "my_exe_project",
"method": "service-account-json",
"keyfile_json": service_account_info_dict,
}
assert actual == expected

def test_get_configs_gcloud_cli_refresh_token(self, google_auth):
gcp_credentials = GcpCredentials()
configs = BigQueryTargetConfigs(
Expand Down
10 changes: 5 additions & 5 deletions tests/cli/configs/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@ def test_snowflake_target_configs_get_configs():
user="user",
password="password",
)
connector_kwargs = dict(
snowflake_connector = SnowflakeConnector(
schema="schema",
database="database",
warehouse="warehouse",
credentials=credentials,
)
configs = SnowflakeTargetConfigs(
connector=snowflake_connector, extras={"retry_on_database_errors": True}
)

snowflake_connector = SnowflakeConnector(**connector_kwargs)
configs_kwargs = {"connector": snowflake_connector}

configs = SnowflakeTargetConfigs(**configs_kwargs)
actual = configs.get_configs()
expected = dict(
account="account",
Expand All @@ -32,6 +31,7 @@ def test_snowflake_target_configs_get_configs():
database="database",
warehouse="warehouse",
authenticator="snowflake",
retry_on_database_errors=True,
threads=4,
)
for k, v in actual.items():
Expand Down

0 comments on commit c409166

Please sign in to comment.