forked from aws/amazon-sagemaker-examples
-
Notifications
You must be signed in to change notification settings - Fork 2
/
inference.py
44 lines (39 loc) · 1.64 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import json
import mxnet as mx
# Please make sure to import neomxnet
import neomxnet # noqa: F401
import io
import os
import logging
# Change the context to mx.gpu() if deploying to a GPU endpoint
ctx = mx.cpu()
def model_fn(model_dir):
logging.info('Invoking user-defined model_fn')
# The compiled model artifacts are saved with the prefix 'compiled'
sym, arg_params, aux_params = mx.model.load_checkpoint(os.path.join(model_dir, 'compiled'), 0)
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
exe = mod.bind(for_training=False,
data_shapes=[('data', (1,3,224,224))],
label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)
# Run warm-up inference on empty data during model load (required for GPU)
data = mx.nd.empty((1,3,224,224), ctx=ctx)
mod.predict(data)
return mod
def transform_fn(mod, data, input_content_type, output_content_type):
logging.info('Invoking user-defined transform_fn')
if output_content_type == 'application/json':
# pre-processing
data = json.loads(data)
mx_ndarray = mx.nd.array(data)
resized = mx.image.imresize(mx_ndarray, 224, 224)
transposed = resized.transpose((2, 0, 1))
batchified = transposed.expand_dims(axis=0)
processed_input = batchified.as_in_context(ctx)
# prediction/inference
prediction_result = mod.predict(processed_input)
# post-processing
prediction = prediction_result.asnumpy().tolist()
prediction_json = json.dumps(prediction[0])
return prediction_json, output_content_type