diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b3a8ebad13ed..8a0a00fcc6edc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,10 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) option(WITH_TESTING "Compile with Unittests" ON) option(WITH_TESTS "Compile with tests" ON) +if (WITH_TESTING) + add_definitions(-DCINN_WITH_TEST) +endif() + # include the customized configures include(${CMAKE_BINARY_DIR}/config.cmake) diff --git a/cinn/pybind/bind_utils.h b/cinn/pybind/bind_utils.h index 88b2ca38de1c3..2ab2c51b3dd42 100644 --- a/cinn/pybind/bind_utils.h +++ b/cinn/pybind/bind_utils.h @@ -111,25 +111,6 @@ void DefineExprNode(py::module *m, std::string_view node_name) { .def("node_type", &ExprNodeT::node_type); } -template -void DefineExprNode1(py::module *m, T *node, std::string_view node_name) { - using ExprNodeT = ExprNode::type>; - std::string prefix{"ExprNode"}; - std::string name = prefix + std::string(node_name); - py::class_ expr_node(*m, name.c_str()); - - expr_node.def(py::init<>()) - .def(py::init()) - .def(py::init()) - .def("accept", &ExprNodeT::Accept) - .def("operands_mutable", py::overload_cast<>(&ExprNodeT::operands)) - .def("operands_const", py::overload_cast<>(&ExprNodeT::operands, py::const_)) - .def("operand_mutable", py::overload_cast(&ExprNodeT::operand), py::return_value_policy::reference) - .def("operand_const", py::overload_cast(&ExprNodeT::operand, py::const_), py::return_value_policy::reference) - .def("copy", &ExprNodeT::Copy) - .def("node_type", &ExprNodeT::node_type); -} - template void DefineBinaryOpNode(py::module *m, std::string_view node_name) { DefineExprNode(m, node_name); @@ -148,31 +129,6 @@ void DefineBinaryOpNode(py::module *m, std::string_view node_name) { .def("expr_fields_const", py::overload_cast<>(&BinaryOpNodeT::expr_fields, py::const_)); } -template -void DefineBinaryOpNode1(py::module *m, T *node, std::string_view node_name) { - using NodeType = typename std::decay_t::type; - using BinaryOpNodeT = ir::BinaryOpNode; - - // DefineExprNode(m, node_name); - - if constexpr (std::is_same_v) { - node->def("is_constant", &ir::FracOp::is_constant).def("get_constant", &ir::FracOp::get_constant); - } - - std::string prefix{"BinaryOpNode"}; - std::string name = prefix + std::string(node_name); - py::class_> binary_op_node(*m, name.c_str()); - binary_op_node.def(py::init<>()) - .def(py::init()) - .def("a_mutable", py::overload_cast<>(&BinaryOpNodeT::a), py::return_value_policy::reference) - .def("a_const", py::overload_cast<>(&BinaryOpNodeT::a, py::const_), py::return_value_policy::reference) - .def("b_mutable", py::overload_cast<>(&BinaryOpNodeT::b), py::return_value_policy::reference) - .def("b_const", py::overload_cast<>(&BinaryOpNodeT::b, py::const_), py::return_value_policy::reference) - .def("type", &BinaryOpNodeT::type) - .def("expr_fields_mutable", py::overload_cast<>(&BinaryOpNodeT::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&BinaryOpNodeT::expr_fields, py::const_)); -} - template void DefineUnaryOpNode(py::module *m, std::string_view node_name) { using UnaryOpNodeT = ir::UnaryOpNode; diff --git a/cinn/pybind/ir.cc b/cinn/pybind/ir.cc index a80354a98cf72..4eb619f0dece5 100644 --- a/cinn/pybind/ir.cc +++ b/cinn/pybind/ir.cc @@ -592,7 +592,9 @@ void BindRegistry(py::module *m) { //.def("set_body", // py::overload_cast(&ir::Registry::SetBody), // py::return_value_policy::reference); - ir::Registry::Register("test_add").SetBody([](ir::Args args, ir::RetValue *rv) { + +#ifdef CINN_WITH_TEST + ir::Registry::Register("test_add_int64").SetBody([](ir::Args args, ir::RetValue *rv) { int64_t x = args[0]; int64_t y = args[1]; *rv = x + y; @@ -604,10 +606,12 @@ void BindRegistry(py::module *m) { *rv = x + y; }); - // ir::Registry::Register("test_callback").SetBody([](ir::Args args, ir::RetValue *rv) { - // ir::PackedFunc f = args[0]; - // f("hello, cinn"); - // }); + ir::Registry::Register("test_mul_float").SetBody([](ir::Args args, ir::RetValue *rv) { + float x = args[0]; + float y = args[1]; + *rv = x * y; + }); +#endif } } // namespace diff --git a/python/tests/test_packed_func.py b/python/tests/test_packed_func.py index 8a6312f9d3b62..f48db5542b0ff 100755 --- a/python/tests/test_packed_func.py +++ b/python/tests/test_packed_func.py @@ -40,6 +40,20 @@ def __call__(self, *args): accumulate = ir.register_packed_func("accumulate_float")(Accumulator(1.0)) self.assertTrue(isclose(accumulate(1., 2., 3., 4.), 11.)) + def test_cxx_register(self): + add_int = ir.Registry.get("test_add_int64") + self.assertEqual(add_int(2, 3), 5) + + add_expr = ir.Registry.get("test_add_expr") + x = ir.Expr(1) + y = ir.Expr(2) + z = x + y + r = add_expr(x, y) + self.assertEqual(r.node_type(), z.node_type()) + + mul_float = ir.Registry.get("test_mul_float") + self.assertTrue(isclose(mul_float(2.4, 2.5), 6.0, abs_tol=1e-5)) + if __name__ == "__main__": unittest.main()