diff --git a/onnxruntime/contrib_ops/contrib_ops.cc b/onnxruntime/contrib_ops/contrib_ops.cc index d698cdcbd3d5..f6151690e7d7 100644 --- a/onnxruntime/contrib_ops/contrib_ops.cc +++ b/onnxruntime/contrib_ops/contrib_ops.cc @@ -12,8 +12,8 @@ namespace onnxruntime { namespace contrib { using ::ONNX_NAMESPACE::AttributeProto; -using ::ONNX_NAMESPACE::OPTIONAL; using ::ONNX_NAMESPACE::OpSchema; +using ::ONNX_NAMESPACE::OPTIONAL; void RegisterContribSchemas() { ONNX_CONTRIB_OPERATOR_SCHEMA(SampleOp) @@ -41,7 +41,37 @@ Sample echo operator.)DOC"); "T", ONNX_NAMESPACE::OpSchema::all_tensor_types(), "Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput) + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // Shape inference + if (!hasInputShape(ctx, 0)) + return; + + auto& input_shape = getInputShape(ctx, 0); + const int rank = input_shape.dim_size(); + const ONNX_NAMESPACE::TensorProto* axis_initializer = ctx.getInputData(1); + if (!axis_initializer) + return; + const int axis = axis_initializer->int32_data()[0]; + if (axis > rank || axis < -rank - 1) { + fail_shape_inference("Input axis is invalid: ", axis); + } + int pos = axis >= 0 ? axis : rank + axis - 1; + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (int i = 0; i < pos; ++i) { + output_shape.add_dim(); + *(output_shape.mutable_dim(i)) = input_shape.dim(i); + } + output_shape.add_dim(); + output_shape.mutable_dim(pos)->set_dim_value(1); + for (int i = pos + 1; i < rank + 1; ++i) { + output_shape.add_dim(); + *(output_shape.mutable_dim(i)) = input_shape.dim(i - 1); + } + updateOutputShape(ctx, 0, output_shape); + }) .SetDoc(R"DOC(ExpandDims echo operator.)DOC"); ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema); diff --git a/onnxruntime/test/contrib_ops/expand_dims_test.cc b/onnxruntime/test/contrib_ops/expand_dims_test.cc index df34a5eb8d30..d1239ad33292 100644 --- a/onnxruntime/test/contrib_ops/expand_dims_test.cc +++ b/onnxruntime/test/contrib_ops/expand_dims_test.cc @@ -9,7 +9,6 @@ namespace test { TEST(ContribOpTest, ExpandDims_0) { OpTester test("ExpandDims", 1, onnxruntime::kMSDomain); - test.AddShapeToTensorData(false); // TODO: re-enable shape inference test test.AddInput("X", {2, 3}, std::vector(6, 1.0f)); test.AddInput("axis", {}, {-1}); test.AddOutput("Y", {2, 3, 1}, std::vector(6, 1.0f)); @@ -18,7 +17,6 @@ TEST(ContribOpTest, ExpandDims_0) { TEST(ContribOpTest, ExpandDims_1) { OpTester test("ExpandDims", 1, onnxruntime::kMSDomain); - test.AddShapeToTensorData(false); // TODO: re-enable shape inference test test.AddInput("X", {2, 3}, std::vector(6, 1.0f)); test.AddInput("axis", {}, {1}); test.AddOutput("Y", {2, 1, 3}, std::vector(6, 1.0f));