Skip to content

Commit

Permalink
[Unittest] Test ncnn gather op and fix gather.cpp (open-mmlab#114)
Browse files Browse the repository at this point in the history
* add shape constantofshape unittest for ncnn

* fix lint

* standarize import

* fix lint

* reply for code review

* reply for code review

* fix lint

* remove some hardcode

* fix lint

* reply for code review

* test gather and fix gather cpp code

* fix yapf

* fix clang-format

* reply for code review

* reply for code review

* fix lint
  • Loading branch information
hanrui1sensetime authored Oct 9, 2021
1 parent f8a70f1 commit 21f2b04
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 105 deletions.
4 changes: 4 additions & 0 deletions backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3736,6 +3736,7 @@ int main(int argc, char** argv) {
float value = 0.f;
value = get_node_attr_f(node, "value", 0.f);
fprintf(pp, " 0=%f", value);

} else if (op == "Conv") {
const onnx::TensorProto& W = weights[node.input(1)];

Expand Down Expand Up @@ -3989,6 +3990,9 @@ int main(int argc, char** argv) {
int op_type = 2;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Gather") {
if (weights[node.input(1)].dims_size() > 1) {
fprintf(stderr, "Unsupported indice dims > 1");
}
int axis = get_node_attr_i(node, "axis", 1) - 1;
if (axis < 0) {
fprintf(stderr, "Unsupported Gather axis: %d\n", axis + 1);
Expand Down
123 changes: 19 additions & 104 deletions backend_ops/ncnn/ops/gather/gather.cpp
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "gather.h"

#include "../ncnn_ops_definer.h"
#include "assert.h"

namespace mmlab {
using namespace ncnn;
Expand All @@ -17,6 +18,11 @@ int Gather::load_param(const ParamDict &pd) {
return 0;
}

// Gather only support 1-dim of indices, because the data and indices all has
// implicit batch in ncnn, this will lead to wrong shape to match onnx result.
// When indices dim equals to 1, after eliminating implicit batch, the indices
// dim still be 1. So there is only 1 implicit batch in data, this will make
// the shape match onnx result.
int Gather::forward(const std::vector<Mat> &bottom_blobs,
std::vector<Mat> &top_blobs, const Option &opt) const {
const Mat &bottom_blob = bottom_blobs[0];
Expand All @@ -26,7 +32,7 @@ int Gather::forward(const std::vector<Mat> &bottom_blobs,
size_t elemsize = bottom_blob.elemsize;
int positive_axis = axis < 0 ? dims + axis : axis;
Mat &top_blob = top_blobs[0];

assert(indices.dims == 1);
const float *indices_ptr = indices;

if (dims == 1 && indices_dims == 1) // positive_axis == 0
Expand All @@ -46,49 +52,6 @@ int Gather::forward(const std::vector<Mat> &bottom_blobs,
return 0;
}

if (dims == 1 && indices_dims == 2) // positive_axis == 0
{
int w = indices.w;
int h = indices.h;
top_blob.create(w, h, elemsize, opt.blob_allocator);
if (top_blob.empty()) {
return -100;
}
const float *ptr = bottom_blob;
float *outptr = top_blob;
for (int j = 0; j < h; j++) {
for (int i = 0; i < w; i++) {
int indice = (int)(indices_ptr[j * w + i] + 0.5);
outptr[j * w + i] = ptr[indice];
}
}
return 0;
}
if (dims == 1 && indices_dims == 3) // positive_axis == 0
{
int c = indices.c;
int w = indices.w;
int h = indices.h;
top_blob.create(c, w, h, elemsize, opt.blob_allocator);
if (top_blob.empty()) {
return -100;
}
const float *ptr = bottom_blob;

for (int page = 0; page < c; page++) {
indices_ptr = indices.channel(page);
float *outptr = top_blob.channel(page);
for (int j = 0; j < h; j++) {
for (int i = 0; i < w; i++) {
int indice = (int)(indices_ptr[j * w + i] + 0.5);
outptr[j * w + i] = ptr[indice];
}
}
}

return 0;
}

if (dims == 2 && positive_axis == 0 && indices_dims == 1) {
int w = bottom_blob.w;
int h = bottom_blob.h;
Expand Down Expand Up @@ -130,51 +93,6 @@ int Gather::forward(const std::vector<Mat> &bottom_blobs,
return 0;
}

if (dims == 2 && positive_axis == 0 && indices_dims == 2) {
int w = bottom_blob.w;
int h = bottom_blob.h;
top_blob.create(w, indices.w, indices.h, elemsize, opt.blob_allocator);

if (top_blob.empty()) {
return -100;
}
const float *ptr = bottom_blob;

for (int k = 0; k < indices.h; k++) {
float *outptr = top_blob.channel(k);
for (int i = 0; i < indices.w; i++) {
for (int j = 0; j < w; j++) {
int selected = (float)(indices_ptr[k * indices.w + i] + 0.5);
outptr[i * w + j] = ptr[selected * w + j];
}
}
}

return 0;
}

if (dims == 2 && positive_axis == 1 && indices_dims == 2) {
int w = bottom_blob.w;
int h = bottom_blob.h;
top_blob.create(h, indices.w, indices.h, elemsize, opt.blob_allocator);

if (top_blob.empty()) {
return -100;
}
const float *ptr = bottom_blob;
for (int k = 0; k < indices.h; k++) {
float *outptr = top_blob.channel(k);
for (int i = 0; i < indices.w; i++) {
for (int j = 0; j < h; j++) {
int selected = (int)(indices_ptr[k * indices.w + i] + 0.5);
outptr[i * h + j] = ptr[j * w + selected];
}
}
}

return 0;
}

if (dims == 3 && positive_axis == 0 && indices_dims == 1) {
int w = bottom_blob.w;
int h = bottom_blob.h;
Expand All @@ -198,14 +116,14 @@ int Gather::forward(const std::vector<Mat> &bottom_blobs,
int w = bottom_blob.w;
int h = bottom_blob.h;
int channels = bottom_blob.c;
top_blob.create(w, channels, indices.w, elemsize, opt.blob_allocator);
top_blob.create(w, indices.w, channels, elemsize, opt.blob_allocator);
#pragma omp parallel for num_threads(opt.num_threads)
// use parallel programming
for (int i = 0; i < indices.w; i++) {
int selected = (int)(indices_ptr[i] + 0.5);
for (int i = 0; i < channels; i++) {
float *outptr = top_blob.channel(i);
for (int j = 0; j < channels; j++) {
const float *ptr = bottom_blob.channel(j);
const float *ptr = bottom_blob.channel(i);
for (int j = 0; j < indices.w; j++) {
int selected = (int)(indices_ptr[j] + 0.5);
for (int k = 0; k < w; k++) {
outptr[j * w + k] = ptr[selected * w + k];
}
Expand All @@ -216,25 +134,22 @@ int Gather::forward(const std::vector<Mat> &bottom_blobs,
}

if (dims == 3 && positive_axis == 2 && indices_dims == 1) {
fprintf(stderr, "gather: dim = 3\n");
int w = bottom_blob.w;
int h = bottom_blob.h;
int channels = bottom_blob.c;
top_blob.create(h, channels, indices.w, elemsize, opt.blob_allocator);
top_blob.create(indices.w, h, channels, elemsize, opt.blob_allocator);
#pragma omp parallel for num_threads(opt.num_threads)
// use parallel programming
for (int i = 0; i < indices.w; i++) {
int selected = (int)(indices_ptr[i] + 0.5);
for (int i = 0; i < channels; i++) {
float *outptr = top_blob.channel(i);
for (int j = 0; j < channels; j++) {
const float *ptr = bottom_blob.channel(j);
for (int k = 0; k < h; k++) {
outptr[j * h + k] = ptr[k * w + selected];
const float *ptr = bottom_blob.channel(i);
for (int j = 0; j < h; j++) {
for (int k = 0; k < indices.w; k++) {
int selected = (int)(indices_ptr[k] + 0.5);
outptr[j * indices.w + k] = ptr[j * w + selected];
}
}
}
fprintf(stderr, "top_blob.size: (%d %d %d)\n", top_blob.c, top_blob.h,
top_blob.w);
return 0;
}

Expand Down
64 changes: 64 additions & 0 deletions tests/test_ops/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,70 @@ def test_constantofshape(backend,
assert_allclose(model_outputs, ncnn_outputs, tolerate_small_mismatch)


@pytest.mark.parametrize('backend', [TEST_NCNN])
@pytest.mark.parametrize('axis, data_dims, indice_dims', [(0, 1, 1), (0, 2, 1),
(1, 2, 1), (0, 3, 1),
(1, 3, 1),
(2, 3, 1)])
def test_gather(backend,
axis,
data_dims,
indice_dims,
input_names=['input', 'indices'],
output_names=['output'],
tolerate_small_mismatch=False,
input_list=None,
save_dir=None):
backend.check_env()

if input_list is None:
# the real data dims is data_dims + 1
data = torch.rand((8, 12, 17)[-data_dims:]).unsqueeze(0)
indice = torch.randint(0, 8, (3, 4, 5)[-indice_dims:]).unsqueeze(0)
else:
data = input_list[0]
indice = input_list[1]
assert data.shape[0] == 1, (f'ncnn batch must be 1, \
but got {data.shape[0]}')
assert indice.shape[0] == 1, (f'ncnn batch must be 1, \
but got {indice.shape[0]}')
cfg = dict()
register_extra_symbolics(cfg=cfg, backend=backend.backend_name, opset=11)

gather_node = make_node('Gather', input_names, output_names, axis=axis + 1)
gather_graph = make_graph([gather_node], 'gather_graph', [
make_tensor_value_info(input_names[0], onnx.TensorProto.FLOAT, None),
make_tensor_value_info(input_names[1], onnx.TensorProto.INT64, None)
], [make_tensor_value_info(output_names[0], onnx.TensorProto.FLOAT, None)])
gather_model = make_model(gather_graph)

ncnn_model = backend.onnx2ncnn(gather_model, 'gather', output_names,
save_dir)

# ncnn mat has implicit batch for mat, the ncnn_output is a mat,
# so the ncnn_outputs has 2 dimensions, not 1.
import onnxruntime
import importlib
assert importlib.util.find_spec('onnxruntime') is not None, 'onnxruntime \
not installed.'

import numpy as np
session = onnxruntime.InferenceSession(gather_model.SerializeToString())
model_outputs = session.run(
output_names,
dict(
zip(input_names, [
np.array(data, dtype=np.float32),
np.array(indice[0], dtype=np.int64)
])))
model_outputs = [model_output for model_output in model_outputs]

ncnn_outputs = ncnn_model(
dict(zip(input_names, [data.float(), indice.float()])))
ncnn_outputs = [ncnn_outputs[name] for name in output_names]
assert_allclose(model_outputs, ncnn_outputs, tolerate_small_mismatch)


@pytest.mark.parametrize('backend', [TEST_NCNN])
@pytest.mark.parametrize('dim', [1, 2, 3])
def test_tensorslice(backend, dim, input_list=None, save_dir=None):
Expand Down
24 changes: 23 additions & 1 deletion tests/test_ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def run_and_validate(self,
output_names=None,
input_names=None,
save_dir=None):
if not save_dir:
if save_dir is None:
onnx_file_path = tempfile.NamedTemporaryFile().name
ncnn_param_path = tempfile.NamedTemporaryFile().name
ncnn_bin_path = tempfile.NamedTemporaryFile().name
Expand Down Expand Up @@ -233,3 +233,25 @@ def run_and_validate(self,
else:
assert_allclose(model_outputs, ncnn_outputs,
tolerate_small_mismatch)

def onnx2ncnn(self, model, model_name, output_names, save_dir=None):
if save_dir is None:
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
ncnn_param_path = tempfile.NamedTemporaryFile(suffix='.param').name
ncnn_bin_path = tempfile.NamedTemporaryFile(suffix='.bin').name
else:
onnx_file_path = os.path.join(save_dir, model_name + '.onnx')
ncnn_param_path = os.path.join(save_dir, model_name + '.param')
ncnn_bin_path = os.path.join(save_dir, model_name + '.bin')

onnx.save_model(model, onnx_file_path)

import mmdeploy.apis.ncnn as ncnn_apis
onnx2ncnn_path = ncnn_apis.get_onnx2ncnn_path()
subprocess.call(
[onnx2ncnn_path, onnx_file_path, ncnn_param_path, ncnn_bin_path])

ncnn_model = ncnn_apis.NCNNWrapper(ncnn_param_path, ncnn_bin_path,
output_names)

return ncnn_model

0 comments on commit 21f2b04

Please sign in to comment.