From a25c533b84c04d234b5e7107a4ec625e1a50d188 Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Wed, 7 Dec 2022 14:15:31 -0800 Subject: [PATCH] Search for temporary token (#98) * Try to get attributes else do not * Simplify logic * fix compute engine credentials method (#100) * Add tests and cleanup Co-authored-by: Lucien Fregosi Co-authored-by: Alexander Streed --- prefect_dbt/cli/configs/bigquery.py | 13 +++++++-- tests/cli/configs/test_bigquery.py | 45 ++++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/prefect_dbt/cli/configs/bigquery.py b/prefect_dbt/cli/configs/bigquery.py index 99d6caa..2fe76e7 100644 --- a/prefect_dbt/cli/configs/bigquery.py +++ b/prefect_dbt/cli/configs/bigquery.py @@ -6,6 +6,8 @@ except ImportError: from typing_extensions import Literal +from prefect.utilities.asyncutils import sync_compatible + from prefect_dbt.cli.configs.base import MissingExtrasRequireError, TargetConfigs try: @@ -84,7 +86,8 @@ class BigQueryTargetConfigs(TargetConfigs): project: Optional[str] = None credentials: GcpCredentials - def get_configs(self) -> Dict[str, Any]: + @sync_compatible + async def get_configs(self) -> Dict[str, Any]: """ Returns the dbt configs specific to BigQuery profile. @@ -106,12 +109,16 @@ def get_configs(self) -> Dict[str, Any]: configs_json["method"] = "service-account" configs_json["keyfile"] = str(configs_json.pop("service_account_file")) else: + configs_json["method"] = "oauth-secrets" # through gcloud application-default login google_credentials = ( self_copy.credentials.get_credentials_from_service_account() ) - for key in ("refresh_token", "client_id", "client_secret", "token_uri"): - configs_json[key] = getattr(google_credentials, key) + if hasattr(google_credentials, "token"): + configs_json["token"] = await self_copy.credentials.get_access_token() + else: + for key in ("refresh_token", "client_id", "client_secret", "token_uri"): + configs_json[key] = getattr(google_credentials, key) if "project" not in configs_json: raise ValueError( diff --git a/tests/cli/configs/test_bigquery.py b/tests/cli/configs/test_bigquery.py index 05210bb..564ae2f 100644 --- a/tests/cli/configs/test_bigquery.py +++ b/tests/cli/configs/test_bigquery.py @@ -1,5 +1,5 @@ import json -from unittest.mock import MagicMock +from unittest.mock import MagicMock, seal import pytest from prefect_gcp.credentials import GcpCredentials @@ -36,12 +36,9 @@ def service_account_file(monkeypatch, tmp_path, service_account_info_dict): @pytest.fixture def google_auth(monkeypatch): - google_auth_mock = MagicMock() + google_auth_mock = MagicMock(name="google_auth") default_credentials_mock = MagicMock( - refresh_token="my_refresh_token", - token_uri="my_token_uri", - client_id="my_client_id", - client_secret="my_client_secret", + name="default_credentials", quota_project_id="my_project", ) google_auth_mock.default.side_effect = lambda *args, **kwargs: ( @@ -85,13 +82,24 @@ def test_get_configs_service_account_info(self, service_account_info_dict): } assert actual == expected - def test_get_configs_gcloud_cli(self, google_auth): + def test_get_configs_gcloud_cli_refresh_token(self, google_auth): gcp_credentials = GcpCredentials() configs = BigQueryTargetConfigs( credentials=gcp_credentials, project="my_project", schema="my_schema" ) + google_credentials = MagicMock( + refresh_token="my_refresh_token", + token_uri="my_token_uri", + client_id="my_client_id", + client_secret="my_client_secret", + ) + seal(google_credentials) + gcp_credentials.get_credentials_from_service_account = ( + lambda: google_credentials + ) actual = configs.get_configs() expected = { + "method": "oauth-secrets", "type": "bigquery", "schema": "my_schema", "threads": 4, @@ -103,6 +111,29 @@ def test_get_configs_gcloud_cli(self, google_auth): } assert actual == expected + def test_get_configs_gcloud_cli_temporary_token(self, google_auth): + gcp_credentials = GcpCredentials() + configs = BigQueryTargetConfigs( + credentials=gcp_credentials, project="my_project", schema="my_schema" + ) + google_credentials = MagicMock( + token="my_token", refresh=lambda *args, **kwargs: "refreshed" + ) + seal(google_credentials) + gcp_credentials.get_credentials_from_service_account = ( + lambda: google_credentials + ) + actual = configs.get_configs() + expected = { + "method": "oauth-secrets", + "type": "bigquery", + "schema": "my_schema", + "threads": 4, + "project": "my_project", + "token": "my_token", + } + assert actual == expected + def test_get_configs_project_from_service_account_file(self, service_account_file): gcp_credentials = GcpCredentials(service_account_file=service_account_file) configs = BigQueryTargetConfigs(credentials=gcp_credentials, schema="schema")