Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bug in SSL Seldon Client and added functionality in GRPC #946

Merged
merged 1 commit into from
Oct 17, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions python/seldon_core/seldon_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ class SeldonChannelCredentials(object):
"""
Channel credentials
Presently just denotes an SSL connection.
For GRPC in order to be properly implemented, you need to provide *either*
the root_certificate_files, *or* all the file paths.
The verify attribute currently is used to avoid SSL verification in REST
however for GRPC it is recommended that you provide a path at least
for the root_certificates_file otherwise it may not work as expected.
"""

def __init__(self, verify:bool = True, root_certificates_file: str = None,
def __init__(self, verify: bool = True, root_certificates_file: str = None,
private_key_file: str = None, certificate_chain_file: str = None):
self.verify = verify
self.root_certificates_file = root_certificates_file
Expand All @@ -41,7 +46,9 @@ def __init__(self, verify:bool = True, root_certificates_file: str = None,

class SeldonCallCredentials(object):
"""
Credentials for each call
Credentials for each call, currently implements the ability to provide
an OAuth token which is currently made available through REST via
the X-Auth-Token header, and via GRPC via the metadata call creds.
"""

def __init__(self,token:str = None):
Expand Down Expand Up @@ -1167,7 +1174,7 @@ def rest_predict_gateway(deployment_name: str, namespace: str = None, gateway_en
req_headers = headers.copy()
else:
req_headers = {}
if channel_credentials is None:
if call_credentials is None:
scheme = "http"
else:
scheme = "https"
Expand Down Expand Up @@ -1359,6 +1366,11 @@ def grpc_predict_gateway(deployment_name: str, namespace: str = None, gateway_en
Max grpc receive message size in bytes
names
Column names
call_credentials
Call credentials - see SeldonCallCredentials
channel_credentials
Channel credentials - see SeldonChannelCredentials


Returns
-------
Expand All @@ -1380,18 +1392,34 @@ def grpc_predict_gateway(deployment_name: str, namespace: str = None, gateway_en
if channel_credentials is None:
channel = grpc.insecure_channel(gateway_endpoint, options)
else:
grpc_channel_credentials = grpc.ssl_channel_credentials(root_certificates=open(channel_credentials.certificate_chain_file, 'rb').read(),
private_key=open(channel_credentials.private_key_file, 'rb').read(),
certificate_chain=open(channel_credentials.root_certificates_file, 'rb').read())
# If one of root cert & cert chain are provided, both must be provided
# otherwise there is a null pointer exception in the Go underlying impl
if (channel_credentials.private_key_file
and channel_credentials.root_certificates_file
and channel_credentials.certificate_chain_file):
grpc_channel_credentials = grpc.ssl_channel_credentials(
root_certificates=open(channel_credentials.root_certificates_file, 'rb').read(),
private_key=open(channel_credentials.private_key_file, 'rb').read(),
certificate_chain=open(channel_credentials.certificate_chain_file, 'rb').read())
# For most usecases only providing the root cert file is enough
elif channel_credentials.root_certificates_file:
grpc_channel_credentials = grpc.ssl_channel_credentials(
root_certificates=open(channel_credentials.root_certificates_file, 'rb').read())
# This piece also allows for blank SSL Channel credentials in case this is required
else:
grpc_channel_credentials = grpc.ssl_channel_credentials()
Copy link
Contributor

@adriangonz adriangonz Oct 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to build the dict of arguments as args = {"root_certificates: ..., "private_key": ..., ...} and then just pass it down as grpc.ssl_channel_credentials(**args) instead of the three if-else cases? This is mostly a personal preference though, so please ignore it if you don't agree.

On the other hand, and this is something that was already present on the codebase, does an inline call to open().read() also close the file after it's finished?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason was due to the open statement that is embedded (And yeah it was already in that format). I do agree, I think we'll have to review the whole GRPC / REST SSL at some point soon so not sure how much worth it would be to have it here.

In regards ot the second question, open files get closed and garbage collected when they lose the last reference pointer, so in this case they would be getting closed either immediately or when the function returns.

if channel_credentials.verify == False:
# If Verify is set to false then we add the SSL Target Name Override option
options += [('grpc.ssl_target_name_override', gateway_endpoint.split(":")[0])]

if not call_credentials is None:
#grpc_call_credentials = grpc.access_token_call_credentials(call_credentials.token)
grpc_call_credentials = grpc.metadata_call_credentials(
lambda context, callback: callback((("X-Auth-Token", call_credentials.token),), None))
credentials = grpc.composite_channel_credentials(grpc_channel_credentials,grpc_call_credentials)
lambda context, callback: callback((("x-auth-token", call_credentials.token),), None))
credentials = grpc.composite_channel_credentials(grpc_channel_credentials, grpc_call_credentials)
else:
credentials = grpc_channel_credentials
logger.debug(gateway_endpoint)
channel = grpc.secure_channel(gateway_endpoint,credentials, options)
logger.debug(f"Sending GRPC Request to endpoint: {gateway_endpoint}")
channel = grpc.secure_channel(gateway_endpoint, credentials, options)
stub = prediction_pb2_grpc.SeldonStub(channel)
if namespace is None:
metadata = [('seldon', deployment_name)]
Expand All @@ -1400,11 +1428,8 @@ def grpc_predict_gateway(deployment_name: str, namespace: str = None, gateway_en
if not headers is None:
for k in headers:
metadata.append((k, headers[k]))
#try:
response = stub.Predict(request=request, metadata=metadata)
return SeldonClientPrediction(request, response, True, "")
#except Exception as e:
# return SeldonClientPrediction(request, None, False, str(e))


def rest_feedback_seldon_oauth(prediction_request: prediction_pb2.SeldonMessage = None,
Expand Down