Skip to content

Commit

Permalink
enable flip and fix top_k_v2 dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
xczhai committed May 16, 2023
1 parent 9b743a5 commit 7614cbb
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ paddlepaddle >= 2.1
fill_constant
fill_constant_batch_size_like
flatten_contiguous_range
flip
floor
gather
gather_nd
Expand Down
17 changes: 17 additions & 0 deletions src/frontends/paddle/src/op/flip.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "reverse_op.hpp"

namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs flip(const NodeContext& node) {
return reverse_op(node);
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov
22 changes: 2 additions & 20 deletions src/frontends/paddle/src/op/reverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,14 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"
#include "openvino/opsets/opset1.hpp"
#include "reverse_op.hpp"

namespace ov {
namespace frontend {
namespace paddle {
namespace op {

using namespace default_opset;

NamedOutputs reverse(const NodeContext& node) {
auto x = node.get_input("X");
// axis is a vector
auto axis = node.get_attribute<std::vector<int32_t>>("axis");
// try to keep the axis positive since reverse IR doesn't support negative axis.
const auto dims = static_cast<int32_t>(x.get_partial_shape().rank().get_length());
std::for_each(axis.begin(), axis.end(), [&dims](int32_t& value) {
if (value < 0) {
value += dims;
}
});

auto axis_node = std::make_shared<Constant>(ngraph::element::i32, Shape{axis.size()}, axis);
auto reverse_op = std::make_shared<ov::opset1::Reverse>(x, axis_node, ov::opset1::Reverse::Mode::INDEX);
return node.default_single_output_mapping({reverse_op}, {"Out"});
return reverse_op(node);
}
} // namespace op
} // namespace paddle
Expand Down
34 changes: 34 additions & 0 deletions src/frontends/paddle/src/op/reverse_op.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"
#include "openvino/opsets/opset1.hpp"

namespace ov {
namespace frontend {
namespace paddle {
namespace op {
namespace {
NamedOutputs reverse_op(const NodeContext& node) {
auto x = node.get_input("X");
// axis is a vector
auto axis = node.get_attribute<std::vector<int32_t>>("axis");
// try to keep the axis positive since reverse IR doesn't support negative axis.
const auto dims = static_cast<int32_t>(x.get_partial_shape().rank().get_length());
std::for_each(axis.begin(), axis.end(), [&dims](int32_t& value) {
if (value < 0) {
value += dims;
}
});

auto axis_node = std::make_shared<default_opset::Constant>(ngraph::element::i32, Shape{axis.size()}, axis);
auto reverse_node = std::make_shared<ov::opset1::Reverse>(x, axis_node, ov::opset1::Reverse::Mode::INDEX);
return node.default_single_output_mapping({reverse_node}, {"Out"});
}
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov
2 changes: 1 addition & 1 deletion src/frontends/paddle/src/op/top_k_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ NamedOutputs top_k_v2(const NodeContext& node) {
std::string sort_type = sorted ? "value" : "none";
std::string mode = largest ? "max" : "min";

auto node_topk = std::make_shared<default_opset::TopK>(x, k_expected_node, axis, mode, sort_type);
auto node_topk = std::make_shared<default_opset::TopK>(x, k_expected_node, axis, mode, sort_type, element::i64);

NamedOutputs named_outputs;
named_outputs["Out"] = OutputVector{node_topk->output(0)};
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/paddle/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ OP_CONVERTER(fill_any_like);
OP_CONVERTER(fill_constant_batch_size_like);
OP_CONVERTER(fill_constant);
OP_CONVERTER(flatten_contiguous_range);
OP_CONVERTER(flip);
OP_CONVERTER(floor);
OP_CONVERTER(gather);
OP_CONVERTER(gather_nd);
Expand Down Expand Up @@ -160,6 +161,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"fill_any_like", op::fill_any_like},
{"fill_constant_batch_size_like", op::fill_constant_batch_size_like},
{"fill_constant", op::fill_constant},
{"flip", op::flip},
{"flatten_contiguous_range", op::flatten_contiguous_range},
{"floor", op::floor},
{"gather", op::gather},
Expand Down
8 changes: 8 additions & 0 deletions src/frontends/paddle/tests/op_fuzzy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ static const std::vector<std::string> models{
std::string("fill_constant_shape_tensor"),
std::string("fill_constant_shape_tensor_list"),
std::string("flatten_contiguous_range_test1"),
std::string("flip_static_1"),
std::string("flip_static_2"),
std::string("flip_static_3"),
std::string("flip_static_4"),
std::string("flip_dynamic_1"),
std::string("flip_dynamic_2"),
std::string("flip_dynamic_3"),
std::string("flip_dynamic_4"),
std::string("floor_float32"),
std::string("floor_mod1"),
std::string("floor_mod2"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

#
# flip paddle model generator
#
import numpy as np
from save_model import saveModel
import sys
import paddle

def flip(name: str, x, axis, use_static=True, dtype="float32"):
paddle.enable_static()

with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
if use_static:
node_x = paddle.static.data(name='x', shape=x.shape, dtype=dtype)
else:
node_x = paddle.fluid.data(name='x', shape=[1, 1, -1, -1], dtype=dtype)
out = paddle.flip(node_x, axis)

cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(paddle.static.default_startup_program())

outs = exe.run(
feed={'x': x},
fetch_list=[out])

saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])

return outs[0]

def main():
data1 = np.array([0,1,2], dtype='int32')
flip("flip_static_1", data1, 0, True, 'int32')

data2 = np.array([[0,1,2], [3,4,5], [6,7,8]], dtype='float32')
flip("flip_static_2", data2, 1, True, 'float32')

data3 = np.array([[0,1,2], [3,4,5], [6,7,8]], dtype='float32')
flip("flip_static_3", data3, [0, 1], True, 'float32')

data4 = np.array([[0,1,2], [3,4,5], [6,7,8]], dtype='int32')
flip("flip_static_4", data4, -1, True, 'int32')

data5 = np.random.randn(1, 1, 32, 32).astype('int32')
flip("flip_dynamic_1", data5, [2], False, dtype='int32')

data6 = np.random.randn(1, 1, 64, 64).astype('float32')
flip("flip_dynamic_2", data6, [3], False, dtype='float32')

data7 = np.random.randn(1, 1, 112, 112).astype('float32')
flip("flip_dynamic_3", data7, [2,3], False, dtype='float32')

data8 = np.random.randn(1, 1, 224, 224).astype('int32')
flip("flip_dynamic_4", data8, [-2, -1], False, dtype='int32')

if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def reverse(name: str, x, axis, use_static=True, dtype="float32"):
return outs[0]

def main():
data1 = np.array([0,2], dtype='int64')
reverse("reverse_static_1", data1, 0, True, 'int64')
data1 = np.array([0,2], dtype='int32')
reverse("reverse_static_1", data1, 0, True, 'int32')

data2 = np.array([[0,1,2], [3,4,5], [6,7,8]], dtype='float32')
reverse("reverse_static_2", data2, 1, True, 'float32')
Expand Down

0 comments on commit 7614cbb

Please sign in to comment.