Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Error Message No. 35 【BUAA】 】reciprocal.cc & ast_gen.cc& poly_scheduler.cc #66847

Merged
merged 5 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 86 additions & 44 deletions paddle/cinn/hlir/op/contrib/reciprocal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,40 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(

framework::CINNCompute reciprocal_compute(
[=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of " << op_name
<< " compute is empty! Please check.\n";
PADDLE_ENFORCE_NE(
args.empty(),
true,
phi::errors::InvalidArgument(
"The input argument of %s compute is empty! Please check.",
op_name));
CINNValuePack pack_args = args[0];
CHECK(!pack_args.empty())
<< "at least one input tensor for " << op_name << " compute\n";
PADDLE_ENFORCE_NE(
pack_args.empty(),
true,
phi::errors::InvalidArgument(
"At least one input tensor for %s compute.", op_name));
PADDLE_ENFORCE_EQ(pack_args.size(),
2,
phi::errors::InvalidArgument(
"The input argument's size of reciprocal op "
"should be 2."));
CHECK(pack_args[1].is_string());
PADDLE_ENFORCE_EQ(
pack_args[1].is_string(),
true,
phi::errors::InvalidArgument(
"Required pack_args[1] must be a string. Please check."));
std::string tensor_name = pack_args[1].operator std::string();

Expr A = pack_args[0];
CHECK(A.as_tensor());
CHECK(!output_shapes.empty());
PADDLE_ENFORCE_NOT_NULL(
A.as_tensor(),
phi::errors::InvalidArgument(
"Required Input must be a tensor. Please check."));
PADDLE_ENFORCE_NE(
output_shapes.empty(),
true,
phi::errors::InvalidArgument(
"The output shape of reciprocal is empty! Please check."));
auto tensor_A = A.as_tensor_ref();
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
Expand All @@ -115,8 +133,11 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
ir::Tensor out = Reciprocal(tensor_A, tensor_name);
std::vector<CINNValue> res;
res.push_back(CINNValue(out));
CHECK(!out_type.empty())
<< "Output type of Reciprocal is empty! Please check.\n";
PADDLE_ENFORCE_NE(
out_type.empty(),
true,
phi::errors::InvalidArgument(
"The output type of Reciprocal is empty! Please check."));
*ret = CINNValuePack{res};
});

Expand All @@ -136,41 +157,62 @@ std::shared_ptr<OpStrategy> StrategyForReciprocalSymbolic(
const Target &target) {
std::string op_name("reciprocal");

framework::CINNCompute reciprocal_compute(
[=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of " << op_name
<< " compute is empty! Please check.\n";
CINNValuePack pack_args = args[0];
CHECK(!pack_args.empty())
<< "at least one input tensor for " << op_name << " compute\n";
PADDLE_ENFORCE_EQ(pack_args.size(),
2,
phi::errors::InvalidArgument(
"The input argument's size of reciprocal op "
"should be 2."));
CHECK(pack_args[1].is_string());
std::string tensor_name = pack_args[1].operator std::string();

Expr A = pack_args[0];
CHECK(A.as_tensor());
CHECK(!output_shapes.empty());
auto tensor_A = A.as_tensor_ref();
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
PADDLE_ENFORCE_EQ(pack_args.size(),
2U,
phi::errors::InvalidArgument(
"The input argument's size of reciprocal op "
"should be 2."));
tensor_name = pack_args[1].operator std::string();

ir::Tensor out = Reciprocal(tensor_A, tensor_name);
std::vector<CINNValue> res;
res.push_back(CINNValue(out));
CHECK(!out_type.empty())
<< "Output type of Reciprocal is empty! Please check.\n";
*ret = CINNValuePack{res};
});
framework::CINNCompute reciprocal_compute([=](lang::Args args,
lang::RetValue *ret) {
PADDLE_ENFORCE_NE(
args.empty(),
true,
phi::errors::InvalidArgument(
"The input argument of %s compute is empty! Please check.",
op_name));
CINNValuePack pack_args = args[0];
PADDLE_ENFORCE_NE(
pack_args.empty(),
true,
phi::errors::InvalidArgument(
"At least one input tensor for %s compute.", op_name));
PADDLE_ENFORCE_EQ(pack_args.size(),
2,
phi::errors::InvalidArgument(
"The input argument's size of reciprocal op "
"should be 2."));
PADDLE_ENFORCE_EQ(
pack_args[1].is_string(),
true,
phi::errors::InvalidArgument(
"Required pack_args[1] must be a string. Please check."));
std::string tensor_name = pack_args[1].operator std::string();

Expr A = pack_args[0];
PADDLE_ENFORCE_NOT_NULL(
A.as_tensor(),
phi::errors::InvalidArgument(
"Required Input must be a tensor. Please check."));
PADDLE_ENFORCE_NE(
output_shapes.empty(),
true,
phi::errors::InvalidArgument(
"The output shape of reciprocal_compute is empty! Please check."));
auto tensor_A = A.as_tensor_ref();
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
PADDLE_ENFORCE_EQ(pack_args.size(),
2U,
phi::errors::InvalidArgument(
"The input argument's size of reciprocal op "
"should be 2."));
tensor_name = pack_args[1].operator std::string();

ir::Tensor out = Reciprocal(tensor_A, tensor_name);
std::vector<CINNValue> res;
res.push_back(CINNValue(out));
PADDLE_ENFORCE_NE(
out_type.empty(),
true,
phi::errors::InvalidArgument(
"The output type of Reciprocal is empty! Please check."));
*ret = CINNValuePack{res};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
Expand Down
72 changes: 53 additions & 19 deletions paddle/cinn/poly/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ struct AstGen::Impl {
isl::union_set AstGen::domain() const { return impl_->domain(); }

isl::union_set AstGen::Impl::domain() const {
CHECK(!stages_.empty());
PADDLE_ENFORCE_NE(stages_.empty(),
true,
phi::errors::InvalidArgument(
"Stages vector is empty in AstGen::Impl::domain()."));
auto sets = utils::Map<std::vector<Shared<Stage>>, isl::set>(
stages_, [](const Shared<Stage>& e) { return e->domain(); });
return isl_sets_to_union_set(sets);
Expand All @@ -79,7 +82,10 @@ isl::union_set AstGen::Impl::domain() const {
isl::ctx AstGen::ctx() const { return impl_->ctx(); }

isl::ctx AstGen::Impl::ctx() const {
CHECK(!stages_.empty());
PADDLE_ENFORCE_NE(stages_.empty(),
true,
phi::errors::InvalidArgument(
"Stages vector is empty in AstGen::Impl::ctx()."));
return stages_.front()->domain().ctx();
}

Expand Down Expand Up @@ -153,8 +159,10 @@ isl::ast_node AstGen::Build() {
std::vector<isl::map> maps;
for (auto& stage : impl_->stages_) {
auto it = schedule_map.find(stage->id());
CHECK(it != std::end(schedule_map))
<< "stage " << stage->id() << " not found in the map";
PADDLE_ENFORCE_EQ(it != std::end(schedule_map),
true,
phi::errors::InvalidArgument(
"Stage %s not found in the map.", stage->id()));
maps.push_back(it->second);
}
auto schedule = isl_maps_to_union_map(maps);
Expand Down Expand Up @@ -188,7 +196,11 @@ isl::ast_node AstGen::Build() {
impl_->stages_.begin(),
impl_->stages_.end(),
[&name](const Shared<Stage>& ele) { return ele->id() == name; });
CHECK(ele_it != std::end(impl_->stages_));
PADDLE_ENFORCE_EQ(
ele_it != std::end(impl_->stages_),
true,
phi::errors::InvalidArgument(
"Stage with name %s not found in the stages vector.", name));
return (*ele_it)->domain();
};

Expand Down Expand Up @@ -255,7 +267,10 @@ AstGen::Impl::ExtractIslTransformedIndiceMap(const isl::set& iterator_domain,
const std::map<std::string, isl::ast_expr>& AstGen::axis2ast(
const std::string& tuple_name) const {
auto it = impl_->transformed_indice_map_.find(tuple_name);
CHECK(it != impl_->transformed_indice_map_.end()) << "no id " << tuple_name;
PADDLE_ENFORCE_EQ(it != impl_->transformed_indice_map_.end(),
true,
phi::errors::InvalidArgument(
"No id named %s, please check.", tuple_name));
return it->second;
}

Expand All @@ -273,7 +288,10 @@ const std::map<std::string, Expr> AstGen::axis2expr(

isl::ast_expr CreateIslAstIndexExpression(isl_ast_build* build,
const isl::map& access) {
CHECK(build);
PADDLE_ENFORCE_NOT_NULL(
build,
phi::errors::InvalidArgument(
"The isl_ast_build pointer is null in CreateIslAstIndexExpression."));
isl::map schedule =
isl::manage(isl_map_from_union_map(isl_ast_build_get_schedule(build)));

Expand Down Expand Up @@ -335,8 +353,14 @@ void EatIf(const isl::ast_node& node, ir::Expr* expr);
void EatMark(const isl::ast_node& node, ir::Expr* expr);

void IslAstNodeToCinnExpr(const isl::ast_node& node, ir::Expr* expr) {
CHECK(!node.is_null());
CHECK(expr);
PADDLE_ENFORCE_EQ(!node.is_null(),
true,
phi::errors::InvalidArgument(
"The isl::ast_node is null in IslAstNodeToCinnExpr."));
PADDLE_ENFORCE_NOT_NULL(
expr,
phi::errors::InvalidArgument(
"The ir::Expr pointer is null in IslAstNodeToCinnExpr."));

switch (isl_ast_node_get_type(node.get())) {
case isl_ast_node_block: {
Expand All @@ -362,19 +386,24 @@ void IslAstNodeToCinnExpr(const isl::ast_node& node, ir::Expr* expr) {
default:
std::stringstream ss;
ss << "Unexpected ISL node type " << isl_ast_node_get_type(node.get());
PADDLE_THROW(::common::errors::InvalidArgument(ss.str()));
PADDLE_THROW(::phi::errors::InvalidArgument(ss.str()));
break;
}
}

// Eat an isl block node.
void EatBlock(const isl::ast_node& node, ir::Expr* expr) {
VLOG(2) << "get isl ast body node";
CHECK(!node.is_null());
CHECK(expr);
PADDLE_ENFORCE_EQ(
!node.is_null(),
true,
phi::errors::InvalidArgument("The isl::ast_node is null in EatBlock."));
PADDLE_ENFORCE_NOT_NULL(expr,
phi::errors::InvalidArgument(
"The ir::Expr pointer is null in EatBlock."));
PADDLE_ENFORCE_EQ(isl_ast_node_get_type(node.get()),
isl_ast_node_block,
::common::errors::InvalidArgument(
::phi::errors::InvalidArgument(
"The node type should be isl_ast_node_block"));
isl::ast_node_list list =
isl::manage(isl_ast_node_block_get_children(node.get()));
Expand All @@ -393,7 +422,7 @@ void EatBlock(const isl::ast_node& node, ir::Expr* expr) {
void EatUser(const isl::ast_node& node, ir::Expr* expr) {
PADDLE_ENFORCE_EQ(isl_ast_node_get_type(node.get()),
isl_ast_node_user,
::common::errors::InvalidArgument(
::phi::errors::InvalidArgument(
"The node type should be isl_ast_node_user"));
isl::ast_expr isl_expr = isl::manage(isl_ast_node_user_get_expr(node.get()));
IslAstExprToCinnExpr(isl_expr, expr);
Expand All @@ -402,7 +431,7 @@ void EatUser(const isl::ast_node& node, ir::Expr* expr) {
void EatFor(const isl::ast_node& node, ir::Expr* expr) {
PADDLE_ENFORCE_EQ(isl_ast_node_get_type(node.get()),
isl_ast_node_for,
::common::errors::InvalidArgument(
::phi::errors::InvalidArgument(
"The node type should be isl_ast_node_for"));

// iter name
Expand Down Expand Up @@ -448,7 +477,7 @@ void EatFor(const isl::ast_node& node, ir::Expr* expr) {
void EatIf(const isl::ast_node& node, ir::Expr* expr) {
PADDLE_ENFORCE_EQ(isl_ast_node_get_type(node.get()),
isl_ast_node_if,
::common::errors::InvalidArgument(
::phi::errors::InvalidArgument(
"The node type should be isl_ast_node_if."));
isl::ast_node then_body = isl::manage(isl_ast_node_if_get_then(node.get()));
isl::ast_expr condition = isl::manage(isl_ast_node_if_get_cond(node.get()));
Expand Down Expand Up @@ -558,7 +587,12 @@ void IslAstExprToCinnExpr(const isl::ast_expr& node, ir::Expr* expr) {
case isl_ast_op_call: {
ir::Expr caller_expr = ops.front();
// TODO(Superjomn) make it an string
CHECK(caller_expr.node_type() == ir::IrNodeTy::_Var_);
PADDLE_ENFORCE_EQ(
caller_expr.node_type() == ir::IrNodeTy::_Var_,
true,
phi::errors::InvalidArgument(
"Expected caller_expr to be of type _Var_, but got %s.",
caller_expr.node_type()));
std::string caller = caller_expr.As<ir::_Var_>()->name;
ops.erase(ops.begin());
// NOTE the type here is not important.
Expand All @@ -576,15 +610,15 @@ void IslAstExprToCinnExpr(const isl::ast_expr& node, ir::Expr* expr) {
case isl_ast_op_select:
PADDLE_ENFORCE_EQ(ops.size(),
3UL,
::common::errors::InvalidArgument(
::phi::errors::InvalidArgument(
"In ir::Select, the ops size should be 3"));
ops[0]->set_type(Bool());
*expr = ir::Select::Make(ops[0], ops[1], ops[2]);
break;
default:
std::stringstream ss;
ss << "unsupported op " << op_type;
PADDLE_THROW(::common::errors::InvalidArgument(ss.str()));
PADDLE_THROW(::phi::errors::InvalidArgument(ss.str()));
}
} break;
default:
Expand Down
Loading