diff --git a/luigi/contrib/redshift.py b/luigi/contrib/redshift.py index 48456d23e2..fdf1a7aa99 100644 --- a/luigi/contrib/redshift.py +++ b/luigi/contrib/redshift.py @@ -78,20 +78,34 @@ def s3_load_path(self): """ return None - @abc.abstractproperty + @property def aws_access_key_id(self): """ Override to return the key id. """ return None - @abc.abstractproperty + @property def aws_secret_access_key(self): """ Override to return the secret access key. """ return None + @property + def aws_account_id(self): + """ + Override to return the account id. + """ + return None + + @property + def aws_arn_role_name(self): + """ + Override to return the arn role name. + """ + return None + @property def aws_session_token(self): """ @@ -263,22 +277,38 @@ def run(self): def copy(self, cursor, f): """ Defines copying from s3 into redshift. + + If both key-based and role-based credentials are provided, role-based will be used. """ - # if session token is set, create token string - if self.aws_session_token: - token = ';token=%s' % self.aws_session_token - # otherwise, leave token string empty + # format the credentials string dependent upon which type of credentials were provided + if self.aws_account_id and self.aws_arn_role_name: + cred_str = 'aws_iam_role=arn:aws:iam::{id}:role/{role}'.format( + id=self.aws_account_id, + role=self.aws_arn_role_name + ) + elif self.aws_access_key_id and self.aws_secret_access_key: + cred_str = 'aws_access_key_id={key};aws_secret_key={secret}{opt}'.format( + key=self.aws_access_key_id, + secret=self.aws_secret_access_key, + opt=';token={}'.format(self.aws_session_token) if self.aws_session_token else '' + ) else: - token = '' + raise NotImplementedError("Missing Credentials. " + "Override one of the following pairs of auth-args: " + "'aws_access_key_id' AND 'aws_secret_access_key' OR " + "'aws_account_id' AND 'aws_arn_role_name'") logger.info("Inserting file: %s", f) cursor.execute(""" - COPY %s from '%s' - CREDENTIALS 'aws_access_key_id=%s;aws_secret_access_key=%s%s' - %s - ;""" % (self.table, f, self.aws_access_key_id, - self.aws_secret_access_key, token, - self.copy_options)) + COPY {table} from '{source}' + CREDENTIALS '{creds}' + {options} + ;""".format( + table=self.table, + source=f, + creds=cred_str, + options=self.copy_options) + ) def output(self): """ diff --git a/test/contrib/redshift_test.py b/test/contrib/redshift_test.py index 848cfb6bb9..5180efe946 100644 --- a/test/contrib/redshift_test.py +++ b/test/contrib/redshift_test.py @@ -23,11 +23,14 @@ AWS_ACCESS_KEY = 'key' AWS_SECRET_KEY = 'secret' +AWS_ACCOUNT_ID = '0123456789012' +AWS_ROLE_NAME = 'MyRedshiftRole' + BUCKET = 'bucket' KEY = 'key' -class DummyS3CopyToTable(luigi.contrib.redshift.S3CopyToTable): +class DummyS3CopyToTableBase(luigi.contrib.redshift.S3CopyToTable): # Class attributes taken from `DummyPostgresImporter` in # `../postgres_test.py`. host = 'dummy_host' @@ -40,8 +43,6 @@ class DummyS3CopyToTable(luigi.contrib.redshift.S3CopyToTable): ('some_int', 'int'), ) - aws_access_key_id = 'AWS_ACCESS_KEY' - aws_secret_access_key = 'AWS_SECRET_KEY' copy_options = '' prune_table = '' prune_column = '' @@ -51,7 +52,17 @@ def s3_load_path(self): return 's3://%s/%s' % (BUCKET, KEY) -class DummyS3CopyToTempTable(DummyS3CopyToTable): +class DummyS3CopyToTableKey(DummyS3CopyToTableBase): + aws_access_key_id = AWS_ACCESS_KEY + aws_secret_access_key = AWS_SECRET_KEY + + +class DummyS3CopyToTableRole(DummyS3CopyToTableBase): + aws_account_id = AWS_ACCESS_KEY + aws_arn_role_name = AWS_SECRET_KEY + + +class DummyS3CopyToTempTable(DummyS3CopyToTableKey): # Extend/alter DummyS3CopyToTable for temp table copying table = luigi.Parameter(default='stage_dummy_table') @@ -65,10 +76,25 @@ class DummyS3CopyToTempTable(DummyS3CopyToTable): class TestS3CopyToTable(unittest.TestCase): + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_missing_creds(self, mock_redshift_target): + task = DummyS3CopyToTableBase() + + # The mocked connection cursor passed to + # S3CopyToTable.copy(self, cursor, f). + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + with self.assertRaises(NotImplementedError): + task.copy(mock_cursor, task.s3_load_path()) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.copy") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_table(self, mock_redshift_target, mock_copy): - task = DummyS3CopyToTable() + task = DummyS3CopyToTableKey() task.run() # The mocked connection cursor passed to @@ -112,7 +138,7 @@ def test_s3_copy_to_missing_table(self, Test missing table creation """ # Ensure `S3CopyToTable.create_table` does not throw an error. - task = DummyS3CopyToTable() + task = DummyS3CopyToTableKey() task.run() # Make sure the cursor was successfully used to create the table in @@ -171,7 +197,7 @@ class TestS3CopyToSchemaTable(unittest.TestCase): @mock.patch("luigi.contrib.redshift.S3CopyToTable.copy") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_table(self, mock_redshift_target, mock_copy): - task = DummyS3CopyToTable(table='dummy_schema.dummy_table') + task = DummyS3CopyToTableKey(table='dummy_schema.dummy_table') task.run() # The mocked connection cursor passed to