From 6e8d4a6fa6526ed397ede022d0d7cde30c08411f Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Wed, 26 Feb 2020 10:37:12 -0700 Subject: [PATCH] actually update the connection --- .../dbt/adapters/snowflake/connections.py | 96 +++++++++++++++++-- 1 file changed, 86 insertions(+), 10 deletions(-) diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index e6b60d1fafb..826dfb154ce 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -1,3 +1,4 @@ +import base64 import datetime import pytz import re @@ -8,15 +9,22 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +import requests import snowflake.connector import snowflake.connector.errors -import dbt.exceptions +from dbt.exceptions import ( + InternalException, RuntimeException, FailedToConnectException, + DatabaseException, warn_or_error +) from dbt.adapters.base import Credentials from dbt.adapters.sql import SQLConnectionManager from dbt.logger import GLOBAL_LOGGER as logger +_TOKEN_REQUEST_URL = 'https://{}.snowflakecomputing.com/oauth/token-request' + + @dataclass class SnowflakeCredentials(Credentials): account: str @@ -28,15 +36,30 @@ class SnowflakeCredentials(Credentials): private_key_path: Optional[str] private_key_passphrase: Optional[str] token: Optional[str] + oauth_client_id: Optional[str] + oauth_client_secret: Optional[str] client_session_keep_alive: bool = False + def __post_init__(self): + if ( + self.authenticator != 'oauth' and + (self.oauth_client_secret or self.oauth_client_id or self.token) + ): + # the user probably forgot to set 'authenticator' like I keep doing + warn_or_error( + 'Authenticator is not set to oauth, but an oauth-only ' + 'parameter is set! Did you mean to set authenticator: oauth?' + ) + @property def type(self): return 'snowflake' def _connection_keys(self): - return ('account', 'user', 'database', 'schema', 'warehouse', 'role', - 'client_session_keep_alive') + return ( + 'account', 'user', 'database', 'schema', 'warehouse', 'role', + 'client_session_keep_alive' + ) def auth_args(self): # Pull all of the optional authentication args for the connector, @@ -47,10 +70,63 @@ def auth_args(self): if self.authenticator: result['authenticator'] = self.authenticator if self.authenticator == 'oauth': - result['token'] = self.token + token = self.token + # if we have a client ID/client secret, the token is a refresh + # token, not an access token + if self.oauth_client_id and self.oauth_client_secret: + token = self._get_access_token() + elif self.oauth_client_id: + warn_or_error( + 'Invalid profile: got an oauth_client_id, but not an ' + 'oauth_client_secret!' + ) + elif self.oauth_client_secret: + warn_or_error( + 'Invalid profile: got an oauth_client_secret, but not ' + 'an oauth_client_id!' + ) + + result['token'] = token result['private_key'] = self._get_private_key() return result + def _get_access_token(self) -> str: + if self.authenticator != 'oauth': + raise InternalException('Can only get access tokens for oauth') + missing = any( + x is None for x in + (self.oauth_client_id, self.oauth_client_secret, self.token) + ) + if missing: + raise InternalException( + 'need a client ID a client secret, and a refresh token to get ' + 'an access token' + ) + # should the full url be a config item? + token_url = _TOKEN_REQUEST_URL.format(self.account) + # I think this is only used to redirect on success, which we ignore + # (it does not have to match the integration's settings in snowflake) + redirect_uri = 'http://localhost:9999' + data = { + 'grant_type': 'refresh_token', + 'refresh_token': self.token, + 'redirect_uri': redirect_uri + } + + auth = base64.b64encode( + f'{self.oauth_client_id}:{self.oauth_client_secret}' + .encode('ascii') + ).decode('ascii') + headers = { + 'Authorization': f'Basic {auth}', + 'Content-type': 'application/x-www-form-urlencoded;charset=utf-8' + } + result = requests.post(token_url, headers=headers, data=data) + result_json = result.json() + if 'access_token' not in result_json: + raise DatabaseException(f'Did not get a token: {result_json}') + return result_json['access_token'] + def _get_private_key(self): """Get Snowflake private key by path or None.""" if not self.private_key_path or self.private_key_passphrase is None: @@ -84,7 +160,7 @@ def exception_handler(self, sql): logger.debug("got empty sql statement, moving on") elif 'This session does not have a current database' in msg: self.release() - raise dbt.exceptions.FailedToConnectException( + raise FailedToConnectException( ('{}\n\nThis error sometimes occurs when invalid ' 'credentials are provided, or when your default role ' 'does not have access to use the specified database. ' @@ -92,17 +168,17 @@ def exception_handler(self, sql): .format(msg)) else: self.release() - raise dbt.exceptions.DatabaseException(msg) + raise DatabaseException(msg) except Exception as e: logger.debug("Error running SQL: {}", sql) logger.debug("Rolling back transaction.") self.release() - if isinstance(e, dbt.exceptions.RuntimeException): + if isinstance(e, RuntimeException): # during a sql query, an internal to dbt exception was raised. # this sounds a lot like a signal handler and probably has # useful information, so raise it without modification. raise - raise dbt.exceptions.RuntimeException(str(e)) from e + raise RuntimeException(str(e)) from e @classmethod def open(cls, connection): @@ -136,7 +212,7 @@ def open(cls, connection): connection.handle = None connection.state = 'fail' - raise dbt.exceptions.FailedToConnectException(str(e)) + raise FailedToConnectException(str(e)) def cancel(self, connection): handle = connection.handle @@ -228,7 +304,7 @@ def add_query(self, sql, auto_begin=True, else: conn_name = conn.name - raise dbt.exceptions.RuntimeException( + raise RuntimeException( "Tried to run an empty query on model '{}'. If you are " "conditionally running\nsql, eg. in a model hook, make " "sure your `else` clause contains valid sql!\n\n"