Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ability to use ServerSideEncryption for S3 uploads #1436

Merged
merged 2 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
S3_ENDPOINT_URL = from_conf("S3_ENDPOINT_URL")
S3_VERIFY_CERTIFICATE = from_conf("S3_VERIFY_CERTIFICATE")

# Set ServerSideEncryption for S3 uploads
S3_SERVER_SIDE_ENCRYPTION = from_conf("S3_SERVER_SIDE_ENCRYPTION")

# S3 retry configuration
# This is useful if you want to "fail fast" on S3 operations; use with caution
# though as this may increase failures. Note that this is the number of *retries*
Expand Down
2 changes: 1 addition & 1 deletion metaflow/metaflow_config_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def from_conf(name, default=None, validate_fn=None):
validate_fn should accept (name, value).
If the value validates, return None, else raise an MetaflowException.
"""
env_name = "METAFLOW_%s" % name
is_default = True
env_name = "METAFLOW_%s" % name
value = os.environ.get(env_name, METAFLOW_CONFIG.get(env_name, default))
if validate_fn and value is not None:
validate_fn(env_name, value)
Expand Down
5 changes: 5 additions & 0 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
S3_ENDPOINT_URL,
SERVICE_HEADERS,
SERVICE_INTERNAL_URL,
S3_SERVER_SIDE_ENCRYPTION,
)
from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars
from metaflow.parameters import deploy_time_eval
Expand Down Expand Up @@ -1101,6 +1102,10 @@ def _container_templates(self):
% (event["type"], event["sanitized_name"])
] = ("{{workflow.parameters.%s}}" % event["sanitized_name"])

# Map S3 upload headers to environment variables
if S3_SERVER_SIDE_ENCRYPTION is not None:
env["METAFLOW_S3_SERVER_SIDE_ENCRYPTION"] = S3_SERVER_SIDE_ENCRYPTION

metaflow_version = self.environment.get_environment_info()
metaflow_version["flow_name"] = self.graph.name
metaflow_version["production_token"] = self.production_token
Expand Down
6 changes: 6 additions & 0 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
S3_ENDPOINT_URL,
DEFAULT_SECRETS_BACKEND_TYPE,
AWS_SECRETS_MANAGER_DEFAULT_REGION,
S3_SERVER_SIDE_ENCRYPTION,
)
from metaflow.mflog import (
export_mflog_env_vars,
Expand Down Expand Up @@ -263,6 +264,11 @@ def create_job(
if tmpfs_enabled and tmpfs_tempdir:
job.environment_variable("METAFLOW_TEMPDIR", tmpfs_path)

if S3_SERVER_SIDE_ENCRYPTION is not None:
job.environment_variable(
"METAFLOW_S3_SERVER_SIDE_ENCRYPTION", S3_SERVER_SIDE_ENCRYPTION
)

# Skip setting METAFLOW_DATASTORE_SYSROOT_LOCAL because metadata sync between the local user
# instance and the remote AWS Batch instance assumes metadata is stored in DATASTORE_LOCAL_DIR
# on the remote AWS Batch instance; this happens when METAFLOW_DATASTORE_SYSROOT_LOCAL
Expand Down
56 changes: 46 additions & 10 deletions metaflow/plugins/datatools/s3/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DATATOOLS_S3ROOT,
S3_RETRY_COUNT,
S3_TRANSIENT_RETRY_COUNT,
S3_SERVER_SIDE_ENCRYPTION,
TEMPDIR,
)
from metaflow.util import (
Expand Down Expand Up @@ -83,9 +84,10 @@ def ensure_unicode(x):
("value", Optional[PutValue]),
("path", Optional[str]),
("content_type", Optional[str]),
("encryption", Optional[str]),
("metadata", Optional[Dict[str, str]]),
],
defaults=(None, None, None, None),
defaults=(None, None, None, None, None),
)
S3PutObject.__module__ = __name__

Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
metadata: Optional[Dict[str, str]] = None,
range_info: Optional[RangeInfo] = None,
last_modified: int = None,
encryption: Optional[str] = None,
):
# all fields of S3Object should return a unicode object
prefix, url, path = map(ensure_unicode, (prefix, url, path))
Expand Down Expand Up @@ -176,6 +179,8 @@ def __init__(
self._key = url[len(prefix.rstrip("/")) + 1 :].rstrip("/")
self._prefix = prefix

self._encryption = encryption

@property
def exists(self) -> bool:
"""
Expand Down Expand Up @@ -321,6 +326,7 @@ def has_info(self) -> bool:
self._content_type is not None
or self._metadata is not None
or self._range_info is not None
or self._encryption is not None
)

