Skip to content

Commit

Permalink
Add optional ServerSideEncryption setting for S3 uploads
Browse files Browse the repository at this point in the history
  • Loading branch information
zendesk-klross committed May 29, 2023
1 parent 74a8790 commit 75ed721
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 81 deletions.
3 changes: 3 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
S3_ENDPOINT_URL = from_conf("S3_ENDPOINT_URL")
S3_VERIFY_CERTIFICATE = from_conf("S3_VERIFY_CERTIFICATE")

# Set ServerSideEncryption for S3 uploads
S3_SERVER_SIDE_ENCRYPTION = from_conf("S3_SERVER_SIDE_ENCRYPTION")

# S3 retry configuration
# This is useful if you want to "fail fast" on S3 operations; use with caution
# though as this may increase failures. Note that this is the number of *retries*
Expand Down
2 changes: 1 addition & 1 deletion metaflow/metaflow_config_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def from_conf(name, default=None, validate_fn=None):
validate_fn should accept (name, value).
If the value validates, return None, else raise an MetaflowException.
"""
env_name = "METAFLOW_%s" % name
is_default = True
env_name = "METAFLOW_%s" % name
value = os.environ.get(env_name, METAFLOW_CONFIG.get(env_name, default))
if validate_fn and value is not None:
validate_fn(env_name, value)
Expand Down
6 changes: 6 additions & 0 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
S3_ENDPOINT_URL,
SERVICE_HEADERS,
SERVICE_INTERNAL_URL,
S3_SERVER_SIDE_ENCRYPTION,
)
from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars
from metaflow.parameters import deploy_time_eval
Expand Down Expand Up @@ -1100,12 +1101,17 @@ def _container_templates(self):
"METAFLOW_ARGO_EVENT_PAYLOAD_%s_%s"
% (event["type"], event["sanitized_name"])
] = ("{{workflow.parameters.%s}}" % event["sanitized_name"])

# Map S3 upload headers to environment variables
if S3_SERVER_SIDE_ENCRYPTION is not None:
env["METAFLOW_S3_SERVERSIDE_ENCRYPTION"] = S3_SERVER_SIDE_ENCRYPTION

metaflow_version = self.environment.get_environment_info()
metaflow_version["flow_name"] = self.graph.name
metaflow_version["production_token"] = self.production_token
env["METAFLOW_VERSION"] = json.dumps(metaflow_version)


# Set the template inputs and outputs for passing state. Very simply,
# the container template takes in input-paths as input and outputs
# the task-id (which feeds in as input-paths to the subsequent task).
Expand Down
4 changes: 4 additions & 0 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
S3_ENDPOINT_URL,
DEFAULT_SECRETS_BACKEND_TYPE,
AWS_SECRETS_MANAGER_DEFAULT_REGION,
S3_SERVER_SIDE_ENCRYPTION,
)
from metaflow.mflog import (
export_mflog_env_vars,
Expand Down Expand Up @@ -263,6 +264,9 @@ def create_job(
if tmpfs_enabled and tmpfs_tempdir:
job.environment_variable("METAFLOW_TEMPDIR", tmpfs_path)

if S3_SERVER_SIDE_ENCRYPTION is not None:
job.environment_variable("METAFLOW_S3_SERVER_SIDE_ENCRYPTION", S3_SERVER_SIDE_ENCRYPTION)

# Skip setting METAFLOW_DATASTORE_SYSROOT_LOCAL because metadata sync between the local user
# instance and the remote AWS Batch instance assumes metadata is stored in DATASTORE_LOCAL_DIR
# on the remote AWS Batch instance; this happens when METAFLOW_DATASTORE_SYSROOT_LOCAL
Expand Down
56 changes: 46 additions & 10 deletions metaflow/plugins/datatools/s3/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DATATOOLS_S3ROOT,
S3_RETRY_COUNT,
S3_TRANSIENT_RETRY_COUNT,
S3_SERVER_SIDE_ENCRYPTION,
TEMPDIR,
)
from metaflow.util import (
Expand Down Expand Up @@ -83,9 +84,10 @@ def ensure_unicode(x):
("value", Optional[PutValue]),
("path", Optional[str]),
("content_type", Optional[str]),
("encryption", Optional[str]),
("metadata", Optional[Dict[str, str]]),
],
defaults=(None, None, None, None),
defaults=(None, None, None, None, None),
)
S3PutObject.__module__ = __name__

Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
metadata: Optional[Dict[str, str]] = None,
range_info: Optional[RangeInfo] = None,
last_modified: int = None,
encryption: Optional[str] = None,
):
# all fields of S3Object should return a unicode object
prefix, url, path = map(ensure_unicode, (prefix, url, path))
Expand Down Expand Up @@ -176,6 +179,8 @@ def __init__(
self._key = url[len(prefix.rstrip("/")) + 1 :].rstrip("/")
self._prefix = prefix

self._encryption = encryption

@property
def exists(self) -> bool:
"""
Expand Down Expand Up @@ -321,6 +326,7 @@ def has_info(self) -> bool:
self._content_type is not None
or self._metadata is not None
or self._range_info is not None
or self._encryption is not None
)

