Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Search for temporary token (#98)
Browse files Browse the repository at this point in the history
* Try to get attributes else do not

* Simplify logic

* fix compute engine credentials method (#100)

* Add tests and cleanup

Co-authored-by: Lucien Fregosi <lucien.fregosi@bodyguard.ai>
Co-authored-by: Alexander Streed <desertaxle@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 7, 2022
1 parent 44e954b commit a25c533
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
13 changes: 10 additions & 3 deletions prefect_dbt/cli/configs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
except ImportError:
from typing_extensions import Literal

from prefect.utilities.asyncutils import sync_compatible

from prefect_dbt.cli.configs.base import MissingExtrasRequireError, TargetConfigs

try:
Expand Down Expand Up @@ -84,7 +86,8 @@ class BigQueryTargetConfigs(TargetConfigs):
project: Optional[str] = None
credentials: GcpCredentials

def get_configs(self) -> Dict[str, Any]:
@sync_compatible
async def get_configs(self) -> Dict[str, Any]:
"""
Returns the dbt configs specific to BigQuery profile.
Expand All @@ -106,12 +109,16 @@ def get_configs(self) -> Dict[str, Any]:
configs_json["method"] = "service-account"
configs_json["keyfile"] = str(configs_json.pop("service_account_file"))
else:
configs_json["method"] = "oauth-secrets"
# through gcloud application-default login
google_credentials = (
self_copy.credentials.get_credentials_from_service_account()
)
for key in ("refresh_token", "client_id", "client_secret", "token_uri"):
configs_json[key] = getattr(google_credentials, key)
if hasattr(google_credentials, "token"):
configs_json["token"] = await self_copy.credentials.get_access_token()
else:
for key in ("refresh_token", "client_id", "client_secret", "token_uri"):
configs_json[key] = getattr(google_credentials, key)

if "project" not in configs_json:
raise ValueError(
Expand Down
45 changes: 38 additions & 7 deletions tests/cli/configs/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from unittest.mock import MagicMock
from unittest.mock import MagicMock, seal

import pytest
from prefect_gcp.credentials import GcpCredentials
Expand Down Expand Up @@ -36,12 +36,9 @@ def service_account_file(monkeypatch, tmp_path, service_account_info_dict):

@pytest.fixture
def google_auth(monkeypatch):
google_auth_mock = MagicMock()
google_auth_mock = MagicMock(name="google_auth")
default_credentials_mock = MagicMock(
refresh_token="my_refresh_token",
token_uri="my_token_uri",
client_id="my_client_id",
client_secret="my_client_secret",
name="default_credentials",
quota_project_id="my_project",
)
google_auth_mock.default.side_effect = lambda *args, **kwargs: (
Expand Down Expand Up @@ -85,13 +82,24 @@ def test_get_configs_service_account_info(self, service_account_info_dict):
}
assert actual == expected

def test_get_configs_gcloud_cli(self, google_auth):
def test_get_configs_gcloud_cli_refresh_token(self, google_auth):
gcp_credentials = GcpCredentials()
configs = BigQueryTargetConfigs(
credentials=gcp_credentials, project="my_project", schema="my_schema"
)
google_credentials = MagicMock(
refresh_token="my_refresh_token",
token_uri="my_token_uri",
client_id="my_client_id",
client_secret="my_client_secret",
)
seal(google_credentials)
gcp_credentials.get_credentials_from_service_account = (
lambda: google_credentials
)
actual = configs.get_configs()
expected = {
"method": "oauth-secrets",
"type": "bigquery",
"schema": "my_schema",
"threads": 4,
Expand All @@ -103,6 +111,29 @@ def test_get_configs_gcloud_cli(self, google_auth):
}
assert actual == expected

def test_get_configs_gcloud_cli_temporary_token(self, google_auth):
gcp_credentials = GcpCredentials()
configs = BigQueryTargetConfigs(
credentials=gcp_credentials, project="my_project", schema="my_schema"
)
google_credentials = MagicMock(
token="my_token", refresh=lambda *args, **kwargs: "refreshed"
)
seal(google_credentials)
gcp_credentials.get_credentials_from_service_account = (
lambda: google_credentials
)
actual = configs.get_configs()
expected = {
"method": "oauth-secrets",
"type": "bigquery",
"schema": "my_schema",
"threads": 4,
"project": "my_project",
"token": "my_token",
}
assert actual == expected

def test_get_configs_project_from_service_account_file(self, service_account_file):
gcp_credentials = GcpCredentials(service_account_file=service_account_file)
configs = BigQueryTargetConfigs(credentials=gcp_credentials, schema="schema")
Expand Down

0 comments on commit a25c533

Please sign in to comment.