Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IAM authentication method for redshift adapter #769

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions dbt/adapters/redshift/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from dbt.adapters.postgres import PostgresAdapter
from dbt.logger import GLOBAL_LOGGER as logger # noqa

import dbt.exceptions
import boto3
import psycopg2

drop_lock = multiprocessing.Lock()

Expand All @@ -17,6 +19,81 @@ def type(cls):
def date_function(cls):
return 'getdate()'

@classmethod
def get_tmp_cluster_credentials(cls, config):
creds = config.copy()

cluster_id = creds.get('cluster_id')
if not cluster_id:
error = '`cluster_id` must be set in profile if IAM authentication method selected'
raise dbt.exceptions.FailedToConnectException(error)

client = boto3.client('redshift')

# replace username and password with temporary redshift credentials
try:
cluster_creds = client.get_cluster_credentials(DbUser=creds.get('user'),
DbName=creds.get('dbname'),
ClusterIdentifier=creds.get('cluster_id'),
AutoCreate=False)
creds['user'] = cluster_creds.get('DbUser')
creds['pass'] = cluster_creds.get('DbPassword')

return creds

except client.exceptions.ClientError as e:
error = ('Unable to get temporary Redshift cluster credentials: "{}"'.format(str(e)))
raise dbt.exceptions.FailedToConnectException(error)

@classmethod
def get_redshift_credentials(cls, config):
creds = config.copy()

method = creds.get('method')

if method == 'database' or method is None: # Support missing method for backwards compatibility
return creds

elif method == 'iam':
return cls.get_tmp_cluster_credentials(creds)

else:
error = ('Invalid `method` in profile: "{}"'.format(method))
raise dbt.exceptions.FailedToConnectException(error)

@classmethod
def open_connection(cls, connection):
if connection.get('state') == 'open':
logger.debug('Connection is already open, skipping open.')
return connection

result = connection.copy()

try:
credentials = cls.get_redshift_credentials(connection.get('credentials', {}))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the suggestion above about get_redshift_credentials should clean up the if/else logic here I believe!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


handle = psycopg2.connect(
dbname=credentials.get('dbname'),
user=credentials.get('user'),
host=credentials.get('host'),
password=credentials.get('pass'),
port=credentials.get('port'),
connect_timeout=10)

result['handle'] = handle
result['state'] = 'open'
except psycopg2.Error as e:
logger.debug("Got an error when attempting to open a postgres "
"connection: '{}'"
.format(e))

result['handle'] = None
result['state'] = 'fail'

raise dbt.exceptions.FailedToConnectException(str(e))

return result

@classmethod
def _get_columns_in_table_sql(cls, schema_name, table_name, database):
# Redshift doesn't support cross-database queries,
Expand All @@ -27,7 +104,7 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database):
table_schema_filter = '1=1'
else:
table_schema_filter = "table_schema = '{schema_name}'".format(
schema_name=schema_name)
schema_name=schema_name)

sql = """
with bound_views as (
Expand Down
15 changes: 13 additions & 2 deletions dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dbt.contracts.common import validate_with
from dbt.logger import GLOBAL_LOGGER as logger # noqa


adapter_types = ['postgres', 'redshift', 'snowflake', 'bigquery']
connection_contract = Schema({
Required('type'): Any(*adapter_types),
Expand All @@ -24,6 +23,18 @@
Required('schema'): basestring,
})

redshift_auth_methods = ['database', 'iam']
redshift_credentials_contract = Schema({
Optional('method'): Any(*redshift_auth_methods),
Required('dbname'): basestring,
Required('host'): basestring,
Required('user'): basestring,
Optional('pass'): basestring, # TODO: require if 'database' method selected
Required('port'): Any(All(int, Range(min=0, max=65535)), basestring),
Required('schema'): basestring,
Optional('cluster_id'): basestring, # TODO: require if 'iam' method selected
})

snowflake_credentials_contract = Schema({
Required('account'): basestring,
Required('user'): basestring,
Expand All @@ -46,7 +57,7 @@

credentials_mapping = {
'postgres': postgres_credentials_contract,
'redshift': postgres_credentials_contract,
'redshift': redshift_credentials_contract,
'snowflake': snowflake_credentials_contract,
'bigquery': bigquery_credentials_contract,
}
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ google-cloud-bigquery==0.29.0
requests>=2.18.0
agate>=1.6,<2
jsonschema==2.6.0
boto3>=1.6.23