diff --git a/CHANGELOG.md b/CHANGELOG.md index f710998d614..ae9e5794d68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### Features - Add a "docs" field to models, with a "show" subfield ([#1671](https://github.com/fishtown-analytics/dbt/issues/1671), [#2107](https://github.com/fishtown-analytics/dbt/pull/2107)) +- Use refresh tokens in snowflake instead of access tokens ([#2126](https://github.com/fishtown-analytics/dbt/issues/2126), [#2141](https://github.com/fishtown-analytics/dbt/pull/2141)) ### Fixes - Fix issue where dbt did not give an error in the presence of duplicate doc names ([#2054](https://github.com/fishtown-analytics/dbt/issues/2054), [#2080](https://github.com/fishtown-analytics/dbt/pull/2080)) diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index e6b60d1fafb..826dfb154ce 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -1,3 +1,4 @@ +import base64 import datetime import pytz import re @@ -8,15 +9,22 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +import requests import snowflake.connector import snowflake.connector.errors -import dbt.exceptions +from dbt.exceptions import ( + InternalException, RuntimeException, FailedToConnectException, + DatabaseException, warn_or_error +) from dbt.adapters.base import Credentials from dbt.adapters.sql import SQLConnectionManager from dbt.logger import GLOBAL_LOGGER as logger +_TOKEN_REQUEST_URL = 'https://{}.snowflakecomputing.com/oauth/token-request' + + @dataclass class SnowflakeCredentials(Credentials): account: str @@ -28,15 +36,30 @@ class SnowflakeCredentials(Credentials): private_key_path: Optional[str] private_key_passphrase: Optional[str] token: Optional[str] + oauth_client_id: Optional[str] + oauth_client_secret: Optional[str] client_session_keep_alive: bool = False + def __post_init__(self): + if ( + self.authenticator != 'oauth' and + (self.oauth_client_secret or self.oauth_client_id or self.token) + ): + # the user probably forgot to set 'authenticator' like I keep doing + warn_or_error( + 'Authenticator is not set to oauth, but an oauth-only ' + 'parameter is set! Did you mean to set authenticator: oauth?' + ) + @property def type(self): return 'snowflake' def _connection_keys(self): - return ('account', 'user', 'database', 'schema', 'warehouse', 'role', - 'client_session_keep_alive') + return ( + 'account', 'user', 'database', 'schema', 'warehouse', 'role', + 'client_session_keep_alive' + ) def auth_args(self): # Pull all of the optional authentication args for the connector, @@ -47,10 +70,63 @@ def auth_args(self): if self.authenticator: result['authenticator'] = self.authenticator if self.authenticator == 'oauth': - result['token'] = self.token + token = self.token + # if we have a client ID/client secret, the token is a refresh + # token, not an access token + if self.oauth_client_id and self.oauth_client_secret: + token = self._get_access_token() + elif self.oauth_client_id: + warn_or_error( + 'Invalid profile: got an oauth_client_id, but not an ' + 'oauth_client_secret!' + ) + elif self.oauth_client_secret: + warn_or_error( + 'Invalid profile: got an oauth_client_secret, but not ' + 'an oauth_client_id!' + ) + + result['token'] = token result['private_key'] = self._get_private_key() return result + def _get_access_token(self) -> str: + if self.authenticator != 'oauth': + raise InternalException('Can only get access tokens for oauth') + missing = any( + x is None for x in + (self.oauth_client_id, self.oauth_client_secret, self.token) + ) + if missing: + raise InternalException( + 'need a client ID a client secret, and a refresh token to get ' + 'an access token' + ) + # should the full url be a config item? + token_url = _TOKEN_REQUEST_URL.format(self.account) + # I think this is only used to redirect on success, which we ignore + # (it does not have to match the integration's settings in snowflake) + redirect_uri = 'http://localhost:9999' + data = { + 'grant_type': 'refresh_token', + 'refresh_token': self.token, + 'redirect_uri': redirect_uri + } + + auth = base64.b64encode( + f'{self.oauth_client_id}:{self.oauth_client_secret}' + .encode('ascii') + ).decode('ascii') + headers = { + 'Authorization': f'Basic {auth}', + 'Content-type': 'application/x-www-form-urlencoded;charset=utf-8' + } + result = requests.post(token_url, headers=headers, data=data) + result_json = result.json() + if 'access_token' not in result_json: + raise DatabaseException(f'Did not get a token: {result_json}') + return result_json['access_token'] + def _get_private_key(self): """Get Snowflake private key by path or None.""" if not self.private_key_path or self.private_key_passphrase is None: @@ -84,7 +160,7 @@ def exception_handler(self, sql): logger.debug("got empty sql statement, moving on") elif 'This session does not have a current database' in msg: self.release() - raise dbt.exceptions.FailedToConnectException( + raise FailedToConnectException( ('{}\n\nThis error sometimes occurs when invalid ' 'credentials are provided, or when your default role ' 'does not have access to use the specified database. ' @@ -92,17 +168,17 @@ def exception_handler(self, sql): .format(msg)) else: self.release() - raise dbt.exceptions.DatabaseException(msg) + raise DatabaseException(msg) except Exception as e: logger.debug("Error running SQL: {}", sql) logger.debug("Rolling back transaction.") self.release() - if isinstance(e, dbt.exceptions.RuntimeException): + if isinstance(e, RuntimeException): # during a sql query, an internal to dbt exception was raised. # this sounds a lot like a signal handler and probably has # useful information, so raise it without modification. raise - raise dbt.exceptions.RuntimeException(str(e)) from e + raise RuntimeException(str(e)) from e @classmethod def open(cls, connection): @@ -136,7 +212,7 @@ def open(cls, connection): connection.handle = None connection.state = 'fail' - raise dbt.exceptions.FailedToConnectException(str(e)) + raise FailedToConnectException(str(e)) def cancel(self, connection): handle = connection.handle @@ -228,7 +304,7 @@ def add_query(self, sql, auto_begin=True, else: conn_name = conn.name - raise dbt.exceptions.RuntimeException( + raise RuntimeException( "Tried to run an empty query on model '{}'. If you are " "conditionally running\nsql, eg. in a model hook, make " "sure your `else` clause contains valid sql!\n\n" diff --git a/scripts/werkzeug-refresh-token.py b/scripts/werkzeug-refresh-token.py new file mode 100644 index 00000000000..45a9e5104b7 --- /dev/null +++ b/scripts/werkzeug-refresh-token.py @@ -0,0 +1,139 @@ +import argparse +import json +import secrets +import textwrap +from base64 import b64encode + +import requests +from werkzeug import redirect +from werkzeug.middleware.dispatcher import DispatcherMiddleware +from werkzeug.wrappers import Request, Response +from werkzeug.serving import run_simple +from urllib.parse import urlencode + + +def _make_rfp_claim_value(): + # from https://tools.ietf.org/id/draft-bradley-oauth-jwt-encoded-state-08.html#rfc.section.4 # noqa + # we can do whatever we want really, so just token.urlsafe? + return secrets.token_urlsafe(112) + + +def _make_response(client_id, client_secret, refresh_token): + return Response(textwrap.dedent( + f'''\ + SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN="{refresh_token}" + SNOWFLAKE_TEST_OAUTH_CLIENT_ID="{client_id}" + SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET="{client_secret}"''' + )) + + +class TokenManager: + def __init__(self, account_name, client_id, client_secret): + self.account_name = account_name + self.client_id = client_id + self.client_secret = client_secret + self.token = None + self.rfp_claim = _make_rfp_claim_value() + self.port = 8080 + + @property + def account_url(self): + return f'https://{self.account_name}.snowflakecomputing.com' + + @property + def auth_url(self): + return f'{self.account_url}/oauth/authorize' + + @property + def token_url(self): + return f'{self.account_url}/oauth/token-request' + + @property + def redirect_uri(self): + return f'http://localhost:{self.port}' + + @property + def headers(self): + auth = f'{self.client_id}:{self.client_secret}'.encode('ascii') + encoded_auth = b64encode(auth).decode('ascii') + return { + 'Authorization': f'Basic {encoded_auth}', + 'Content-type': 'application/x-www-form-urlencoded; charset=utf-8' + } + + def _code_to_token(self, code): + data = { + 'grant_type': 'authorization_code', + 'code': code, + 'redirect_uri': self.redirect_uri, + } + # data = urlencode(data) + resp = requests.post( + url=self.token_url, + headers=self.headers, + data=data, + ) + try: + refresh_token = resp.json()['refresh_token'] + except KeyError: + print(resp.json()) + raise + return refresh_token + + @Request.application + def auth(self, request): + code = request.args.get('code') + if code: + # we got 303'ed here with a code + state_received = request.args.get('state') + if state_received != self.rfp_claim: + return Response('Invalid RFP claim: MITM?', status=401) + refresh_token = self._code_to_token(code) + return _make_response( + self.client_id, + self.client_secret, + refresh_token, + ) + else: + return redirect('/login') + + @Request.application + def login(self, request): + # take the auth URL and add the query string to it + query = { + 'response_type': 'code', + 'client_id': self.client_id, + 'redirect_uri': self.redirect_uri, + 'state': self.rfp_claim, + } + query = urlencode(query) + return redirect(f'{self.auth_url}?{query}') + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('account_name', help='The account name') + parser.add_argument('json_blob', help='The json auth blob') + + return parser.parse_args() + + +def main(): + args = parse_args() + data = json.loads(args.json_blob) + client_id = data['OAUTH_CLIENT_ID'] + client_secret = data['OAUTH_CLIENT_SECRET'] + token_manager = TokenManager( + account_name=args.account_name, + client_id=client_id, + client_secret=client_secret, + ) + app = DispatcherMiddleware(token_manager.auth, { + '/login': token_manager.login, + }) + + run_simple('localhost', token_manager.port, app) + + +if __name__ == '__main__': + main() diff --git a/test.env.sample b/test.env.sample index b0f0e95ce45..9f87a3d9461 100644 --- a/test.env.sample +++ b/test.env.sample @@ -7,6 +7,9 @@ SNOWFLAKE_TEST_ALT_DATABASE= SNOWFLAKE_TEST_QUOTED_DATABASE= SNOWFLAKE_TEST_WAREHOUSE= SNOWFLAKE_TEST_ALT_WAREHOUSE= +SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN= +SNOWFLAKE_TEST_OAUTH_CLIENT_ID= +SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET= BIGQUERY_TYPE= BIGQUERY_PROJECT_ID= diff --git a/test/integration/057_oauth_tests/models/model_1.sql b/test/integration/057_oauth_tests/models/model_1.sql new file mode 100644 index 00000000000..43258a71464 --- /dev/null +++ b/test/integration/057_oauth_tests/models/model_1.sql @@ -0,0 +1 @@ +select 1 as id diff --git a/test/integration/057_oauth_tests/models/model_2.sql b/test/integration/057_oauth_tests/models/model_2.sql new file mode 100644 index 00000000000..33560d6c082 --- /dev/null +++ b/test/integration/057_oauth_tests/models/model_2.sql @@ -0,0 +1 @@ +select 2 as id diff --git a/test/integration/057_oauth_tests/models/model_3.sql b/test/integration/057_oauth_tests/models/model_3.sql new file mode 100644 index 00000000000..c724eec6fa9 --- /dev/null +++ b/test/integration/057_oauth_tests/models/model_3.sql @@ -0,0 +1,3 @@ +select * from {{ ref('model_1') }} +union all +select * from {{ ref('model_2') }} diff --git a/test/integration/057_oauth_tests/models/model_4.sql b/test/integration/057_oauth_tests/models/model_4.sql new file mode 100644 index 00000000000..2e8896da085 --- /dev/null +++ b/test/integration/057_oauth_tests/models/model_4.sql @@ -0,0 +1,3 @@ +select 1 as id +union all +select 2 as id diff --git a/test/integration/057_oauth_tests/test_oauth.py b/test/integration/057_oauth_tests/test_oauth.py new file mode 100644 index 00000000000..a602f5e2522 --- /dev/null +++ b/test/integration/057_oauth_tests/test_oauth.py @@ -0,0 +1,65 @@ +""" +The first time using an account for testing, you should run this: + +``` +CREATE OR REPLACE SECURITY INTEGRATION DBT_INTEGRATION_TEST_OAUTH + TYPE = OAUTH + ENABLED = TRUE + OAUTH_CLIENT = CUSTOM + OAUTH_CLIENT_TYPE = 'CONFIDENTIAL' + OAUTH_REDIRECT_URI = 'http://localhost:8080' + oauth_issue_refresh_tokens = true + OAUTH_ALLOW_NON_TLS_REDIRECT_URI = true + BLOCKED_ROLES_LIST = + oauth_refresh_token_validity = 7776000; +``` + + +Every month (or any amount <90 days): + +Run `select SYSTEM$SHOW_OAUTH_CLIENT_SECRETS('DBT_INTEGRATION_TEST_OAUTH');` + +The only row/column of output should be a json blob, it goes (within single +quotes!) as the second argument to the server script: + +python scripts/werkzeug-refresh-token.py ${acount_name} '${json_blob}' + +Open http://localhost:8080 + +Log in as the test user, get a resonse page with some environment variables. +Update CI providers and test.env with the new values (If you kept the security +integration the same, just the refresh token changed) +""" +from test.integration.base import DBTIntegrationTest, use_profile + + +class TestSnowflakeOauth(DBTIntegrationTest): + @property + def schema(self): + return "simple_copy_001" + + @staticmethod + def dir(path): + return path.lstrip('/') + + @property + def models(self): + return self.dir("models") + + def snowflake_profile(self): + profile = super().snowflake_profile() + profile['test']['target'] = 'oauth' + missing = ', '.join( + k for k in ('oauth_client_id', 'oauth_client_secret', 'token') + if k not in profile['test']['outputs']['oauth'] + ) + if missing: + raise ValueError(f'Cannot run test - {missing} not configured') + del profile['test']['outputs']['default2'] + del profile['test']['outputs']['noaccess'] + return profile + + @use_profile('snowflake') + def test_snowflake_basic(self): + self.run_dbt() + self.assertManyRelationsEqual([['MODEL_3'], ['MODEL_4']]) diff --git a/test/integration/base.py b/test/integration/base.py index 6b2cf545f10..91827df1066 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -210,7 +210,20 @@ def snowflake_profile(self): 'database': os.getenv('SNOWFLAKE_TEST_DATABASE'), 'schema': self.unique_schema(), 'warehouse': os.getenv('SNOWFLAKE_TEST_WAREHOUSE'), - } + }, + 'oauth': { + 'type': 'snowflake', + 'threads': 4, + 'account': os.getenv('SNOWFLAKE_TEST_ACCOUNT'), + 'user': os.getenv('SNOWFLAKE_TEST_USER'), + 'oauth_client_id': os.getenv('SNOWFLAKE_TEST_OAUTH_CLIENT_ID'), + 'oauth_client_secret': os.getenv('SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET'), + 'token': os.getenv('SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN'), + 'database': os.getenv('SNOWFLAKE_TEST_DATABASE'), + 'schema': self.unique_schema(), + 'warehouse': os.getenv('SNOWFLAKE_TEST_WAREHOUSE'), + 'authenticator': 'oauth', + }, }, 'target': 'default2' }