Skip to content

Commit

Permalink
support paddle elementwise_floordiv
Browse files Browse the repository at this point in the history
  • Loading branch information
taixiurong committed Sep 15, 2022
1 parent 1b7237d commit 75d39a3
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/core/tests/frontend/paddle/op_fuzzy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,27 +100,35 @@ static const std::vector<std::string> models{
std::string("dropout_upscale_in_train"),
std::string("elementwise_add1"),
std::string("elementwise_div1"),
std::string("elementwise_floordiv_int32_1"),
std::string("elementwise_floordiv_int64_1"),
std::string("elementwise_max1"),
std::string("elementwise_min1"),
std::string("elementwise_mul1"),
std::string("elementwise_pow1"),
std::string("elementwise_sub1"),
std::string("elementwise_add2"),
std::string("elementwise_div2"),
std::string("elementwise_floordiv_int32_2"),
std::string("elementwise_floordiv_int64_2"),
std::string("elementwise_max2"),
std::string("elementwise_min2"),
std::string("elementwise_mul2"),
std::string("elementwise_pow2"),
std::string("elementwise_sub2"),
std::string("elementwise_add3"),
std::string("elementwise_div3"),
std::string("elementwise_floordiv_int32_3"),
std::string("elementwise_floordiv_int64_3"),
std::string("elementwise_max3"),
std::string("elementwise_min3"),
std::string("elementwise_mul3"),
std::string("elementwise_pow3"),
std::string("elementwise_sub3"),
std::string("elementwise_add4"),
std::string("elementwise_div4"),
std::string("elementwise_floordiv_int32_4"),
std::string("elementwise_floordiv_int64_4"),
std::string("elementwise_max4"),
std::string("elementwise_min4"),
std::string("elementwise_mul4"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,27 @@ def elementwise_pow(name : str, x, y, axis, in_dtype):

return outs[0]

def elementwise_floordiv(name : str, x, y, axis, in_dtype):
import paddle
paddle.enable_static()

with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
node_x = paddle.static.data(name = 'x', shape = x.shape, dtype = in_dtype)
node_y = paddle.static.data(name = 'y', shape = y.shape, dtype = in_dtype)
out = paddle.fluid.layers.nn.elementwise_floordiv(node_x, node_y, axis=axis)

cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])

# startup program will call initializer to initialize the parameters.
exe.run(paddle.static.default_startup_program())
outs = exe.run(
feed={'x': x, 'y': y},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x', 'y'], fetchlist=[out], inputs=[x, y], outputs=[outs[0]], target_dir=sys.argv[1])

return outs[0]

def elementwise_ops(name : str, data_x, data_y, axis, in_dtype):
elementwise_add("elementwise_add" + name, data_x, data_y, axis, in_dtype)
elementwise_sub("elementwise_sub" + name, data_x, data_y, axis, in_dtype)
Expand Down Expand Up @@ -193,5 +214,33 @@ def main():
axis = 0
elementwise_ops("4", data_x, data_y, axis, in_dtype)

# test for elementwise_floordiv, support int and int64
floordiv_support_dtype = ['int64', 'int32']
in_dtype_int32 = 'int32'
data_x = np.array([-2, 0, 4])
data_y = np.array([1, 5, 2])
axis = -1
for dtype in floordiv_support_dtype:
elementwise_floordiv("elementwise_floordiv" + "_" + dtype + "_1",
data_x.astype(dtype), data_y.astype(dtype), axis, dtype)

data_x = np.random.randint(1, 10, [2, 5, 3, 4])
data_y = np.random.randint(1, 5, [3, 4])
for dtype in floordiv_support_dtype:
elementwise_floordiv("elementwise_floordiv" + "_" + dtype + "_2",
data_x.astype(dtype), data_y.astype(dtype), axis, dtype)

data_y = np.random.randint(1, 5, [5])
axis = 1
for dtype in floordiv_support_dtype:
elementwise_floordiv("elementwise_floordiv" + "_" + dtype + "_3",
data_x.astype(dtype), data_y.astype(dtype), axis, dtype)

data_y = np.random.randint(1, 5, [2, 5, 3])
axis = 0
for dtype in floordiv_support_dtype:
elementwise_floordiv("elementwise_floordiv" + "_" + dtype + "_4",
data_x.astype(dtype), data_y.astype(dtype), axis, dtype)

if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions src/frontends/paddle/src/op/elementwise_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ NamedOutputs elementwise_greater_equal(const NodeContext& node_context) {
return elementwise_ops<default_opset::GreaterEqual>(node_context);
}

NamedOutputs elementwise_floordiv(const NodeContext& node_context) {
auto x = node_context.get_input("X");
auto y = node_context.get_input("Y");
auto axis = node_context.get_attribute<int>("axis");
return node_context.default_single_output_mapping({std::make_shared<default_opset::Divide>(x, y, ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::PDPD, axis))}, {"Out"});
}

} // namespace op
} // namespace paddle
} // namespace frontend
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/paddle/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ OP_CONVERTER(dropout);
OP_CONVERTER(elementwise_add);
OP_CONVERTER(elementwise_div);
OP_CONVERTER(elementwise_equal);
OP_CONVERTER(elementwise_floordiv);
OP_CONVERTER(elementwise_greater_equal);
OP_CONVERTER(elementwise_max);
OP_CONVERTER(elementwise_min);
Expand Down Expand Up @@ -121,6 +122,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"dropout", op::dropout},
{"elementwise_add", op::elementwise_add},
{"elementwise_div", op::elementwise_div},
{"elementwise_floordiv", op::elementwise_floordiv},
{"elementwise_max", op::elementwise_max},
{"elementwise_min", op::elementwise_min},
{"elementwise_mul", op::elementwise_mul},
Expand Down

0 comments on commit 75d39a3

Please sign in to comment.