@property
Expand Down Expand Up @@ -348,6 +354,18 @@ def content_type(self) -> Optional[str]:
"""
return self._content_type

@property
def encryption(self) -> Optional[str]:
"""
Returns the encryption type of the S3 object or None if it is not defined.
Returns
-------
str
Server-side-encryption type or None if parameter is not set.
"""
return self._encryption

@property
def range_info(self) -> Optional[RangeInfo]:
"""
Expand Down Expand Up @@ -486,6 +504,7 @@ def __init__(
prefix: Optional[str] = None,
run: Optional[Union[FlowSpec, "Run"]] = None,
s3root: Optional[str] = None,
encryption: Optional[str] = S3_SERVER_SIDE_ENCRYPTION,
**kwargs
):
if not boto_found:
Expand Down Expand Up @@ -539,6 +558,7 @@ def __init__(
"inject_failure_rate", TEST_INJECT_RETRYABLE_FAILURES
)
self._tmpdir = mkdtemp(dir=tmproot, prefix="metaflow.s3.")
self._encryption = encryption

def __enter__(self) -> "S3":
return self
Expand Down Expand Up @@ -739,6 +759,7 @@ def _info(s3, tmp):
"metadata": resp["Metadata"],
"size": resp["ContentLength"],
"last_modified": get_timestamp(resp["LastModified"]),
"encryption": resp["ServerSideEncryption"],
}

info_results = None
Expand All @@ -758,6 +779,7 @@ def _info(s3, tmp):
content_type=info_results["content_type"],
metadata=info_results["metadata"],
last_modified=info_results["last_modified"],
encryption=info_results["encryption"],
)
return S3Object(self._s3root, url, None)

