-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6d751ba
commit 963c758
Showing
5 changed files
with
252 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import logging | ||
import argparse | ||
import os | ||
import mxnet as mx | ||
from mxnet import gluon | ||
from mxnet.gluon.model_zoo import vision | ||
from gluoncv.data import imagenet | ||
|
||
# Two functions for reading data from record file or raw images | ||
def get_val_data(rec_val, | ||
batch_size, | ||
num_workers=4): | ||
rec_val = os.path.expanduser(rec_val) | ||
mean_rgb = [123.68, 116.779, 103.939] | ||
std_rgb = [58.393, 57.12, 57.375] | ||
def batch_fn(batch, ctx): | ||
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) | ||
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) | ||
return data, label | ||
|
||
val_data = mx.io.ImageRecordIter( | ||
path_imgrec = rec_val, | ||
preprocess_threads = num_workers, | ||
shuffle = True, | ||
batch_size = batch_size, | ||
resize = 256, | ||
data_shape = (3, 224, 224), | ||
mean_r = mean_rgb[0], | ||
mean_g = mean_rgb[1], | ||
mean_b = mean_rgb[2], | ||
std_r = std_rgb[0], | ||
std_g = std_rgb[1], | ||
std_b = std_rgb[2], | ||
) | ||
return val_data, batch_fn | ||
|
||
|
||
def evaluate(args, graph, lib, params, ctx): | ||
"""Evaluate on the validation set.""" | ||
import tvm | ||
from tvm.contrib import graph_runtime | ||
|
||
# tetup dataset. | ||
batch_size = args.batch_size | ||
val_data, batch_fn = get_val_data(args.rec_val, batch_size) | ||
# create runtime module | ||
m = graph_runtime.create(graph, lib, ctx) | ||
m.set_input(**params) | ||
oshape = (batch_size, args.num_classes) | ||
out_arr = tvm.nd.empty(oshape, "float32") | ||
# setup evaluaiton metric | ||
acc_top1 = mx.metric.Accuracy() | ||
acc_top5 = mx.metric.TopKAccuracy(5) | ||
val_data.reset() | ||
acc_top1.reset() | ||
acc_top5.reset() | ||
# Execute | ||
for i, batch in enumerate(val_data): | ||
data, label = batch_fn(batch, [mx.cpu(0)]) | ||
m.run(data=data[0].asnumpy()) | ||
m.get_output(0, out_arr) | ||
acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())]) | ||
acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())]) | ||
|
||
if args.log_interval and not (i + 1) % args.log_interval: | ||
_, top1 = acc_top1.get() | ||
_, top5 = acc_top5.get() | ||
nsamples = (i + 1) * batch_size | ||
logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5) | ||
logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5) | ||
|
||
|
||
def build_nnvm(args, gluon_model): | ||
"""Build with nnvm path""" | ||
import tvm | ||
import nnvm | ||
import nnvm.compiler | ||
net, params = nnvm.frontend.from_mxnet(gluon_model) | ||
data_shape = (args.batch_size, 3, 224, 224) | ||
shape_dict = {'data': data_shape} | ||
shape_dict.update({k: v.shape for k, v in params.items()}) | ||
dtype_dict = {"data": "float32"} | ||
target = args.target | ||
|
||
with nnvm.compiler.build_config(opt_level=3): | ||
graph, lib, params = nnvm.compiler.build( | ||
net, target, shape_dict, dtype_dict, params=params) | ||
ctx = tvm.nd.context(target, 0) | ||
return graph, lib,params, ctx | ||
|
||
|
||
def build_relay(args, gluon_model): | ||
"""Build with relay.""" | ||
import tvm | ||
from tvm import relay | ||
data_shape = (args.batch_size, 3, 224, 224) | ||
net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) | ||
target = args.target | ||
with relay.build_config(opt_level=3): | ||
graph, lib, params = relay.build( | ||
net, target, params=params) | ||
ctx = tvm.nd.context(target, 0) | ||
return graph, lib, params, ctx | ||
|
||
|
||
def build_quantize(args, gluon_model): | ||
print('build quantize') | ||
"""Build with relay.""" | ||
import tvm | ||
from tvm import relay | ||
from tvm.relay import quantize as qtz | ||
from tvm.relay import ir_pass | ||
data_shape = (args.batch_size, 3, 224, 224) | ||
net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) | ||
target = args.target | ||
|
||
with qtz.qconfig(skip_k_conv=2, global_scale=2.0): | ||
# print(net.astext()) | ||
|
||
graph = net | ||
# graph = ir_pass.infer_type(graph) | ||
# graph = ir_pass.simplify_inference(graph) | ||
# var_map = {arg.name_hint: arg for arg in graph.params} | ||
# const_map = {var_map[key]: tvm.relay.const(params[key]) for key in params} | ||
# graph = tvm.relay.bind(graph, const_map) | ||
# graph = ir_pass.fold_constant(graph) | ||
# print('after const folding\n') | ||
# print(graph.astext()) | ||
qgraph = qtz.annotate(graph, params) | ||
# print('after annotate\n') | ||
# print(qgraph.astext()) | ||
qgraph = qtz.calibrate(qgraph) | ||
# print('after calibrate\n') | ||
# print(qgraph.astext()) | ||
# qgraph = qtz.realize(qgraph) | ||
# print('after realize\n') | ||
# print(qgraph.astext()) | ||
|
||
with relay.build_config(opt_level=3): | ||
graph, lib, params = relay.build( | ||
qgraph, target, params=params) | ||
ctx = tvm.nd.context(target, 0) | ||
return graph, lib, params, ctx | ||
|
||
|
||
def main(args): | ||
gluon_model = vision.get_model(args.model, pretrained=True) | ||
if args.use_nnvm: | ||
graph, lib, params, ctx = build_nnvm(args, gluon_model) | ||
else: | ||
# graph, lib, params, ctx = build_relay(args, gluon_model) | ||
graph, lib, params, ctx = build_quantize(args, gluon_model) | ||
|
||
logging.info("Finish building model %s...", args.model) | ||
evaluate(args, graph, lib, params, ctx) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Evaluate ImageNet validation accuracy") | ||
parser.add_argument("--rec-val", type=str, default="~/.mxnet/datasets/imagenet/rec/val.rec", | ||
help="the validation data") | ||
parser.add_argument("--num-classes", type=int, default=1000, | ||
help="batch size") | ||
parser.add_argument("--model", type=str, default="resnet18_v1", | ||
help="Name of the model") | ||
parser.add_argument("--log-interval", type=int, default=100, | ||
help="log interval") | ||
parser.add_argument("--batch-size", type=int, default=1, | ||
help="batch size") | ||
parser.add_argument("--target", type=str, default="cuda", | ||
help="target option") | ||
parser.add_argument("--use-nnvm", action="store_true", | ||
help="Use legacy nnvm compiler") | ||
args = parser.parse_args() | ||
logging.basicConfig(level=logging.INFO) | ||
logging.info(args) | ||
main(args) |
Oops, something went wrong.