Skip to content

Commit

Permalink
add packedfunc test (PaddlePaddle#139)
Browse files Browse the repository at this point in the history
* add C++ packedfunc test
  • Loading branch information
fc500110 authored Jul 31, 2020
1 parent 5238ff1 commit 397c9d6
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 49 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
option(WITH_TESTING "Compile with Unittests" ON)
option(WITH_TESTS "Compile with tests" ON)

if (WITH_TESTING)
add_definitions(-DCINN_WITH_TEST)
endif()

# include the customized configures
include(${CMAKE_BINARY_DIR}/config.cmake)

Expand Down
44 changes: 0 additions & 44 deletions cinn/pybind/bind_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,6 @@ void DefineExprNode(py::module *m, std::string_view node_name) {
.def("node_type", &ExprNodeT::node_type);
}

template <typename T>
void DefineExprNode1(py::module *m, T *node, std::string_view node_name) {
using ExprNodeT = ExprNode<typename std::decay_t<decltype(*node)>::type>;
std::string prefix{"ExprNode"};
std::string name = prefix + std::string(node_name);
py::class_<ExprNodeT, ir::IrNode> expr_node(*m, name.c_str());

expr_node.def(py::init<>())
.def(py::init<Type>())
.def(py::init<int>())
.def("accept", &ExprNodeT::Accept)
.def("operands_mutable", py::overload_cast<>(&ExprNodeT::operands))
.def("operands_const", py::overload_cast<>(&ExprNodeT::operands, py::const_))
.def("operand_mutable", py::overload_cast<int>(&ExprNodeT::operand), py::return_value_policy::reference)
.def("operand_const", py::overload_cast<int>(&ExprNodeT::operand, py::const_), py::return_value_policy::reference)
.def("copy", &ExprNodeT::Copy)
.def("node_type", &ExprNodeT::node_type);
}

template <typename NodeType>
void DefineBinaryOpNode(py::module *m, std::string_view node_name) {
DefineExprNode<NodeType>(m, node_name);
Expand All @@ -148,31 +129,6 @@ void DefineBinaryOpNode(py::module *m, std::string_view node_name) {
.def("expr_fields_const", py::overload_cast<>(&BinaryOpNodeT::expr_fields, py::const_));
}

template <typename T>
void DefineBinaryOpNode1(py::module *m, T *node, std::string_view node_name) {
using NodeType = typename std::decay_t<decltype(*node)>::type;
using BinaryOpNodeT = ir::BinaryOpNode<NodeType>;

// DefineExprNode<NodeType>(m, node_name);

if constexpr (std::is_same_v<T, ir::FracOp>) {
node->def("is_constant", &ir::FracOp::is_constant).def("get_constant", &ir::FracOp::get_constant);
}

std::string prefix{"BinaryOpNode"};
std::string name = prefix + std::string(node_name);
py::class_<BinaryOpNodeT, ir::ExprNode<NodeType>> binary_op_node(*m, name.c_str());
binary_op_node.def(py::init<>())
.def(py::init<Type, Expr, Expr>())
.def("a_mutable", py::overload_cast<>(&BinaryOpNodeT::a), py::return_value_policy::reference)
.def("a_const", py::overload_cast<>(&BinaryOpNodeT::a, py::const_), py::return_value_policy::reference)
.def("b_mutable", py::overload_cast<>(&BinaryOpNodeT::b), py::return_value_policy::reference)
.def("b_const", py::overload_cast<>(&BinaryOpNodeT::b, py::const_), py::return_value_policy::reference)
.def("type", &BinaryOpNodeT::type)
.def("expr_fields_mutable", py::overload_cast<>(&BinaryOpNodeT::expr_fields))
.def("expr_fields_const", py::overload_cast<>(&BinaryOpNodeT::expr_fields, py::const_));
}

template <typename NodeType>
void DefineUnaryOpNode(py::module *m, std::string_view node_name) {
using UnaryOpNodeT = ir::UnaryOpNode<NodeType>;
Expand Down
14 changes: 9 additions & 5 deletions cinn/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,9 @@ void BindRegistry(py::module *m) {
//.def("set_body",
// py::overload_cast<ir::PackedFunc::body_t>(&ir::Registry::SetBody),
// py::return_value_policy::reference);
ir::Registry::Register("test_add").SetBody([](ir::Args args, ir::RetValue *rv) {

#ifdef CINN_WITH_TEST
ir::Registry::Register("test_add_int64").SetBody([](ir::Args args, ir::RetValue *rv) {
int64_t x = args[0];
int64_t y = args[1];
*rv = x + y;
Expand All @@ -604,10 +606,12 @@ void BindRegistry(py::module *m) {
*rv = x + y;
});

// ir::Registry::Register("test_callback").SetBody([](ir::Args args, ir::RetValue *rv) {
// ir::PackedFunc f = args[0];
// f("hello, cinn");
// });
ir::Registry::Register("test_mul_float").SetBody([](ir::Args args, ir::RetValue *rv) {
float x = args[0];
float y = args[1];
*rv = x * y;
});
#endif
}
} // namespace

Expand Down
14 changes: 14 additions & 0 deletions python/tests/test_packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ def __call__(self, *args):
accumulate = ir.register_packed_func("accumulate_float")(Accumulator(1.0))
self.assertTrue(isclose(accumulate(1., 2., 3., 4.), 11.))

def test_cxx_register(self):
add_int = ir.Registry.get("test_add_int64")
self.assertEqual(add_int(2, 3), 5)

add_expr = ir.Registry.get("test_add_expr")
x = ir.Expr(1)
y = ir.Expr(2)
z = x + y
r = add_expr(x, y)
self.assertEqual(r.node_type(), z.node_type())

mul_float = ir.Registry.get("test_mul_float")
self.assertTrue(isclose(mul_float(2.4, 2.5), 6.0, abs_tol=1e-5))


if __name__ == "__main__":
unittest.main()

0 comments on commit 397c9d6

Please sign in to comment.