Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Storage.py and initialiser image #1368

Merged
merged 7 commits into from
Feb 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 94 additions & 13 deletions python/seldon_core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from azure.storage.blob import BlockBlobService
from minio import Minio
from seldon_core.imports_helper import _GCS_PRESENT
from seldon_core.utils import getenv

if _GCS_PRESENT:
from google.auth import exceptions
Expand Down Expand Up @@ -78,6 +79,7 @@ def _download_s3(uri, temp_dir: str):
bucket_name = bucket_args[0]
bucket_path = bucket_args[1] if len(bucket_args) > 1 else ""
objects = client.list_objects(bucket_name, prefix=bucket_path, recursive=True)
count = 0
for obj in objects:
# Replace any prefix from the object key with temp_dir
subdir_object_key = obj.object_name.replace(bucket_path, "", 1).strip("/")
Expand All @@ -90,6 +92,13 @@ def _download_s3(uri, temp_dir: str):
obj.object_name,
os.path.join(temp_dir, subdir_object_key),
)
count = count + 1
if count == 0:
raise RuntimeError(
"Failed to fetch model. \
The path or model %s does not exist."
% (uri)
)

@staticmethod
def _download_gcs(uri, temp_dir: str):
Expand All @@ -105,6 +114,7 @@ def _download_gcs(uri, temp_dir: str):
if not prefix.endswith("/"):
prefix = prefix + "/"
blobs = bucket.list_blobs(prefix=prefix)
count = 0
for blob in blobs:
# Replace any prefix from the object key with temp_dir
subdir_object_key = blob.name.replace(bucket_path, "", 1).strip("/")
Expand All @@ -120,38 +130,102 @@ def _download_gcs(uri, temp_dir: str):
dest_path = os.path.join(temp_dir, subdir_object_key)
logging.info("Downloading: %s", dest_path)
blob.download_to_filename(dest_path)
count = count + 1
if count == 0:
raise RuntimeError(
"Failed to fetch model. \
The path or model %s does not exist."
% (uri)
)

@staticmethod
def _download_blob(uri, out_dir: str):
def _download_blob(uri, out_dir: str): # pylint: disable=too-many-locals
match = re.search(_BLOB_RE, uri)
account_name = match.group(1)
storage_url = match.group(2)
container_name, prefix = storage_url.split("/", 1)

logging.info(
"Connecting to BLOB account: %s, contianer: %s",
"Connecting to BLOB account: [%s], container: [%s], prefix: [%s]",
account_name,
container_name,
prefix,
)
block_blob_service = BlockBlobService(account_name=account_name)
blobs = block_blob_service.list_blobs(container_name, prefix=prefix)

try:
block_blob_service = BlockBlobService(account_name=account_name)
blobs = block_blob_service.list_blobs(container_name, prefix=prefix)
except Exception: # pylint: disable=broad-except
token = Storage._get_azure_storage_token()
if token is None:
logging.warning(
"Azure credentials not found, retrying anonymous access"
)
block_blob_service = BlockBlobService(
account_name=account_name, token_credential=token
)
blobs = block_blob_service.list_blobs(container_name, prefix=prefix)
count = 0
for blob in blobs:
dest_path = os.path.join(out_dir, blob.name)
if "/" in blob.name:
head, _ = os.path.split(blob.name)
head, tail = os.path.split(blob.name)
if prefix is not None:
head = head[len(prefix) :]
if head.startswith("/"):
head = head[1:]
dir_path = os.path.join(out_dir, head)
dest_path = os.path.join(dir_path, tail)
if not os.path.isdir(dir_path):
os.makedirs(dir_path)

dest_path = os.path.join(out_dir, blob.name)
logging.info("Downloading: %s", dest_path)
logging.info("Downloading: %s to %s", blob.name, dest_path)
block_blob_service.get_blob_to_path(container_name, blob.name, dest_path)
count = count + 1
if count == 0:
raise RuntimeError(
"Failed to fetch model. \
The path or model %s does not exist."
% (uri)
)

