Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

SSD performance optimization and benchmark script #10483

Merged
merged 7 commits into from
May 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions example/ssd/benchmark_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import print_function
import os
import sys
import argparse
import importlib
import mxnet as mx
import time
import logging

from symbol.symbol_factory import get_symbol
from symbol.symbol_factory import get_symbol_train
from symbol import symbol_builder


parser = argparse.ArgumentParser(description='MxNet SSD benchmark')
parser.add_argument('--network', '-n', type=str, default='vgg16_reduced')
parser.add_argument('--batch_size', '-b', type=int, default=0)
parser.add_argument('--shape', '-w', type=int, default=300)
parser.add_argument('--class_num', '-class', type=int, default=20)


def get_data_shapes(batch_size):
image_shape = (3, 300, 300)
return [('data', (batch_size,)+image_shape)]

def get_data(batch_size):
data_shapes = get_data_shapes(batch_size)
data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=mx.cpu()) for _, shape in data_shapes]
batch = mx.io.DataBatch(data, [])
return batch


if __name__ == '__main__':
args = parser.parse_args()
network = args.network
image_shape = args.shape
num_classes = args.class_num
b = args.batch_size
supported_image_shapes = [300, 512]
supported_networks = ['vgg16_reduced', 'inceptionv3', 'resnet50']

if network not in supported_networks:
raise Exception(network + " is not supported")

if image_shape not in supported_image_shapes:
raise Exception("Image shape should be either 300*300 or 512*512!")

if b == 0:
batch_sizes = [1, 2, 4, 8, 16, 32]
else:
batch_sizes = [b]

data_shape = (3, image_shape, image_shape)
net = get_symbol(network, data_shape[1], num_classes=num_classes,
nms_thresh=0.4, force_suppress=True)

num_batches = 100
dry_run = 5 # use 5 iterations to warm up

for bs in batch_sizes:
batch = get_data(bs)
mod = mx.mod.Module(net, label_names=None, context=mx.cpu())
mod.bind(for_training = False,
inputs_need_grad = False,
data_shapes = get_data_shapes(bs))
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try load some pre-trained models to test the real perf

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Do you know where I can find pre-trained parameters for SSD? I didn't find them in data.mxnet.io.


# get data
data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=mx.cpu()) for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, [])

for i in range(dry_run + num_batches):
if i == dry_run:
tic = time.time()
mod.forward(batch, is_train=False)
for output in mod.get_outputs():
output.wait_to_read()

avg_time = (time.time() - tic) / num_batches
fps = bs / avg_time
print("SSD-" + network + " with " + str(num_classes) + " classes and shape " + str(data_shape))
print("batchsize=" + str(bs) + " " + str(1000*avg_time) + " ms")
print("batchsize=" + str(bs) + " " + str(fps) + " imgs/s")
41 changes: 31 additions & 10 deletions src/operator/contrib/multibox_detection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,16 @@ inline void MultiBoxDetectionForward(const Tensor<cpu, 3, DType> &out,
const int num_anchors = cls_prob.size(2);
const int num_batches = cls_prob.size(0);
const DType *p_anchor = anchors.dptr_;

const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
std::vector<DType> outputs;
outputs.reserve(num_anchors * 6);
for (int nbatch = 0; nbatch < num_batches; ++nbatch) {
const DType *p_cls_prob = cls_prob.dptr_ + nbatch * num_classes * num_anchors;
const DType *p_loc_pred = loc_pred.dptr_ + nbatch * num_anchors * 4;
DType *p_out = out.dptr_ + nbatch * num_anchors * 6;
int valid_count = 0;

#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < num_anchors; ++i) {
// find the predicted class id and probability
DType score = -1;
Expand All @@ -112,20 +117,33 @@ inline void MultiBoxDetectionForward(const Tensor<cpu, 3, DType> &out,
id = j;
}
}

if (id > 0 && score < threshold) {
id = 0;
}
if (id > 0) {
// [id, prob, xmin, ymin, xmax, ymax]
p_out[valid_count * 6] = id - 1; // remove background, restore original id
p_out[valid_count * 6 + 1] = (id == 0 ? DType(-1) : score);
int offset = i * 4;
TransformLocations(p_out + valid_count * 6 + 2, p_anchor + offset,
p_loc_pred + offset, clip, variances[0], variances[1],
variances[2], variances[3]);

// [id, prob, xmin, ymin, xmax, ymax]
outputs[i * 6] = id - 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do any tests exist currently for this op ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess no

outputs[i * 6 + 1] = score;
int offset = i * 4;
TransformLocations(outputs.data() + i * 6 + 2, p_anchor + offset, p_loc_pred + offset, clip,
variances[0], variances[1], variances[2], variances[3]);
}

int valid_count = 0;
for (int i = 0; i < num_anchors; ++i) {
int offset1 = valid_count * 6;
int offset2 = i * 6;
if (outputs[offset2] >= 0) {
p_out[offset1] = outputs[offset2];
p_out[offset1 + 1] = outputs[offset2 + 1];
p_out[offset1 + 2] = outputs[offset2 + 2];
p_out[offset1 + 3] = outputs[offset2 + 3];
p_out[offset1 + 4] = outputs[offset2 + 4];
p_out[offset1 + 5] = outputs[offset2 + 5];
++valid_count;
}
} // end iter num_anchors
}

if (valid_count < 1 || nms_threshold <= 0 || nms_threshold > 1) continue;

Expand All @@ -138,6 +156,7 @@ inline void MultiBoxDetectionForward(const Tensor<cpu, 3, DType> &out,
sorter.push_back(SortElemDescend<DType>(p_out[i * 6 + 1], i));
}
std::stable_sort(sorter.begin(), sorter.end());

// re-order output
DType *ptemp = temp_space.dptr_ + nbatch * num_anchors * 6;
int nkeep = static_cast<int>(sorter.size());
Expand All @@ -153,7 +172,9 @@ inline void MultiBoxDetectionForward(const Tensor<cpu, 3, DType> &out,
p_out[i * 6 + j] = ptemp[sorter[i].index * 6 + j];
}
}

// apply nms
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < nkeep; ++i) {
int offset_i = i * 6;
if (p_out[offset_i] < 0) continue; // skip eliminated
Expand Down