Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Error Message No. 35 BUAA】lowered_func.cc & ir_operators.cc&vectorize_loops.cc #66830

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 68 additions & 21 deletions paddle/cinn/ir/lowered_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,18 @@ void _LoweredFunc_::PrepareCudaAxisInfoFromBody() {
}

void _LoweredFunc_::PrepareAllocOutputBufferExprs() {
CHECK(alloc_output_buffer_exprs.empty())
<< "duplicate prepare the allocate buffer for outputs";

PADDLE_ENFORCE_EQ(alloc_output_buffer_exprs.empty(),
true,
phi::errors::InvalidArgument(
"Duplicate prepare the allocate buffer for outputs."));
std::set<std::string> buffer_names;
for (auto& arg : args) {
if (arg.is_output()) {
CHECK(arg.type().valid())
<< "argument [" << arg.name() << "]'s type should be set";
PADDLE_ENFORCE_EQ(
arg.type().valid(),
true,
phi::errors::InvalidArgument("Argument ['%s']'s type should be set.",
arg.name()));
if (arg.is_buffer() &&
!buffer_names.count(arg.name())) { // only buffer need allocation.
buffer_names.insert(arg.name()); // Avoid duplicate
Expand Down Expand Up @@ -200,14 +204,19 @@ std::vector<Expr> _LoweredFunc_::CudaPrepareAllocTempBufferExprs() const {
}

void _LoweredFunc_::PrepareDeallocOutputBufferExprs() {
CHECK(dealloc_output_buffer_exprs.empty())
<< "duplicate prepare the allocate buffer for outputs";
PADDLE_ENFORCE_EQ(dealloc_output_buffer_exprs.empty(),
true,
phi::errors::InvalidArgument(
"Duplicate prepare the allocate buffer for outputs."));

std::set<std::string> buffer_names;
for (auto& arg : args) {
if (arg.is_output()) {
CHECK(arg.type().valid())
<< "argument [" << arg.name() << "]'s type should be set";
PADDLE_ENFORCE_EQ(
arg.type().valid(),
true,
phi::errors::InvalidArgument("Argument ['%s']'s type should be set.",
arg.name()));
if (arg.is_buffer() &&
!buffer_names.count(arg.name())) { // only buffer need allocation.
buffer_names.insert(arg.name()); // Avoid duplicate
Expand All @@ -232,7 +241,11 @@ void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) {
VLOG(3) << "Function used " << tensors.size() << " buffers";
for (auto& tensor : tensors) {
auto* node = tensor.As<ir::_Tensor_>();
CHECK(node);
PADDLE_ENFORCE_NOT_NULL(
node,
phi::errors::InvalidArgument(
"Failed to convert tensor to ir::_Tensor_. The tensor might be "
"invalid or of an incorrect type."));
if (!tensor->buffer.defined()) continue;

Type value_type = tensor->type().ElementOf();
Expand Down Expand Up @@ -271,7 +284,11 @@ std::vector<Expr> _LoweredFunc_::CudaAliasVarExprs() const {

for (auto& tensor : tensors) {
auto* node = tensor.As<ir::_Tensor_>();
CHECK(node);
PADDLE_ENFORCE_NOT_NULL(
node,
phi::errors::InvalidArgument(
"Failed to convert tensor to ir::_Tensor_. The tensor might be "
"invalid or of an incorrect type."));
if (!tensor->buffer.defined()) {
continue;
}
Expand Down Expand Up @@ -309,8 +326,10 @@ void _LoweredFunc_::PrepareArgumentExprs() {
.set_cpp_handle();
// type of `const cinn_buffer_t*`
auto const_buffer_ptr_type = buffer_ptr_type.with_cpp_const();
CHECK(!buffer_ptr_type.is_cpp_const());

PADDLE_ENFORCE_NE(buffer_ptr_type.is_cpp_const(),
true,
phi::errors::InvalidArgument(
"The buffer pointer type should not be const."));
Var args_passed_in("_args", type_of<void*>());
auto pod_value_ptr =
cinn::common::CastIfNeeded(args_passed_in, type_of<cinn_pod_value_t*>());
Expand Down Expand Up @@ -357,7 +376,10 @@ void _LoweredFunc_::PrepareArgumentExprs() {
CINN_NOT_IMPLEMENTED
}

CHECK(_arg->type().valid());
PADDLE_ENFORCE_EQ(
_arg->type().valid(),
true,
phi::errors::InvalidArgument("Argument's type should be set."));

Expr pod_cast_expr;

Expand Down Expand Up @@ -425,7 +447,10 @@ void _LoweredFunc_::PrepareArgumentExprs() {

VLOG(6) << "args " << i << "convert";
Expr let_expr = Let::Make(_arg, pod_cast_expr);
CHECK(let_expr.type().valid());
PADDLE_ENFORCE_EQ(let_expr.type().valid(),
true,
phi::errors::InvalidArgument(
"The let expression's type should be set."));
argument_prepare_exprs.push_back(let_expr);
}
}
Expand Down Expand Up @@ -456,22 +481,36 @@ std::vector<Tensor> _LoweredFunc_::CollectAllTensorReference(
}

ir::Buffer Argument::buffer_arg() const {
CHECK(is_buffer());
PADDLE_ENFORCE_EQ(
is_buffer(),
true,
phi::errors::InvalidArgument(
"The argument is not a buffer. Unable to return buffer_arg_."));
return buffer_arg_;
}

ir::Var Argument::var_arg() const {
CHECK(is_var());
PADDLE_ENFORCE_EQ(
is_var(),
true,
phi::errors::InvalidArgument(
"The argument is not a variable. Unable to return var_arg_."));
return var_arg_;
}

void Argument::set_buffer(const ir::Buffer& x) {
CHECK(!is_var()) << "the buffer is already a var";
PADDLE_ENFORCE_EQ(
!is_var(),
true,
phi::errors::InvalidArgument("The buffer is already a variable."));
buffer_arg_ = x;
}

void Argument::set_var(const ir::Var& x) {
CHECK(!is_buffer()) << "the buffer is already a buffer";
PADDLE_ENFORCE_EQ(
!is_buffer(),
true,
phi::errors::InvalidArgument("The buffer is already a buffer."));
var_arg_ = x;
}

Expand Down Expand Up @@ -557,7 +596,11 @@ void CudaAxisInfo::set_block_dim(int offset, ir::Expr x) {
}

ir::Expr CudaAxisInfo::grid_dim(int offset) const {
CHECK(valid_);
PADDLE_ENFORCE_EQ(
valid_,
true,
phi::errors::InvalidArgument("CudaAxisInfo is not valid. This check "
"failed in grid_dim() method."));
PADDLE_ENFORCE_LT(
offset,
3,
Expand All @@ -566,7 +609,11 @@ ir::Expr CudaAxisInfo::grid_dim(int offset) const {
}

ir::Expr CudaAxisInfo::block_dim(int offset) const {
CHECK(valid_);
PADDLE_ENFORCE_EQ(
valid_,
true,
phi::errors::InvalidArgument("CudaAxisInfo is not valid. This check "
"failed in block_dim() method."));
PADDLE_ENFORCE_LT(
offset,
3,
Expand Down
75 changes: 58 additions & 17 deletions paddle/cinn/ir/op/ir_operators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,27 @@ namespace ir {
using attr_t = absl::variant<int, float, bool, std::string>;

Expr operator<<(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
PADDLE_ENFORCE_EQ(
a.type().is_int() || a.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
PADDLE_ENFORCE_EQ(
b.type().is_int() || b.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
auto int_a = a.As<IntImm>();
auto int_b = b.As<IntImm>();
Type t_a = a.type();
Type t_b = b.type();
if (t_a.is_index_type() && t_b.is_index_type()) {
if (int_b) {
CHECK(int_b->value >= 0 && int_b->value < t_a.bits())
<< "Shift amount must be non-negative and less than " << t_a.bits()
<< " for type " << t_a << std::endl;
PADDLE_ENFORCE_EQ(
int_b->value >= 0 && int_b->value < t_a.bits(),
true,
phi::errors::InvalidArgument(
"Shift amount must be non-negative and less than %d for type %s.",
t_a.bits(),
t_a));
if (int_b->value == 0) return a;
}
if (int_a && int_b) {
Expand All @@ -49,17 +59,27 @@ Expr operator<<(Expr a, Expr b) {
}

Expr operator>>(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
PADDLE_ENFORCE_EQ(
a.type().is_int() || a.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
PADDLE_ENFORCE_EQ(
b.type().is_int() || b.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
auto int_a = a.As<IntImm>();
auto int_b = b.As<IntImm>();
Type t_a = a.type();
Type t_b = b.type();
if (t_a.is_index_type() && t_b.is_index_type()) {
if (int_b) {
CHECK(int_b->value >= 0 && int_b->value < t_a.bits())
<< "Shift amount must be non-negative and less than " << t_a.bits()
<< " for type " << t_a << std::endl;
PADDLE_ENFORCE_EQ(
int_b->value >= 0 && int_b->value < t_a.bits(),
true,
phi::errors::InvalidArgument(
"Shift amount must be non-negative and less than %d for type %s.",
t_a.bits(),
t_a));
if (int_b->value == 0) return a;
}
if (int_a && int_b) {
Expand Down Expand Up @@ -113,8 +133,14 @@ Expr BitwiseOrCall(const Target& target, Expr a, Expr b) {
}

Expr operator|(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
PADDLE_ENFORCE_EQ(
a.type().is_int() || a.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
PADDLE_ENFORCE_EQ(
b.type().is_int() || b.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
auto int_a = a.As<IntImm>();
auto int_b = b.As<IntImm>();
Type t_a = a.type();
Expand Down Expand Up @@ -172,8 +198,14 @@ Expr BitwiseAndCall(const Target& target, Expr a, Expr b) {
}

Expr operator&(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
PADDLE_ENFORCE_EQ(
a.type().is_int() || a.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
PADDLE_ENFORCE_EQ(
b.type().is_int() || b.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
auto int_a = a.As<IntImm>();
auto int_b = b.As<IntImm>();
Type t_a = a.type();
Expand Down Expand Up @@ -231,8 +263,14 @@ Expr BitwiseXorCall(const Target& target, Expr a, Expr b) {
}

Expr operator^(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
PADDLE_ENFORCE_EQ(
a.type().is_int() || a.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
PADDLE_ENFORCE_EQ(
b.type().is_int() || b.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
auto int_a = a.As<IntImm>();
auto int_b = b.As<IntImm>();
Type t_a = a.type();
Expand Down Expand Up @@ -279,7 +317,10 @@ Expr BitwiseNotCall(const Target& target, Expr a) {
}

Expr operator~(Expr a) {
CHECK(a.type().is_int() || a.type().is_uint());
PADDLE_ENFORCE_EQ(
a.type().is_int() || a.type().is_uint(),
true,
phi::errors::InvalidArgument("The input's type should be int or uint."));
auto target = cinn::runtime::CurrentTarget::GetCurrentTarget();
return BitwiseNotCall(target, a);
}
Expand Down
Loading