Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions smdebug/core/access_layer/s3.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Standard Library
import io
import os
import re
import tempfile

# Third Party
import boto3
from boto3.s3.transfer import TransferConfig

# First Party
from smdebug.core.access_layer.base import TSAccessBase
Expand All @@ -28,6 +31,10 @@ def __init__(
self.s3 = boto3.resource("s3", region_name=get_region())
self.s3_client = boto3.client("s3", region_name=get_region())

# Set the desired multipart threshold value (5GB)
MB = 1024 ** 2
self.transfer_config = TransferConfig(multipart_threshold=5 * MB)

# check if the bucket exists
buckets = [bucket["Name"] for bucket in self.s3_client.list_buckets()["Buckets"]]
if self.bucket_name not in buckets:
Expand All @@ -39,26 +46,38 @@ def _init_data(self):
else:
self.data = ""

def _init_data(self):
if self.binary:
self.data = bytearray()
else:
self.data = ""

def open(self, bucket_name, mode):
raise NotImplementedError

def write(self, _data):
start = len(self.data)

self.data += _data
length = len(_data)
return [start, length]

def close(self):
if self.flushed:
return
key = self.s3.Object(self.bucket_name, self.key_name)
key.put(Body=self.data)
if self.binary:
self.logger.debug(
f"Sagemaker-Debugger: Writing binary data to s3://{os.path.join(self.bucket_name, self.key_name)}"
)
self.s3_client.upload_fileobj(
io.BytesIO(self.data), self.bucket_name, self.key_name, Config=self.transfer_config
)
else:
f = tempfile.NamedTemporaryFile(mode="w+")
self.logger.debug(
f"Sagemaker-Debugger: Writing string data to s3://{os.path.join(self.bucket_name, self.key_name)}"
)

f.write(self.data)
f.flush()
self.s3_client.upload_file(
f.name, self.bucket_name, self.key_name, Config=self.transfer_config
)

self.logger.debug(
f"Sagemaker-Debugger: Wrote {len(self.data)} bytes to file "
f"s3://{os.path.join(self.bucket_name, self.key_name)}"
Expand Down