@property
Expand Down Expand Up @@ -348,6 +354,18 @@ def content_type(self) -> Optional[str]:
"""
return self._content_type

@property
def encryption(self) -> Optional[str]:
"""
Returns the encryption type of the S3 object or None if it is not defined.

Returns
-------
str
Server-side-encryption type or None if parameter is not set.
"""
return self._encryption

@property
def range_info(self) -> Optional[RangeInfo]:
"""
Expand Down Expand Up @@ -486,6 +504,7 @@ def __init__(
prefix: Optional[str] = None,
run: Optional[Union[FlowSpec, "Run"]] = None,
s3root: Optional[str] = None,
encryption: Optional[str] = S3_SERVER_SIDE_ENCRYPTION,
**kwargs
):
if not boto_found:
Expand Down Expand Up @@ -539,6 +558,7 @@ def __init__(
"inject_failure_rate", TEST_INJECT_RETRYABLE_FAILURES
)
self._tmpdir = mkdtemp(dir=tmproot, prefix="metaflow.s3.")
self._encryption = encryption

def __enter__(self) -> "S3":
return self
Expand Down Expand Up @@ -739,6 +759,7 @@ def _info(s3, tmp):
"metadata": resp["Metadata"],
"size": resp["ContentLength"],
"last_modified": get_timestamp(resp["LastModified"]),
"encryption": resp["ServerSideEncryption"],
}

info_results = None
Expand All @@ -758,6 +779,7 @@ def _info(s3, tmp):
content_type=info_results["content_type"],
metadata=info_results["metadata"],
last_modified=info_results["last_modified"],
encryption=info_results["encryption"],
)
return S3Object(self._s3root, url, None)

