Skip to content

Commit

Permalink
Merge pull request #21 from Microsoft/duli/shape_inference
Browse files Browse the repository at this point in the history
Adding shape inference for Op expand_dims
  • Loading branch information
duli2012 authored Nov 27, 2018
2 parents 408fd21 + 725a262 commit 9c40206
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
34 changes: 32 additions & 2 deletions onnxruntime/contrib_ops/contrib_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
namespace onnxruntime {
namespace contrib {
using ::ONNX_NAMESPACE::AttributeProto;
using ::ONNX_NAMESPACE::OPTIONAL;
using ::ONNX_NAMESPACE::OpSchema;
using ::ONNX_NAMESPACE::OPTIONAL;

void RegisterContribSchemas() {
ONNX_CONTRIB_OPERATOR_SCHEMA(SampleOp)
Expand Down Expand Up @@ -41,7 +41,37 @@ Sample echo operator.)DOC");
"T",
ONNX_NAMESPACE::OpSchema::all_tensor_types(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);

// Shape inference
if (!hasInputShape(ctx, 0))
return;

auto& input_shape = getInputShape(ctx, 0);
const int rank = input_shape.dim_size();
const ONNX_NAMESPACE::TensorProto* axis_initializer = ctx.getInputData(1);
if (!axis_initializer)
return;
const int axis = axis_initializer->int32_data()[0];
if (axis > rank || axis < -rank - 1) {
fail_shape_inference("Input axis is invalid: ", axis);
}
int pos = axis >= 0 ? axis : rank + axis - 1;
ONNX_NAMESPACE::TensorShapeProto output_shape;
for (int i = 0; i < pos; ++i) {
output_shape.add_dim();
*(output_shape.mutable_dim(i)) = input_shape.dim(i);
}
output_shape.add_dim();
output_shape.mutable_dim(pos)->set_dim_value(1);
for (int i = pos + 1; i < rank + 1; ++i) {
output_shape.add_dim();
*(output_shape.mutable_dim(i)) = input_shape.dim(i - 1);
}
updateOutputShape(ctx, 0, output_shape);
})
.SetDoc(R"DOC(ExpandDims echo operator.)DOC");

ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema);
Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/test/contrib_ops/expand_dims_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ namespace test {

TEST(ContribOpTest, ExpandDims_0) {
OpTester test("ExpandDims", 1, onnxruntime::kMSDomain);
test.AddShapeToTensorData(false); // TODO: re-enable shape inference test
test.AddInput<float>("X", {2, 3}, std::vector<float>(6, 1.0f));
test.AddInput<int32_t>("axis", {}, {-1});
test.AddOutput<float>("Y", {2, 3, 1}, std::vector<float>(6, 1.0f));
Expand All @@ -18,7 +17,6 @@ TEST(ContribOpTest, ExpandDims_0) {

TEST(ContribOpTest, ExpandDims_1) {
OpTester test("ExpandDims", 1, onnxruntime::kMSDomain);
test.AddShapeToTensorData(false); // TODO: re-enable shape inference test
test.AddInput<float>("X", {2, 3}, std::vector<float>(6, 1.0f));
test.AddInput<int32_t>("axis", {}, {1});
test.AddOutput<float>("Y", {2, 1, 3}, std::vector<float>(6, 1.0f));
Expand Down

0 comments on commit 9c40206

Please sign in to comment.