@staticmethod
def _get_azure_storage_token():
tenant_id = os.getenv("AZ_TENANT_ID", "")
client_id = os.getenv("AZ_CLIENT_ID", "")
client_secret = os.getenv("AZ_CLIENT_SECRET", "")
subscription_id = os.getenv("AZ_SUBSCRIPTION_ID", "")

if (
tenant_id == ""
or client_id == ""
or client_secret == ""
or subscription_id == ""
):
return None

# note the SP must have "Storage Blob Data Owner" perms for this to work
import adal
from azure.storage.common import TokenCredential

authority_url = "https://login.microsoftonline.com/" + tenant_id

context = adal.AuthenticationContext(authority_url)

token = context.acquire_token_with_client_credentials(
"https://storage.azure.com/", client_id, client_secret
)

token_credential = TokenCredential(token["accessToken"])

logging.info("Retrieved SP token credential for client_id: %s", client_id)

return token_credential

@staticmethod
def _download_local(uri, out_dir=None):
local_path = uri.replace(_LOCAL_PREFIX, "", 1)
if not os.path.exists(local_path):
raise Exception("Local path %s does not exist." % (uri))
raise RuntimeError("Local path %s does not exist." % (uri))

if out_dir is None:
return local_path
Expand All @@ -171,14 +245,21 @@ def _download_local(uri, out_dir=None):
@staticmethod
def _create_minio_client():
# Remove possible http scheme for Minio
url = urlparse(os.getenv("S3_ENDPOINT", ""))
url = urlparse(os.getenv("AWS_ENDPOINT_URL", "s3.amazonaws.com"))
use_ssl = (
url.scheme == "https" if url.scheme else bool(os.getenv("USE_SSL", True))
url.scheme == "https"
if url.scheme
# KFServing uses S3_USE_HTTPS, whereas Seldon was already using
# USE_SSL.
# To keep compatibility with the storage init layer we support
# both, giving priority to USE_SSL.
# https://github.com/SeldonIO/seldon-core/pull/827
# https://github.com/kubeflow/kfserving/pull/362
else bool(getenv("USE_SSL", "S3_USE_HTTPS", "false"))
)
minioClient = Minio(
return Minio(
url.netloc,
access_key=os.getenv("AWS_ACCESS_KEY_ID", ""),
secret_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
secure=use_ssl,
)
return minioClient
29 changes: 25 additions & 4 deletions python/seldon_core/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import json
import sys
import base64
Expand Down Expand Up @@ -520,21 +521,17 @@ def extract_request_parts_json(
if not isinstance(request, dict):
raise SeldonMicroserviceException(f"Invalid request data type: {request}")
meta = request.get("meta", None)
datadef_type = None
datadef = None

if "data" in request:
data_type = "data"
datadef = request["data"]
if "tensor" in datadef:
datadef_type = "tensor"
tensor = datadef["tensor"]
features = np.array(tensor["values"]).reshape(tensor["shape"])
elif "ndarray" in datadef:
datadef_type = "ndarray"
features = np.array(datadef["ndarray"])
elif "tftensor" in datadef:
datadef_type = "tftensor"
tf_proto = TensorProto()
json_format.ParseDict(datadef["tftensor"], tf_proto)
features = tf.make_ndarray(tf_proto)
Expand Down Expand Up @@ -597,3 +594,27 @@ def extract_feedback_request_parts(
truth = grpc_datadef_to_array(request.truth.data)
reward = request.reward
return request.request.data, features, truth, reward


def getenv(*env_vars, default=None):
"""
Overload of os.getenv() to allow falling back through multiple environment
variables. The environment variables will be checked sequentially until one
of them is found.

Parameters
------
*env_vars
Variadic list of environment variable names to check.
default
Default value to return if none of the environment variables exist.

Returns
------
Value of the first environment variable set or default.
"""
for env_var in env_vars:
if env_var in os.environ:
return os.environ.get(env_var)

return default
Loading