-
Notifications
You must be signed in to change notification settings - Fork 81
/
transformer.py
284 lines (227 loc) · 11 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functionality for the Transformer class,
which represents the execution workflow for handling inference
requests.
"""
from __future__ import absolute_import
import importlib
import traceback
try:
from inspect import signature # pylint: disable=ungrouped-imports
except ImportError:
# for Python2.7
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "inspect2"])
from inspect2 import signature
try:
from importlib.util import find_spec # pylint: disable=ungrouped-imports
except ImportError:
import imp # noqa: F401
def find_spec(module_name):
"""Function that searches for a module.
Args:
module_name: The name of the module to search for.
Returns:
bool: Whether the module was found.
"""
try:
imp.find_module(module_name)
return True
except ImportError:
return None
from six.moves import http_client
from sagemaker_inference import content_types, environment, utils
from sagemaker_inference.default_inference_handler import DefaultInferenceHandler
from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError
class Transformer(object):
"""Represents the execution workflow for handling inference requests
sent to the model server.
"""
def __init__(self, default_inference_handler=None):
"""Initialize a ``Transformer``.
Args:
default_inference_handler (DefaultInferenceHandler): default implementation of
inference handlers to use in absence of expected serving functions within
the user module. Defaults to ``DefaultInferenceHandler``.
"""
self._default_inference_handler = default_inference_handler or DefaultInferenceHandler()
self._initialized = False
self._environment = None
self._model = None
self._pre_model_fn = None
self._model_warmup_fn = None
self._model_fn = None
self._transform_fn = None
self._input_fn = None
self._predict_fn = None
self._output_fn = None
self._context = None
@staticmethod
def handle_error(context, inference_exception, trace):
"""Set context appropriately for error response.
Args:
context (mms.context.Context): The inference context.
inference_exception (sagemaker_inference.errors.BaseInferenceToolkitError): An exception
raised during inference, with information for the error response.
trace (traceback): The stacktrace of the error.
Returns:
str: The error message and stacktrace from the exception.
"""
context.set_response_status(
code=inference_exception.status_code,
phrase=utils.remove_crlf(inference_exception.phrase),
)
return ["{}\n{}".format(inference_exception.message, trace)]
def transform(self, data, context):
"""Take a request with input data, deserialize it, make a prediction, and return a
serialized response.
Args:
data (obj): the request data.
context (obj): metadata on the incoming request data.
Returns:
list[obj]: The serialized prediction result wrapped in a list if
inference is successful. Otherwise returns an error message
with the context set appropriately.
"""
try:
properties = context.system_properties
model_dir = properties.get("model_dir")
self.validate_and_initialize(model_dir=model_dir, context=context)
response_list = []
for i in range(len(data)):
input_data = data[i].get("body")
request_processor = context.request_processor[0]
request_property = request_processor.get_request_properties()
content_type = utils.retrieve_content_type_header(request_property)
accept = request_property.get("Accept") or request_property.get("accept")
if not accept or accept == content_types.ANY:
accept = self._environment.default_accept
if content_type in content_types.UTF8_TYPES:
input_data = input_data.decode("utf-8")
result = self._run_handler_function(
self._transform_fn, *(self._model, input_data, content_type, accept)
)
response = result
response_content_type = accept
if isinstance(result, tuple):
# handles tuple for backwards compatibility
response = result[0]
response_content_type = result[1]
context.set_response_content_type(0, response_content_type)
response_list.append(response)
return response_list
except Exception as e: # pylint: disable=broad-except
trace = traceback.format_exc()
if isinstance(e, BaseInferenceToolkitError):
return self.handle_error(context, e, trace)
else:
return self.handle_error(
context,
GenericInferenceToolkitError(http_client.INTERNAL_SERVER_ERROR, str(e)),
trace,
)
def validate_and_initialize(self, model_dir=environment.model_dir, context=None):
"""Validates the user module against the SageMaker inference contract.
Load the model as defined by the ``model_fn`` to prepare handling predictions.
"""
if not self._initialized:
self._context = context
self._environment = environment.Environment()
self._validate_user_module_and_set_functions()
if self._pre_model_fn is not None:
self._run_handler_function(self._pre_model_fn, *(model_dir,))
self._model = self._run_handler_function(self._model_fn, *(model_dir,))
if self._model_warmup_fn is not None:
self._run_handler_function(self._model_warmup_fn, *(model_dir, self._model))
self._initialized = True
def _validate_user_module_and_set_functions(self):
"""Retrieves and validates the inference handlers provided within the user module.
Default implementations of the inference handlers are utilized in
place of missing functions defined in the user module.
"""
user_module_name = self._environment.module_name
self._pre_model_fn = getattr(self._default_inference_handler, "default_pre_model_fn", None)
self._model_warmup_fn = getattr(
self._default_inference_handler, "default_model_warmup_fn", None
)
if find_spec(user_module_name) is not None:
user_module = importlib.import_module(user_module_name)
self._model_fn = getattr(
user_module, "model_fn", self._default_inference_handler.default_model_fn
)
transform_fn = getattr(user_module, "transform_fn", None)
input_fn = getattr(user_module, "input_fn", None)
predict_fn = getattr(user_module, "predict_fn", None)
output_fn = getattr(user_module, "output_fn", None)
pre_model_fn = getattr(user_module, "pre_model_fn", None)
model_warmup_fn = getattr(user_module, "model_warmup_fn", None)
if transform_fn and (input_fn or predict_fn or output_fn):
raise ValueError(
"Cannot use transform_fn implementation in conjunction with "
"input_fn, predict_fn, and/or output_fn implementation"
)
self._transform_fn = transform_fn or self._default_transform_fn
self._input_fn = input_fn or self._default_inference_handler.default_input_fn
self._predict_fn = predict_fn or self._default_inference_handler.default_predict_fn
self._output_fn = output_fn or self._default_inference_handler.default_output_fn
if pre_model_fn is not None:
self._pre_model_fn = pre_model_fn
if model_warmup_fn is not None:
self._model_warmup_fn = model_warmup_fn
else:
self._model_fn = self._default_inference_handler.default_model_fn
self._input_fn = self._default_inference_handler.default_input_fn
self._predict_fn = self._default_inference_handler.default_predict_fn
self._output_fn = self._default_inference_handler.default_output_fn
self._transform_fn = self._default_transform_fn
def _default_transform_fn(self, model, input_data, content_type, accept, context=None):
# pylint: disable=unused-argument
"""Make predictions against the model and return a serialized response.
This serves as the default implementation of transform_fn, used when the
user has not provided an implementation.
Args:
model (obj): model loaded by model_fn.
input_data (obj): the request data.
content_type (str): the request content type.
accept (str): accept header expected by the client.
context (obj): the request context (default: None).
Returns:
obj: the serialized prediction result or a tuple of the form
(response_data, content_type)
"""
data = self._run_handler_function(self._input_fn, *(input_data, content_type))
prediction = self._run_handler_function(self._predict_fn, *(data, model))
result = self._run_handler_function(self._output_fn, *(prediction, accept))
return result
def _run_handler_function(self, func, *argv):
"""Helper to call the handler function which covers 2 cases:
1. the handle function takes context
2. the handle function does not take context
"""
num_func_input = len(signature(func).parameters)
if num_func_input == len(argv):
# function does not take context
result = func(*argv)
elif num_func_input == len(argv) + 1:
# function takes context
argv_context = argv + (self._context,)
result = func(*argv_context)
else:
raise TypeError(
"{} takes {} arguments but {} were given.".format(
func.__name__, num_func_input, len(argv)
)
)
return result