Skip to content

Commit

Permalink
Add redshift copy support for role-based credential string (#1962)
Browse files Browse the repository at this point in the history
* Add support for role-based credentials and refactor key-based creds

* Add test to ensure error is raised when no redshift creds are provided
  • Loading branch information
dlstadther committed Jan 26, 2017
1 parent d7d6a37 commit b86166c
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 20 deletions.
56 changes: 43 additions & 13 deletions luigi/contrib/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,34 @@ def s3_load_path(self):
"""
return None

@abc.abstractproperty
@property
def aws_access_key_id(self):
"""
Override to return the key id.
"""
return None

@abc.abstractproperty
@property
def aws_secret_access_key(self):
"""
Override to return the secret access key.
"""
return None

@property
def aws_account_id(self):
"""
Override to return the account id.
"""
return None

@property
def aws_arn_role_name(self):
"""
Override to return the arn role name.
"""
return None

@property
def aws_session_token(self):
"""
Expand Down Expand Up @@ -263,22 +277,38 @@ def run(self):
def copy(self, cursor, f):
"""
Defines copying from s3 into redshift.
If both key-based and role-based credentials are provided, role-based will be used.
"""
# if session token is set, create token string
if self.aws_session_token:
token = ';token=%s' % self.aws_session_token
# otherwise, leave token string empty
# format the credentials string dependent upon which type of credentials were provided
if self.aws_account_id and self.aws_arn_role_name:
cred_str = 'aws_iam_role=arn:aws:iam::{id}:role/{role}'.format(
id=self.aws_account_id,
role=self.aws_arn_role_name
)
elif self.aws_access_key_id and self.aws_secret_access_key:
cred_str = 'aws_access_key_id={key};aws_secret_key={secret}{opt}'.format(
key=self.aws_access_key_id,
secret=self.aws_secret_access_key,
opt=';token={}'.format(self.aws_session_token) if self.aws_session_token else ''
)
else:
token = ''
raise NotImplementedError("Missing Credentials. "
"Override one of the following pairs of auth-args: "
"'aws_access_key_id' AND 'aws_secret_access_key' OR "
"'aws_account_id' AND 'aws_arn_role_name'")

logger.info("Inserting file: %s", f)
cursor.execute("""
COPY %s from '%s'
CREDENTIALS 'aws_access_key_id=%s;aws_secret_access_key=%s%s'
%s
;""" % (self.table, f, self.aws_access_key_id,
self.aws_secret_access_key, token,
self.copy_options))
COPY {table} from '{source}'
CREDENTIALS '{creds}'
{options}
;""".format(
table=self.table,
source=f,
creds=cred_str,
options=self.copy_options)
)

def output(self):
"""
Expand Down
40 changes: 33 additions & 7 deletions test/contrib/redshift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
AWS_ACCESS_KEY = 'key'
AWS_SECRET_KEY = 'secret'

AWS_ACCOUNT_ID = '0123456789012'
AWS_ROLE_NAME = 'MyRedshiftRole'

BUCKET = 'bucket'
KEY = 'key'


class DummyS3CopyToTable(luigi.contrib.redshift.S3CopyToTable):
class DummyS3CopyToTableBase(luigi.contrib.redshift.S3CopyToTable):
# Class attributes taken from `DummyPostgresImporter` in
# `../postgres_test.py`.
host = 'dummy_host'
Expand All @@ -40,8 +43,6 @@ class DummyS3CopyToTable(luigi.contrib.redshift.S3CopyToTable):
('some_int', 'int'),
)

aws_access_key_id = 'AWS_ACCESS_KEY'
aws_secret_access_key = 'AWS_SECRET_KEY'
copy_options = ''
prune_table = ''
prune_column = ''
Expand All @@ -51,7 +52,17 @@ def s3_load_path(self):
return 's3://%s/%s' % (BUCKET, KEY)


class DummyS3CopyToTempTable(DummyS3CopyToTable):
class DummyS3CopyToTableKey(DummyS3CopyToTableBase):
aws_access_key_id = AWS_ACCESS_KEY
aws_secret_access_key = AWS_SECRET_KEY


class DummyS3CopyToTableRole(DummyS3CopyToTableBase):
aws_account_id = AWS_ACCESS_KEY
aws_arn_role_name = AWS_SECRET_KEY


class DummyS3CopyToTempTable(DummyS3CopyToTableKey):
# Extend/alter DummyS3CopyToTable for temp table copying
table = luigi.Parameter(default='stage_dummy_table')

Expand All @@ -65,10 +76,25 @@ class DummyS3CopyToTempTable(DummyS3CopyToTable):


class TestS3CopyToTable(unittest.TestCase):
@mock.patch("luigi.contrib.redshift.RedshiftTarget")
def test_copy_missing_creds(self, mock_redshift_target):
task = DummyS3CopyToTableBase()

# The mocked connection cursor passed to
# S3CopyToTable.copy(self, cursor, f).
mock_cursor = (mock_redshift_target.return_value
.connect
.return_value
.cursor
.return_value)

with self.assertRaises(NotImplementedError):
task.copy(mock_cursor, task.s3_load_path())

@mock.patch("luigi.contrib.redshift.S3CopyToTable.copy")
@mock.patch("luigi.contrib.redshift.RedshiftTarget")
def test_s3_copy_to_table(self, mock_redshift_target, mock_copy):
task = DummyS3CopyToTable()
task = DummyS3CopyToTableKey()
task.run()

# The mocked connection cursor passed to
Expand Down Expand Up @@ -112,7 +138,7 @@ def test_s3_copy_to_missing_table(self,
Test missing table creation
"""
# Ensure `S3CopyToTable.create_table` does not throw an error.
task = DummyS3CopyToTable()
task = DummyS3CopyToTableKey()
task.run()

# Make sure the cursor was successfully used to create the table in
Expand Down Expand Up @@ -171,7 +197,7 @@ class TestS3CopyToSchemaTable(unittest.TestCase):
@mock.patch("luigi.contrib.redshift.S3CopyToTable.copy")
@mock.patch("luigi.contrib.redshift.RedshiftTarget")
def test_s3_copy_to_table(self, mock_redshift_target, mock_copy):
task = DummyS3CopyToTable(table='dummy_schema.dummy_table')
task = DummyS3CopyToTableKey(table='dummy_schema.dummy_table')
task.run()

# The mocked connection cursor passed to
Expand Down

0 comments on commit b86166c

Please sign in to comment.