Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support refresh tokens (#2126) #2141

Merged
merged 2 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
## dbt 0.15.3 (Release TBD)

This is a bugfix release.

### Fixes
- Use refresh tokens in snowflake instead of access tokens ([#2126](https://github.com/fishtown-analytics/dbt/issues/2126), [#2141](https://github.com/fishtown-analytics/dbt/pull/2141))

## dbt 0.15.2 (February 2, 2020)

This is a bugfix release.
Expand Down
3 changes: 3 additions & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ jobs:
SNOWFLAKE_TEST_PASSWORD: $(SNOWFLAKE_TEST_PASSWORD)
SNOWFLAKE_TEST_USER: $(SNOWFLAKE_TEST_USER)
SNOWFLAKE_TEST_WAREHOUSE: $(SNOWFLAKE_TEST_WAREHOUSE)
SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN: $(SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN)
SNOWFLAKE_TEST_OAUTH_CLIENT_ID: $(SNOWFLAKE_TEST_OAUTH_CLIENT_ID)
SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET: $(SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET)
displayName: Run integration tests

- job: BigQueryIntegrationTest
Expand Down
96 changes: 86 additions & 10 deletions plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import datetime
import pytz
import re
Expand All @@ -8,15 +9,22 @@

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
import requests
import snowflake.connector
import snowflake.connector.errors

import dbt.exceptions
from dbt.exceptions import (
InternalException, RuntimeException, FailedToConnectException,
DatabaseException, warn_or_error
)
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.logger import GLOBAL_LOGGER as logger


_TOKEN_REQUEST_URL = 'https://{}.snowflakecomputing.com/oauth/token-request'


@dataclass
class SnowflakeCredentials(Credentials):
account: str
Expand All @@ -28,15 +36,30 @@ class SnowflakeCredentials(Credentials):
private_key_path: Optional[str]
private_key_passphrase: Optional[str]
token: Optional[str]
oauth_client_id: Optional[str]
oauth_client_secret: Optional[str]
client_session_keep_alive: bool = False

def __post_init__(self):
if (
self.authenticator != 'oauth' and
(self.oauth_client_secret or self.oauth_client_id or self.token)
):
# the user probably forgot to set 'authenticator' like I keep doing
warn_or_error(
'Authenticator is not set to oauth, but an oauth-only '
'parameter is set! Did you mean to set authenticator: oauth?'
)

@property
def type(self):
return 'snowflake'

def _connection_keys(self):
return ('account', 'user', 'database', 'schema', 'warehouse', 'role',
'client_session_keep_alive')
return (
'account', 'user', 'database', 'schema', 'warehouse', 'role',
'client_session_keep_alive'
)

def auth_args(self):
# Pull all of the optional authentication args for the connector,
Expand All @@ -47,10 +70,63 @@ def auth_args(self):
if self.authenticator:
result['authenticator'] = self.authenticator
if self.authenticator == 'oauth':
result['token'] = self.token
token = self.token
# if we have a client ID/client secret, the token is a refresh
# token, not an access token
if self.oauth_client_id and self.oauth_client_secret:
token = self._get_access_token()
elif self.oauth_client_id:
warn_or_error(
'Invalid profile: got an oauth_client_id, but not an '
'oauth_client_secret!'
)
elif self.oauth_client_secret:
warn_or_error(
'Invalid profile: got an oauth_client_secret, but not '
'an oauth_client_id!'
)

result['token'] = token
result['private_key'] = self._get_private_key()
return result

def _get_access_token(self) -> str:
if self.authenticator != 'oauth':
raise InternalException('Can only get access tokens for oauth')
missing = any(
x is None for x in
(self.oauth_client_id, self.oauth_client_secret, self.token)
)
if missing:
raise InternalException(
'need a client ID a client secret, and a refresh token to get '
'an access token'
)
# should the full url be a config item?
token_url = _TOKEN_REQUEST_URL.format(self.account)
# I think this is only used to redirect on success, which we ignore
# (it does not have to match the integration's settings in snowflake)
redirect_uri = 'http://localhost:9999'
data = {
'grant_type': 'refresh_token',
'refresh_token': self.token,
'redirect_uri': redirect_uri
}

auth = base64.b64encode(
f'{self.oauth_client_id}:{self.oauth_client_secret}'
.encode('ascii')
).decode('ascii')
headers = {
'Authorization': f'Basic {auth}',
'Content-type': 'application/x-www-form-urlencoded;charset=utf-8'
}
result = requests.post(token_url, headers=headers, data=data)
result_json = result.json()
if 'access_token' not in result_json:
raise DatabaseException(f'Did not get a token: {result_json}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know what kind of contents would be returned in result_json here? I'm just worried that this could print sensitive-ish info out to the console, like a client id or client secret. If there's a chance that could happen, we should probably stringify the json and mask out creds if we can!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually got error messages about invalid client IDs. I don't think it returns sensitive information, it was actually very annoying to debug because of that (which is good/normal security practice for an endpoint like this):

Encountered an error:
Database Error
  Did not get a token: {'data': None, 'message': 'This is an invalid client.', 'code': None, 'success': False, 'error': 'invalid_client'}

return result_json['access_token']

def _get_private_key(self):
"""Get Snowflake private key by path or None."""
if not self.private_key_path or self.private_key_passphrase is None:
Expand Down Expand Up @@ -84,25 +160,25 @@ def exception_handler(self, sql):
logger.debug("got empty sql statement, moving on")
elif 'This session does not have a current database' in msg:
self.release()
raise dbt.exceptions.FailedToConnectException(
raise FailedToConnectException(
('{}\n\nThis error sometimes occurs when invalid '
'credentials are provided, or when your default role '
'does not have access to use the specified database. '
'Please double check your profile and try again.')
.format(msg))
else:
self.release()
raise dbt.exceptions.DatabaseException(msg)
raise DatabaseException(msg)
except Exception as e:
logger.debug("Error running SQL: {}", sql)
logger.debug("Rolling back transaction.")
self.release()
if isinstance(e, dbt.exceptions.RuntimeException):
if isinstance(e, RuntimeException):
# during a sql query, an internal to dbt exception was raised.
# this sounds a lot like a signal handler and probably has
# useful information, so raise it without modification.
raise
raise dbt.exceptions.RuntimeException(str(e)) from e
raise RuntimeException(str(e)) from e

@classmethod
def open(cls, connection):
Expand Down Expand Up @@ -136,7 +212,7 @@ def open(cls, connection):
connection.handle = None
connection.state = 'fail'

raise dbt.exceptions.FailedToConnectException(str(e))
raise FailedToConnectException(str(e))

def cancel(self, connection):
handle = connection.handle
Expand Down Expand Up @@ -228,7 +304,7 @@ def add_query(self, sql, auto_begin=True,
else:
conn_name = conn.name

raise dbt.exceptions.RuntimeException(
raise RuntimeException(
"Tried to run an empty query on model '{}'. If you are "
"conditionally running\nsql, eg. in a model hook, make "
"sure your `else` clause contains valid sql!\n\n"
Expand Down
139 changes: 139 additions & 0 deletions scripts/werkzeug-refresh-token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import argparse
import json
import secrets
import textwrap
from base64 import b64encode

import requests
from werkzeug import redirect
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from werkzeug.wrappers import Request, Response
from werkzeug.serving import run_simple
from urllib.parse import urlencode


def _make_rfp_claim_value():
# from https://tools.ietf.org/id/draft-bradley-oauth-jwt-encoded-state-08.html#rfc.section.4 # noqa
# we can do whatever we want really, so just token.urlsafe?
return secrets.token_urlsafe(112)


def _make_response(client_id, client_secret, refresh_token):
return Response(textwrap.dedent(
f'''\
SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN="{refresh_token}"
SNOWFLAKE_TEST_OAUTH_CLIENT_ID="{client_id}"
SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET="{client_secret}"'''
))


class TokenManager:
def __init__(self, account_name, client_id, client_secret):
self.account_name = account_name
self.client_id = client_id
self.client_secret = client_secret
self.token = None
self.rfp_claim = _make_rfp_claim_value()
self.port = 8080

@property
def account_url(self):
return f'https://{self.account_name}.snowflakecomputing.com'

@property
def auth_url(self):
return f'{self.account_url}/oauth/authorize'

@property
def token_url(self):
return f'{self.account_url}/oauth/token-request'

@property
def redirect_uri(self):
return f'http://localhost:{self.port}'

@property
def headers(self):
auth = f'{self.client_id}:{self.client_secret}'.encode('ascii')
encoded_auth = b64encode(auth).decode('ascii')
return {
'Authorization': f'Basic {encoded_auth}',
'Content-type': 'application/x-www-form-urlencoded; charset=utf-8'
}

def _code_to_token(self, code):
data = {
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': self.redirect_uri,
}
# data = urlencode(data)
resp = requests.post(
url=self.token_url,
headers=self.headers,
data=data,
)
try:
refresh_token = resp.json()['refresh_token']
except KeyError:
print(resp.json())
raise
return refresh_token

@Request.application
def auth(self, request):
code = request.args.get('code')
if code:
# we got 303'ed here with a code
state_received = request.args.get('state')
if state_received != self.rfp_claim:
return Response('Invalid RFP claim: MITM?', status=401)
refresh_token = self._code_to_token(code)
return _make_response(
self.client_id,
self.client_secret,
refresh_token,
)
else:
return redirect('/login')

@Request.application
def login(self, request):
# take the auth URL and add the query string to it
query = {
'response_type': 'code',
'client_id': self.client_id,
'redirect_uri': self.redirect_uri,
'state': self.rfp_claim,
}
query = urlencode(query)
return redirect(f'{self.auth_url}?{query}')


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('account_name', help='The account name')
parser.add_argument('json_blob', help='The json auth blob')

return parser.parse_args()


def main():
args = parse_args()
data = json.loads(args.json_blob)
client_id = data['OAUTH_CLIENT_ID']
client_secret = data['OAUTH_CLIENT_SECRET']
token_manager = TokenManager(
account_name=args.account_name,
client_id=client_id,
client_secret=client_secret,
)
app = DispatcherMiddleware(token_manager.auth, {
'/login': token_manager.login,
})

run_simple('localhost', token_manager.port, app)


if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions test.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ SNOWFLAKE_TEST_ALT_DATABASE=
SNOWFLAKE_TEST_QUOTED_DATABASE=
SNOWFLAKE_TEST_WAREHOUSE=
SNOWFLAKE_TEST_ALT_WAREHOUSE=
SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN=
SNOWFLAKE_TEST_OAUTH_CLIENT_ID=
SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET=

BIGQUERY_TYPE=
BIGQUERY_PROJECT_ID=
Expand Down
1 change: 1 addition & 0 deletions test/integration/057_oauth_tests/models/model_1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 1 as id
1 change: 1 addition & 0 deletions test/integration/057_oauth_tests/models/model_2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 2 as id
3 changes: 3 additions & 0 deletions test/integration/057_oauth_tests/models/model_3.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
select * from {{ ref('model_1') }}
union all
select * from {{ ref('model_2') }}
3 changes: 3 additions & 0 deletions test/integration/057_oauth_tests/models/model_4.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
select 1 as id
union all
select 2 as id
Loading