Skip to content

Commit

Permalink
Added optional argument 'aws_session_token' to S3Client (#2798)
Browse files Browse the repository at this point in the history
* Added optional argument 'aws_session_token' to S3Client

* Added tests for S3Client with provided session
  • Loading branch information
Bonsanto authored and dlstadther committed Nov 1, 2019
1 parent c08ba03 commit eb2acbc
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
10 changes: 6 additions & 4 deletions luigi/contrib/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
logger.warning("Loading S3 module without the python package boto3. "
"Will crash at runtime if S3 functionality is used.")


# two different ways of marking a directory
# with a suffix in S3
S3_DIRECTORY_MARKER_SUFFIX_0 = '_$folder$'
Expand Down Expand Up @@ -101,14 +100,16 @@ class S3Client(FileSystem):
DEFAULT_PART_SIZE = 8388608
DEFAULT_THREADS = 100

def __init__(self, aws_access_key_id=None, aws_secret_access_key=None,
def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None,
**kwargs):
options = self._get_s3_config()
options.update(kwargs)
if aws_access_key_id:
options['aws_access_key_id'] = aws_access_key_id
if aws_secret_access_key:
options['aws_secret_access_key'] = aws_secret_access_key
if aws_session_token:
options['aws_session_token'] = aws_session_token

self._options = options

Expand All @@ -129,7 +130,8 @@ def s3(self):
role_arn = options.get('aws_role_arn')
role_session_name = options.get('aws_role_session_name')

aws_session_token = None
# In case the aws_session_token is provided use it
aws_session_token = options.get('aws_session_token')

if role_arn and role_session_name:
sts_client = boto3.client('sts')
Expand All @@ -143,7 +145,7 @@ def s3(self):
.format(role_session_name))

for key in ['aws_access_key_id', 'aws_secret_access_key',
'aws_role_session_name', 'aws_role_arn']:
'aws_role_session_name', 'aws_role_arn', 'aws_session_token']:
if key in options:
options.pop(key)

Expand Down
23 changes: 23 additions & 0 deletions test/contrib/s3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

AWS_ACCESS_KEY = "XXXXXXXXXXXXXXXXXXXX"
AWS_SECRET_KEY = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
AWS_SESSION_TOKEN = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"


def create_bucket():
Expand Down Expand Up @@ -72,6 +73,11 @@ def create_target(self, format=None, **kwargs):
create_bucket()
return S3Target('s3://mybucket/test_file', client=client, format=format, **kwargs)

def create_target_with_session(self, format=None, **kwargs):
client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_SESSION_TOKEN)
create_bucket()
return S3Target('s3://mybucket/test_file', client=client, format=format, **kwargs)

def test_read(self):
client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY)
create_bucket()
Expand All @@ -81,10 +87,23 @@ def test_read(self):
file_str = read_file.read()
self.assertEqual(self.tempFileContents, file_str.encode('utf-8'))

def test_read_with_session(self):
client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_SESSION_TOKEN)
create_bucket()
client.put(self.tempFilePath, 's3://mybucket/tempfile-with-session')
t = S3Target('s3://mybucket/tempfile-with-session', client=client)
read_file = t.open()
file_str = read_file.read()
self.assertEqual(self.tempFileContents, file_str.encode('utf-8'))

def test_read_no_file(self):
t = self.create_target()
self.assertRaises(FileNotFoundException, t.open)

def test_read_no_file_with_session(self):
t = self.create_target_with_session()
self.assertRaises(FileNotFoundException, t.open)

def test_read_no_file_sse(self):
t = self.create_target(encrypt_key=True)
self.assertRaises(FileNotFoundException, t.open)
Expand Down Expand Up @@ -186,6 +205,10 @@ def test_put(self):
s3_client.put(self.tempFilePath, 's3://mybucket/putMe')
self.assertTrue(s3_client.exists('s3://mybucket/putMe'))

s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_SESSION_TOKEN)
s3_client.put(self.tempFilePath, 's3://mybucket/putMe')
self.assertTrue(s3_client.exists('s3://mybucket/putMe'))

def test_put_no_such_bucket(self):
# intentionally don't create bucket
s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY)
Expand Down

0 comments on commit eb2acbc

Please sign in to comment.