Expand Down Expand Up @@ -811,7 +833,9 @@ def _head():
else:
yield self._s3root, s3url, None, info["size"], info[
"content_type"
], info["metadata"], None, info["last_modified"]
], info["metadata"], None, info["last_modified"], info[
"encryption"
]
else:
# This should not happen; we should always get a response
# even if it contains an error inside it
Expand Down Expand Up @@ -886,6 +910,7 @@ def _download(s3, tmp):
if return_info:
return {
"content_type": resp["ContentType"],
"encryption": resp["ServerSideEncryption"],
"metadata": resp["Metadata"],
"range_result": range_result,
"last_modified": get_timestamp(resp["LastModified"]),
Expand All @@ -906,6 +931,7 @@ def _download(s3, tmp):
url,
path,
content_type=addl_info["content_type"],
encryption=addl_info["encryption"],
metadata=addl_info["metadata"],
range_info=addl_info["range_result"],
last_modified=addl_info["last_modified"],
Expand Down Expand Up @@ -967,13 +993,15 @@ def _get():
- range_info["start"]
+ 1,
)
yield self._s3root, s3url, os.path.join(
self._tmpdir, fname
), None, info["content_type"], info[
"metadata"
], range_info, info[
"last_modified"
]
yield self._s3root, s3url, os.path.join(
self._tmpdir, fname
), None, info["content_type"], info[
"metadata"
], range_info, info[
"last_modified"
], info[
"encryption"
]
else:
yield self._s3root, s3prefix, None
else:
Expand Down Expand Up @@ -1033,6 +1061,8 @@ def _get():
self._tmpdir, fname
), None, info["content_type"], info["metadata"], range_info, info[
"last_modified"
], info[
"encryption"
]
else:
yield s3prefix, s3url, os.path.join(self._tmpdir, fname)
Expand Down Expand Up @@ -1120,14 +1150,16 @@ def put(
url = self._url(key)
src = urlparse(url)
extra_args = None
if content_type or metadata:
if content_type or metadata or self._encryption:
extra_args = {}
if content_type:
extra_args["ContentType"] = content_type
if metadata:
extra_args["Metadata"] = {
"metaflow-user-attributes": json.dumps(metadata)
}
if self._encryption:
extra_args["ServerSideEncryption"] = self._encryption

def _upload(s3, _):
# We make sure we are at the beginning in case we are retrying
Expand Down Expand Up @@ -1200,6 +1232,8 @@ def _store():
store_info["metadata"] = {
"metaflow-user-attributes": json.dumps(metadata)
}
if self._encryption:
store_info["encryption"] = self._encryption
if isinstance(obj, (RawIOBase, BufferedIOBase)):
if not obj.readable() or not obj.seekable():
raise MetaflowS3InvalidObject(
Expand Down Expand Up @@ -1272,6 +1306,8 @@ def _check():
store_info["metadata"] = {
"metaflow-user-attributes": json.dumps(metadata)
}
if self._encryption:
store_info["encryption"] = self._encryption
if not os.path.exists(path):
raise MetaflowS3NotFound("Local file not found: %s" % path)
yield path, self._url(key), store_info
Expand Down
19 changes: 15 additions & 4 deletions metaflow/plugins/datatools/s3/s3op.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
local,
prefix,
content_type=None,
encryption=None,
metadata=None,
range=None,
idx=None,
Expand All @@ -77,6 +78,7 @@ def __init__(
self.metadata = metadata
self.range = range
self.idx = idx
self.encryption = encryption

def __str__(self):
return self.url
Expand Down Expand Up @@ -171,6 +173,7 @@ def op_info(url):
"error": None,
"size": head["ContentLength"],
"content_type": head["ContentType"],
"encryption": head["ServerSideEncryption"],
"metadata": head["Metadata"],
"last_modified": get_timestamp(head["LastModified"]),
}
Expand Down Expand Up @@ -276,6 +279,8 @@ def op_info(url):
args["content_type"] = resp["ContentType"]
if resp["Metadata"] is not None:
args["metadata"] = resp["Metadata"]
if resp["ServerSideEncryption"] is not None:
args["encryption"] = resp["ServerSideEncryption"]
if resp["LastModified"]:
args["last_modified"] = get_timestamp(
resp["LastModified"]
Expand All @@ -299,12 +304,14 @@ def op_info(url):
do_upload = True
if do_upload:
extra = None
if url.content_type or url.metadata:
if url.content_type or url.metadata or url.encryption:
extra = {}
if url.content_type:
extra["ContentType"] = url.content_type
if url.metadata is not None:
extra["Metadata"] = url.metadata
if url.encryption is not None:
extra["ServerSideEncryption"] = url.encryption
try:
s3.upload_file(
url.local, url.bucket, url.path, ExtraArgs=extra
Expand Down Expand Up @@ -461,6 +468,7 @@ def get_info(self, url):
prefix=url.prefix,
content_type=head["ContentType"],
metadata=head["Metadata"],
encryption=head["ServerSideEncryption"],
range=url.range,
),
head["ContentLength"],
Expand Down Expand Up @@ -578,7 +586,7 @@ def verify_results(urls, verbose=False):
raise
if expected != got:
exit(ERROR_VERIFY_FAILED, url)
if url.content_type or url.metadata:
if url.content_type or url.metadata or url.encryption:
# Verify that we also have a metadata file present
try:
os.stat("%s_meta" % url.local)
Expand Down Expand Up @@ -838,11 +846,12 @@ def _files():
url = r["url"]
content_type = r.get("content_type", None)
metadata = r.get("metadata", None)
encryption = r.get("encryption", None)
if not os.path.exists(local):
exit(ERROR_LOCAL_FILE_NOT_FOUND, local)
yield input_line_idx, local, url, content_type, metadata
yield input_line_idx, local, url, content_type, metadata, encryption

def _make_url(idx, local, user_url, content_type, metadata):
def _make_url(idx, local, user_url, content_type, metadata, encryption):
src = urlparse(user_url)
url = S3Url(
url=user_url,
Expand All @@ -853,6 +862,7 @@ def _make_url(idx, local, user_url, content_type, metadata):
content_type=content_type,
metadata=metadata,
idx=idx,
encryption=encryption,
)
if src.scheme != "s3":
exit(ERROR_INVALID_URL, url)
Expand Down Expand Up @@ -896,6 +906,7 @@ def _make_url(idx, local, user_url, content_type, metadata):
"local": url.local,
"content_type": url.content_type,
"metadata": url.metadata,
"encryption": url.encryption,
}
)
+ "\n"
Expand Down
4 changes: 4 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
S3_ENDPOINT_URL,
SERVICE_HEADERS,
SERVICE_INTERNAL_URL,
S3_SERVER_SIDE_ENCRYPTION,
)
from metaflow.mflog import (
BASH_SAVE_LOGS,
Expand Down Expand Up @@ -258,6 +259,9 @@ def create_job(
# see get_datastore_root_from_config in datastore/local.py).
)

if S3_SERVER_SIDE_ENCRYPTION is not None:
job.environment_variable("METAFLOW_S3_SERVER_SIDE_ENCRYPTION", S3_SERVER_SIDE_ENCRYPTION)

# Set environment variables to support metaflow.integrations.ArgoEvent
job.environment_variable(
"METAFLOW_ARGO_EVENTS_WEBHOOK_URL", ARGO_EVENTS_INTERNAL_WEBHOOK_URL
Expand Down
Loading

0 comments on commit 75ed721

Please sign in to comment.