diff --git a/src/core/tests/frontend/paddle/op_fuzzy.cpp b/src/core/tests/frontend/paddle/op_fuzzy.cpp index 9ae8d39cd9cdef..b8a226e2cc21a9 100644 --- a/src/core/tests/frontend/paddle/op_fuzzy.cpp +++ b/src/core/tests/frontend/paddle/op_fuzzy.cpp @@ -442,6 +442,12 @@ static const std::vector models{ std::string("where_1"), std::string("where_2"), std::string("where_3"), + std::string("where_index_1"), + std::string("where_index_2"), + std::string("where_index_3"), + std::string("where_index_4"), + std::string("where_index_5"), + std::string("where_index_6"), // Temporily disable them until root caused to secure CI stable. // CVS-66703 to track this. // std::string("yolo_box_clip_box"), diff --git a/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_where_index.py b/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_where_index.py new file mode 100644 index 00000000000000..ea2d5d7c995cbc --- /dev/null +++ b/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_where_index.py @@ -0,0 +1,70 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# +# where paddle model generator +# +import numpy as np +from save_model import saveModel +import sys +import paddle + +paddle.enable_static() + + +def where_index(name: str, x, force_boolean=False): + with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): + node_x = paddle.static.data(name='x', shape=x.shape, dtype=x.dtype) + if force_boolean: + node_x_bl = paddle.fluid.layers.cast(node_x, "bool") + out = paddle.nonzero(node_x_bl) + else: + out = paddle.nonzero(node_x) + + cpu = paddle.static.cpu_places(1) + exe = paddle.static.Executor(cpu[0]) + + # startup program will call initializer to initialize the parameters. + exe.run(paddle.static.default_startup_program()) + outs = exe.run( + feed={'x': x}, + fetch_list=[out]) + saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[ + x], outputs=[outs[0]], target_dir=sys.argv[1]) + + return outs[0] + + +def main(): + # case of int32 + datatype = "int32" + condition = np.random.randint(0, 5, size=[5, 8, 2], dtype=datatype) + paddle_out = where_index("where_index_1", condition) + + # case of float32 + datatype = "float32" + condition = (np.random.randint( + 0, 5, size=[8, 3, 2]) * 1.1).astype(datatype) + paddle_out = where_index("where_index_2", condition) + + # case of dimension 4 + condition = (np.random.randint( + 0, 5, size=[8, 3, 2, 6]) * 1.1).astype(datatype) + paddle_out = where_index("where_index_3", condition) + + # case of dimension 5 + condition = (np.random.randint( + 0, 5, size=[4, 6, 8, 2, 5]) * 1.1).astype(datatype) + paddle_out = where_index("where_index_4", condition) + + # case of rank 1 + condition = np.ones(10).astype(datatype) + paddle_out = where_index("where_index_5", condition, force_boolean=True) + + # case of rank 1 and boolean zeros + condition = np.array([1, 0, 1]).astype(datatype) + paddle_out = where_index("where_index_6", condition, force_boolean=True) + + +if __name__ == "__main__": + main() diff --git a/src/frontends/paddle/src/op/where_index.cpp b/src/frontends/paddle/src/op/where_index.cpp new file mode 100644 index 00000000000000..afe248afca6060 --- /dev/null +++ b/src/frontends/paddle/src/op/where_index.cpp @@ -0,0 +1,22 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "default_opset.hpp" +#include "openvino/frontend/paddle/node_context.hpp" + +namespace ov { +namespace frontend { +namespace paddle { +namespace op { +NamedOutputs where_index(const NodeContext& node) { + const auto condition = node.get_input("Condition"); + const auto perm = default_opset::Constant::create(element::i64, Shape{2}, {1, 0}); + const auto out = std::make_shared(condition, element::i64); + return node.default_single_output_mapping({std::make_shared(out, perm)}, {"Out"}); +} + +} // namespace op +} // namespace paddle +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/paddle/src/op_table.cpp b/src/frontends/paddle/src/op_table.cpp index 1aedc80f90aed5..97603c1ddcf91c 100644 --- a/src/frontends/paddle/src/op_table.cpp +++ b/src/frontends/paddle/src/op_table.cpp @@ -97,6 +97,7 @@ OP_CONVERTER(transpose2); OP_CONVERTER(trilinear_interp_v2); OP_CONVERTER(unsqueeze); OP_CONVERTER(where); +OP_CONVERTER(where_index); OP_CONVERTER(yolo_box); OP_CONVERTER(generate_proposals_v2); } // namespace op @@ -198,6 +199,7 @@ std::map get_supported_ops() { {"trilinear_interp_v2", op::trilinear_interp_v2}, {"unsqueeze2", op::unsqueeze}, {"where", op::where}, + {"where_index", op::where_index}, {"yolo_box", op::yolo_box}, {"generate_proposals_v2", op::generate_proposals_v2}}; };