Skip to content

Commit

Permalink
Support refresh tokens by requesting entirely TMI
Browse files Browse the repository at this point in the history
Add a little webserver that spits out a refresh token
Add tests + notes on test support
  • Loading branch information
Jacob Beck committed Feb 17, 2020
1 parent 85e6d7f commit 9d1df6c
Show file tree
Hide file tree
Showing 10 changed files with 316 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

### Features
- Add a "docs" field to models, with a "show" subfield ([#1671](https://github.com/fishtown-analytics/dbt/issues/1671), [#2107](https://github.com/fishtown-analytics/dbt/pull/2107))
- Use refresh tokens in snowflake instead of access tokens ([#2126](https://github.com/fishtown-analytics/dbt/issues/2126), [#2141](https://github.com/fishtown-analytics/dbt/pull/2141))

### Fixes
- Fix issue where dbt did not give an error in the presence of duplicate doc names ([#2054](https://github.com/fishtown-analytics/dbt/issues/2054), [#2080](https://github.com/fishtown-analytics/dbt/pull/2080))
Expand Down
96 changes: 86 additions & 10 deletions plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import datetime
import pytz
import re
Expand All @@ -8,15 +9,22 @@

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
import requests
import snowflake.connector
import snowflake.connector.errors

import dbt.exceptions
from dbt.exceptions import (
InternalException, RuntimeException, FailedToConnectException,
DatabaseException, warn_or_error
)
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.logger import GLOBAL_LOGGER as logger


_TOKEN_REQUEST_URL = 'https://{}.snowflakecomputing.com/oauth/token-request'


@dataclass
class SnowflakeCredentials(Credentials):
account: str
Expand All @@ -28,15 +36,30 @@ class SnowflakeCredentials(Credentials):
private_key_path: Optional[str]
private_key_passphrase: Optional[str]
token: Optional[str]
oauth_client_id: Optional[str]
oauth_client_secret: Optional[str]
client_session_keep_alive: bool = False

def __post_init__(self):
if (
self.authenticator != 'oauth' and
(self.oauth_client_secret or self.oauth_client_id or self.token)
):
# the user probably forgot to set 'authenticator' like I keep doing
warn_or_error(
'Authenticator is not set to oauth, but an oauth-only '
'parameter is set! Did you mean to set authenticator: oauth?'
)

@property
def type(self):
return 'snowflake'

def _connection_keys(self):
return ('account', 'user', 'database', 'schema', 'warehouse', 'role',
'client_session_keep_alive')
return (
'account', 'user', 'database', 'schema', 'warehouse', 'role',
'client_session_keep_alive'
)

def auth_args(self):
# Pull all of the optional authentication args for the connector,
Expand All @@ -47,10 +70,63 @@ def auth_args(self):
if self.authenticator:
result['authenticator'] = self.authenticator
if self.authenticator == 'oauth':
result['token'] = self.token
token = self.token
# if we have a client ID/client secret, the token is a refresh
# token, not an access token
if self.oauth_client_id and self.oauth_client_secret:
token = self._get_access_token()
elif self.oauth_client_id:
warn_or_error(
'Invalid profile: got an oauth_client_id, but not an '
'oauth_client_secret!'
)
elif self.oauth_client_secret:
warn_or_error(
'Invalid profile: got an oauth_client_secret, but not '
'an oauth_client_id!'
)

result['token'] = token
result['private_key'] = self._get_private_key()
return result

def _get_access_token(self) -> str:
if self.authenticator != 'oauth':
raise InternalException('Can only get access tokens for oauth')
missing = any(
x is None for x in
(self.oauth_client_id, self.oauth_client_secret, self.token)
)
if missing:
raise InternalException(
'need a client ID a client secret, and a refresh token to get '
'an access token'
)
# should the full url be a config item?
token_url = _TOKEN_REQUEST_URL.format(self.account)
# I think this is only used to redirect on success, which we ignore
# (it does not have to match the integration's settings in snowflake)
redirect_uri = 'http://localhost:9999'
data = {
'grant_type': 'refresh_token',
'refresh_token': self.token,
'redirect_uri': redirect_uri
}

auth = base64.b64encode(
f'{self.oauth_client_id}:{self.oauth_client_secret}'
.encode('ascii')
).decode('ascii')
headers = {
'Authorization': f'Basic {auth}',
'Content-type': 'application/x-www-form-urlencoded;charset=utf-8'
}
result = requests.post(token_url, headers=headers, data=data)
result_json = result.json()
if 'access_token' not in result_json:
raise DatabaseException(f'Did not get a token: {result_json}')
return result_json['access_token']

def _get_private_key(self):
"""Get Snowflake private key by path or None."""
if not self.private_key_path or self.private_key_passphrase is None:
Expand Down Expand Up @@ -84,25 +160,25 @@ def exception_handler(self, sql):
logger.debug("got empty sql statement, moving on")
elif 'This session does not have a current database' in msg:
self.release()
raise dbt.exceptions.FailedToConnectException(
raise FailedToConnectException(
('{}\n\nThis error sometimes occurs when invalid '
'credentials are provided, or when your default role '
'does not have access to use the specified database. '
'Please double check your profile and try again.')
.format(msg))
else:
self.release()
raise dbt.exceptions.DatabaseException(msg)
raise DatabaseException(msg)
except Exception as e:
logger.debug("Error running SQL: {}", sql)
logger.debug("Rolling back transaction.")
self.release()
if isinstance(e, dbt.exceptions.RuntimeException):
if isinstance(e, RuntimeException):
# during a sql query, an internal to dbt exception was raised.
# this sounds a lot like a signal handler and probably has
# useful information, so raise it without modification.
raise
raise dbt.exceptions.RuntimeException(str(e)) from e
raise RuntimeException(str(e)) from e

@classmethod
def open(cls, connection):
Expand Down Expand Up @@ -136,7 +212,7 @@ def open(cls, connection):
connection.handle = None
connection.state = 'fail'

raise dbt.exceptions.FailedToConnectException(str(e))
raise FailedToConnectException(str(e))

def cancel(self, connection):
handle = connection.handle
Expand Down Expand Up @@ -228,7 +304,7 @@ def add_query(self, sql, auto_begin=True,
else:
conn_name = conn.name

raise dbt.exceptions.RuntimeException(
raise RuntimeException(
"Tried to run an empty query on model '{}'. If you are "
"conditionally running\nsql, eg. in a model hook, make "
"sure your `else` clause contains valid sql!\n\n"
Expand Down
139 changes: 139 additions & 0 deletions scripts/werkzeug-refresh-token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import argparse
import json
import secrets
import textwrap
from base64 import b64encode

import requests
from werkzeug import redirect
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from werkzeug.wrappers import Request, Response
from werkzeug.serving import run_simple
from urllib.parse import urlencode


def _make_rfp_claim_value():
# from https://tools.ietf.org/id/draft-bradley-oauth-jwt-encoded-state-08.html#rfc.section.4 # noqa
# we can do whatever we want really, so just token.urlsafe?
return secrets.token_urlsafe(112)


def _make_response(client_id, client_secret, refresh_token):
return Response(textwrap.dedent(
f'''\
SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN="{refresh_token}"
SNOWFLAKE_TEST_OAUTH_CLIENT_ID="{client_id}"
SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET="{client_secret}"'''
))


class TokenManager:
def __init__(self, account_name, client_id, client_secret):
self.account_name = account_name
self.client_id = client_id
self.client_secret = client_secret
self.token = None
self.rfp_claim = _make_rfp_claim_value()
self.port = 8080

@property
def account_url(self):
return f'https://{self.account_name}.snowflakecomputing.com'

@property
def auth_url(self):
return f'{self.account_url}/oauth/authorize'

@property
def token_url(self):
return f'{self.account_url}/oauth/token-request'

@property
def redirect_uri(self):
return f'http://localhost:{self.port}'

@property
def headers(self):
auth = f'{self.client_id}:{self.client_secret}'.encode('ascii')
encoded_auth = b64encode(auth).decode('ascii')
return {
'Authorization': f'Basic {encoded_auth}',
'Content-type': 'application/x-www-form-urlencoded; charset=utf-8'
}

def _code_to_token(self, code):
data = {
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': self.redirect_uri,
}
# data = urlencode(data)
resp = requests.post(
url=self.token_url,
headers=self.headers,
data=data,
)
try:
refresh_token = resp.json()['refresh_token']
except KeyError:
print(resp.json())
raise
return refresh_token

@Request.application
def auth(self, request):
code = request.args.get('code')
if code:
# we got 303'ed here with a code
state_received = request.args.get('state')
if state_received != self.rfp_claim:
return Response('Invalid RFP claim: MITM?', status=401)
refresh_token = self._code_to_token(code)
return _make_response(
self.client_id,
self.client_secret,
refresh_token,
)
else:
return redirect('/login')

@Request.application
def login(self, request):
# take the auth URL and add the query string to it
query = {
'response_type': 'code',
'client_id': self.client_id,
'redirect_uri': self.redirect_uri,
'state': self.rfp_claim,
}
query = urlencode(query)
return redirect(f'{self.auth_url}?{query}')


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('account_name', help='The account name')
parser.add_argument('json_blob', help='The json auth blob')

return parser.parse_args()


def main():
args = parse_args()
data = json.loads(args.json_blob)
client_id = data['OAUTH_CLIENT_ID']
client_secret = data['OAUTH_CLIENT_SECRET']
token_manager = TokenManager(
account_name=args.account_name,
client_id=client_id,
client_secret=client_secret,
)
app = DispatcherMiddleware(token_manager.auth, {
'/login': token_manager.login,
})

run_simple('localhost', token_manager.port, app)


if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions test.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ SNOWFLAKE_TEST_ALT_DATABASE=
SNOWFLAKE_TEST_QUOTED_DATABASE=
SNOWFLAKE_TEST_WAREHOUSE=
SNOWFLAKE_TEST_ALT_WAREHOUSE=
SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN=
SNOWFLAKE_TEST_OAUTH_CLIENT_ID=
SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET=

BIGQUERY_TYPE=
BIGQUERY_PROJECT_ID=
Expand Down
1 change: 1 addition & 0 deletions test/integration/057_oauth_tests/models/model_1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 1 as id
1 change: 1 addition & 0 deletions test/integration/057_oauth_tests/models/model_2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 2 as id
3 changes: 3 additions & 0 deletions test/integration/057_oauth_tests/models/model_3.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
select * from {{ ref('model_1') }}
union all
select * from {{ ref('model_2') }}
3 changes: 3 additions & 0 deletions test/integration/057_oauth_tests/models/model_4.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
select 1 as id
union all
select 2 as id
Loading

0 comments on commit 9d1df6c

Please sign in to comment.