Expand Down Expand Up @@ -811,7 +833,9 @@ def _head():
else:
yield self._s3root, s3url, None, info["size"], info[
"content_type"
], info["metadata"], None, info["last_modified"]
], info["metadata"], None, info["last_modified"], info[
"encryption"
]
else:
# This should not happen; we should always get a response
# even if it contains an error inside it
Expand Down Expand Up @@ -886,6 +910,7 @@ def _download(s3, tmp):
if return_info:
return {
"content_type": resp["ContentType"],
"encryption": resp["ServerSideEncryption"],
"metadata": resp["Metadata"],
"range_result": range_result,
"last_modified": get_timestamp(resp["LastModified"]),
Expand All @@ -906,6 +931,7 @@ def _download(s3, tmp):
url,
path,
content_type=addl_info["content_type"],
encryption=addl_info["encryption"],
metadata=addl_info["metadata"],
range_info=addl_info["range_result"],
last_modified=addl_info["last_modified"],
Expand Down Expand Up @@ -967,13 +993,15 @@ def _get():
- range_info["start"]
+ 1,
)
yield self._s3root, s3url, os.path.join(
self._tmpdir, fname
), None, info["content_type"], info[
"metadata"
], range_info, info[
"last_modified"
]
yield self._s3root, s3url, os.path.join(
self._tmpdir, fname
), None, info["content_type"], info[
"metadata"
], range_info, info[
"last_modified"
], info[
"encryption"
]
else:
yield self._s3root, s3prefix, None
else:
Expand Down Expand Up @@ -1033,6 +1061,8 @@ def _get():
self._tmpdir, fname
), None, info["content_type"], info["metadata"], range_info, info[
"last_modified"
], info[
"encryption"
]
else:
yield s3prefix, s3url, os.path.join(self._tmpdir, fname)
Expand Down Expand Up @@ -1120,14 +1150,16 @@ def put(
url = self._url(key)
src = urlparse(url)
extra_args = None
if content_type or metadata:
if content_type or metadata or self._encryption:
extra_args = {}
if content_type:
extra_args["ContentType"] = content_type
if metadata:
extra_args["Metadata"] = {
"metaflow-user-attributes": json.dumps(metadata)
}
if self._encryption:
extra_args["ServerSideEncryption"] = self._encryption

def _upload(s3, _):
# We make sure we are at the beginning in case we are retrying
Expand Down Expand Up @@ -1200,6 +1232,8 @@ def _store():
store_info["metadata"] = {
"metaflow-user-attributes": json.dumps(metadata)
}
if self._encryption:
store_info["encryption"] = self._encryption
if isinstance(obj, (RawIOBase, BufferedIOBase)):
if not obj.readable() or not obj.seekable():
raise MetaflowS3InvalidObject(
Expand Down Expand Up @@ -1272,6 +1306,8 @@ def _check():
store_info["metadata"] = {
"metaflow-user-attributes": json.dumps(metadata)
}
if self._encryption:
store_info["encryption"] = self._encryption
if not os.path.exists(path):
raise MetaflowS3NotFound("Local file not found: %s" % path)
yield path, self._url(key), store_info
Expand Down
19 changes: 15 additions & 4 deletions metaflow/plugins/datatools/s3/s3op.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
local,
prefix,
content_type=None,
encryption=None,
metadata=None,
range=None,
idx=None,
Expand All @@ -77,6 +78,7 @@ def __init__(
self.metadata = metadata
self.range = range
self.idx = idx
self.encryption = encryption

def __str__(self):
return self.url
Expand Down Expand Up @@ -171,6 +173,7 @@ def op_info(url):
"error": None,
"size": head["ContentLength"],
"content_type": head["ContentType"],
"encryption": head["ServerSideEncryption"],
"metadata": head["Metadata"],
"last_modified": get_timestamp(head["LastModified"]),
}
Expand Down Expand Up @@ -276,6 +279,8 @@ def op_info(url):
args["content_type"] = resp["ContentType"]
if resp["Metadata"] is not None:
args["metadata"] = resp["Metadata"]
if resp["ServerSideEncryption"] is not None:
args["encryption"] = resp["ServerSideEncryption"]
if resp["LastModified"]:
args["last_modified"] = get_timestamp(
resp["LastModified"]
Expand All @@ -299,12 +304,14 @@ def op_info(url):
do_upload = True
if do_upload:
extra = None
if url.content_type or url.metadata:
if url.content_type or url.metadata or url.encryption:
extra = {}
if url.content_type:
extra["ContentType"] = url.content_type
if url.metadata is not None:
extra["Metadata"] = url.metadata
if url.encryption is not None:
extra["ServerSideEncryption"] = url.encryption
try:
s3.upload_file(
url.local, url.bucket, url.path, ExtraArgs=extra
Expand Down Expand Up @@ -461,6 +468,7 @@ def get_info(self, url):
prefix=url.prefix,
content_type=head["ContentType"],
metadata=head["Metadata"],
encryption=head["ServerSideEncryption"],
range=url.range,
),
head["ContentLength"],
Expand Down Expand Up @@ -578,7 +586,7 @@ def verify_results(urls, verbose=False):
raise
if expected != got:
exit(ERROR_VERIFY_FAILED, url)
if url.content_type or url.metadata:
if url.content_type or url.metadata or url.encryption:
# Verify that we also have a metadata file present
try:
os.stat("%s_meta" % url.local)
Expand Down Expand Up @@ -838,11 +846,12 @@ def _files():
url = r["url"]
content_type = r.get("content_type", None)
metadata = r.get("metadata", None)
encryption = r.get("encryption", None)
if not os.path.exists(local):
exit(ERROR_LOCAL_FILE_NOT_FOUND, local)
yield input_line_idx, local, url, content_type, metadata
yield input_line_idx, local, url, content_type, metadata, encryption

def _make_url(idx, local, user_url, content_type, metadata):
def _make_url(idx, local, user_url, content_type, metadata, encryption):
src = urlparse(user_url)
url = S3Url(
url=user_url,
Expand All @@ -853,6 +862,7 @@ def _make_url(idx, local, user_url, content_type, metadata):
content_type=content_type,
metadata=metadata,
idx=idx,
encryption=encryption,
)
if src.scheme != "s3":
exit(ERROR_INVALID_URL, url)
Expand Down Expand Up @@ -896,6 +906,7 @@ def _make_url(idx, local, user_url, content_type, metadata):
"local": url.local,
"content_type": url.content_type,
"metadata": url.metadata,
"encryption": url.encryption,
}
)
+ "\n"
Expand Down
6 changes: 6 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
S3_ENDPOINT_URL,
SERVICE_HEADERS,
SERVICE_INTERNAL_URL,
S3_SERVER_SIDE_ENCRYPTION,
)
from metaflow.mflog import (
BASH_SAVE_LOGS,
Expand Down Expand Up @@ -258,6 +259,11 @@ def create_job(
# see get_datastore_root_from_config in datastore/local.py).
)

if S3_SERVER_SIDE_ENCRYPTION is not None:
job.environment_variable(
"METAFLOW_S3_SERVER_SIDE_ENCRYPTION", S3_SERVER_SIDE_ENCRYPTION
)

# Set environment variables to support metaflow.integrations.ArgoEvent
job.environment_variable(
"METAFLOW_ARGO_EVENTS_WEBHOOK_URL", ARGO_EVENTS_INTERNAL_WEBHOOK_URL
Expand Down
Loading