-
Notifications
You must be signed in to change notification settings - Fork 836
/
Copy pathwrapper.py
172 lines (138 loc) · 6.39 KB
/
wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import grpc
from concurrent import futures
from flask import jsonify, Flask, send_from_directory, request
from flask_cors import CORS
import logging
from seldon_core.utils import seldon_message_to_json, json_to_feedback
from seldon_core.flask_utils import get_request
import seldon_core.seldon_methods
from seldon_core.flask_utils import (
SeldonMicroserviceException,
ANNOTATION_GRPC_MAX_MSG_SIZE,
)
from seldon_core.proto import prediction_pb2_grpc
import os
logger = logging.getLogger(__name__)
PRED_UNIT_ID = os.environ.get("PREDICTIVE_UNIT_ID", "0")
def get_rest_microservice(user_model):
app = Flask(__name__, static_url_path="")
CORS(app)
if hasattr(user_model, "model_error_handler"):
logger.info("Registering the custom error handler...")
app.register_blueprint(user_model.model_error_handler)
@app.errorhandler(SeldonMicroserviceException)
def handle_invalid_usage(error):
response = jsonify(error.to_dict())
logger.error("%s", error.to_dict())
response.status_code = error.status_code
return response
@app.route("/seldon.json", methods=["GET"])
def openAPI():
return send_from_directory("", "openapi/seldon.json")
@app.route("/predict", methods=["GET", "POST"])
@app.route("/api/v0.1/predictions", methods=["POST"])
def Predict():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.predict(user_model, requestJson)
json_response = jsonify(response)
if "status" in response and "code" in response["status"]:
json_response.status_code = response["status"]["code"]
logger.debug("REST Response: %s", response)
return json_response
@app.route("/send-feedback", methods=["GET", "POST"])
@app.route("/api/v0.1/feedback", methods=["POST"])
def SendFeedback():
requestJson = get_request()
logger.debug("REST Request: %s", request)
requestProto = json_to_feedback(requestJson)
logger.debug("Proto Request: %s", requestProto)
responseProto = seldon_core.seldon_methods.send_feedback(
user_model, requestProto, PRED_UNIT_ID
)
jsonDict = seldon_message_to_json(responseProto)
return jsonify(jsonDict)
@app.route("/transform-input", methods=["GET", "POST"])
def TransformInput():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.transform_input(user_model, requestJson)
logger.debug("REST Response: %s", response)
return jsonify(response)
@app.route("/transform-output", methods=["GET", "POST"])
def TransformOutput():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.transform_output(user_model, requestJson)
logger.debug("REST Response: %s", response)
return jsonify(response)
@app.route("/route", methods=["GET", "POST"])
def Route():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.route(user_model, requestJson)
logger.debug("REST Response: %s", response)
return jsonify(response)
@app.route("/aggregate", methods=["GET", "POST"])
def Aggregate():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.aggregate(user_model, requestJson)
logger.debug("REST Response: %s", response)
return jsonify(response)
@app.route("/health/ping", methods=["GET"])
def HealthPing():
"""
Lightweight endpoint to check the liveness of the REST endpoint
"""
return "pong"
@app.route("/health/status", methods=["GET"])
def HealthStatus():
logger.debug("REST Health Status Request")
response = seldon_core.seldon_methods.health_status(user_model)
logger.debug("REST Health Status Response: %s", response)
return jsonify(response)
return app
# ----------------------------
# GRPC
# ----------------------------
class SeldonModelGRPC(object):
def __init__(self, user_model):
self.user_model = user_model
def Predict(self, request_grpc, context):
return seldon_core.seldon_methods.predict(self.user_model, request_grpc)
def SendFeedback(self, feedback_grpc, context):
return seldon_core.seldon_methods.send_feedback(
self.user_model, feedback_grpc, PRED_UNIT_ID
)
def TransformInput(self, request_grpc, context):
return seldon_core.seldon_methods.transform_input(self.user_model, request_grpc)
def TransformOutput(self, request_grpc, context):
return seldon_core.seldon_methods.transform_output(
self.user_model, request_grpc
)
def Route(self, request_grpc, context):
return seldon_core.seldon_methods.route(self.user_model, request_grpc)
def Aggregate(self, request_grpc, context):
return seldon_core.seldon_methods.aggregate(self.user_model, request_grpc)
def get_grpc_server(user_model, annotations={}, trace_interceptor=None):
seldon_model = SeldonModelGRPC(user_model)
options = []
if ANNOTATION_GRPC_MAX_MSG_SIZE in annotations:
max_msg = int(annotations[ANNOTATION_GRPC_MAX_MSG_SIZE])
logger.info("Setting grpc max message and receive length to %d", max_msg)
options.append(("grpc.max_message_length", max_msg))
options.append(("grpc.max_send_message_length", max_msg))
options.append(("grpc.max_receive_message_length", max_msg))
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), options=options)
if trace_interceptor:
from grpc_opentracing.grpcext import intercept_server
server = intercept_server(server, trace_interceptor)
prediction_pb2_grpc.add_GenericServicer_to_server(seldon_model, server)
prediction_pb2_grpc.add_ModelServicer_to_server(seldon_model, server)
prediction_pb2_grpc.add_TransformerServicer_to_server(seldon_model, server)
prediction_pb2_grpc.add_OutputTransformerServicer_to_server(seldon_model, server)
prediction_pb2_grpc.add_CombinerServicer_to_server(seldon_model, server)
prediction_pb2_grpc.add_RouterServicer_to_server(seldon_model, server)
prediction_pb2_grpc.add_SeldonServicer_to_server(seldon_model, server)
return server