Skip to content

Commit

Permalink
【BUAA】【Error Message No.16,18】Update error message (#66988)
Browse files Browse the repository at this point in the history
* Add lerp CINN.

* update

* Refine message errors.

* Fix non-pointer error.
  • Loading branch information
Marcusryz authored Aug 5, 2024
1 parent b940ee6 commit 327c5f7
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 84 deletions.
121 changes: 92 additions & 29 deletions paddle/cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ Expr CasSimplifyMutator::SimplifyRationalNumber(Expr u) {
auto* ni = n.As<IntImm>();
auto* di = d.As<IntImm>();

CHECK(ni && di);
PADDLE_ENFORCE_EQ(
ni && di,
true,
::common::errors::InvalidArgument("Ni and Di should not be null."));
int nv = ni->value;
int dv = di->value;

Expand Down Expand Up @@ -611,7 +614,8 @@ Expr CasSimplifyMutator::SimplifySum(Expr u) {
u = SumOrProductGetSingleElementsRec(u);

auto* sum = u.As<Sum>();
CHECK(sum);
PADDLE_ENFORCE_NOT_NULL(
sum, ::common::errors::InvalidArgument("Sum should not be null."));

auto& operands = sum->operands();

Expand Down Expand Up @@ -815,7 +819,10 @@ std::vector<Expr> CasSimplifyMutator::SimplifySumRec(
VLOG(7) << "SimplifySumRec operands: " << ss.str();
}
#endif
CHECK(!operands.empty());
PADDLE_ENFORCE_EQ(
!operands.empty(),
true,
::common::errors::InvalidArgument("Operands should not be empty."));
if (operands.size() < 2)
return {CasSimplify(operands.front(), var_intervals)};
auto mid_it = operands.begin() + operands.size() / 2;
Expand All @@ -837,10 +844,15 @@ void CasSimplifyMutator::UnfoldBound(Expr* lower_bound,
Expr* upper_bound,
Expr var,
bool unfold_const_bound) {
CHECK(lower_bound);
CHECK(upper_bound);
PADDLE_ENFORCE_NOT_NULL(
lower_bound,
::common::errors::InvalidArgument("Lower bound should not be null."));
PADDLE_ENFORCE_NOT_NULL(
upper_bound,
::common::errors::InvalidArgument("Upper bound should not be null."));
auto v_var = var.As<_Var_>();
CHECK(v_var);
PADDLE_ENFORCE_NOT_NULL(
v_var, ::common::errors::InvalidArgument("Var should not be null."));
if (var_intervals.count(v_var->name)) {
auto& interval = var_intervals.at(v_var->name);
if (interval.e_l.defined() && interval.e_r.defined()) {
Expand Down Expand Up @@ -868,8 +880,12 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound,
Expr* upper_bound,
Expr var,
bool unfold_const_bound) {
CHECK(lower_bound);
CHECK(upper_bound);
PADDLE_ENFORCE_NOT_NULL(
lower_bound,
::common::errors::InvalidArgument("Lower bound should not be null."));
PADDLE_ENFORCE_NOT_NULL(
upper_bound,
::common::errors::InvalidArgument("Upper bound should not be null."));
auto v_var = var.As<_Var_>();
auto v_product = var.As<Product>();
auto v_frac = var.As<FracOp>();
Expand All @@ -887,7 +903,10 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound,
Expr v_lower, v_upper;
UnfoldBound(&v_lower, &v_upper, non_const_oper, unfold_const_bound);
auto const_v = const_oper.get_constant();
CHECK(v_lower.defined() && v_upper.defined());
PADDLE_ENFORCE_EQ(v_lower.defined() && v_upper.defined(),
true,
::common::errors::InvalidArgument(
"V lower and upper should be defined."));
if (const_v > 0) {
p_lower_bound = Product::Make({const_oper, v_lower});
p_upper_bound = Product::Make({const_oper, v_upper});
Expand All @@ -910,7 +929,10 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound,
Expr v_lower, v_upper;
UnfoldBound(&v_lower, &v_upper, non_const_oper, unfold_const_bound);
auto const_v = const_oper.get_constant();
CHECK(v_lower.defined() && v_upper.defined());
PADDLE_ENFORCE_EQ(v_lower.defined() && v_upper.defined(),
true,
::common::errors::InvalidArgument(
"V lower and upper should be defined."));
if (const_v > 0) {
p_lower_bound = FracOp::Make(v_lower, const_oper);
p_upper_bound = FracOp::Make(v_upper, const_oper);
Expand All @@ -931,8 +953,12 @@ bool CasSimplifyMutator::GetOperandBound(Expr* lower_bound,
Expr v,
bool unfold_const_bound) {
// only support simple operand of int, var and var's product with int
CHECK(lower_bound);
CHECK(upper_bound);
PADDLE_ENFORCE_NOT_NULL(
lower_bound,
::common::errors::InvalidArgument("Lower bound should not be null."));
PADDLE_ENFORCE_NOT_NULL(
upper_bound,
::common::errors::InvalidArgument("Upper bound should not be null."));
auto* v_int = v.As<IntImm>();
if (v_int) {
AddBaseAndSimplify(lower_bound, v);
Expand All @@ -949,8 +975,12 @@ bool CasSimplifyMutator::GetSumBound(Expr* lower_bound,
Expr sum,
bool unfold_const_bound) {
// only support sum of int, var and var's product with int
CHECK(lower_bound);
CHECK(upper_bound);
PADDLE_ENFORCE_NOT_NULL(
lower_bound,
::common::errors::InvalidArgument("Lower bound should not be null."));
PADDLE_ENFORCE_NOT_NULL(
upper_bound,
::common::errors::InvalidArgument("Upper bound should not be null."));
auto bound_sum = sum.As<Sum>();
// CHECK(bound_sum);
bool get_bound = true;
Expand Down Expand Up @@ -1002,7 +1032,9 @@ bool CasSimplifyMutator::GetMinBound(Expr* lower_bound,
// only support min's operands as sum, int or var or var's product with int or
// min/max
auto bound_min = min.As<Min>();
CHECK(bound_min);
PADDLE_ENFORCE_NOT_NULL(
bound_min,
::common::errors::InvalidArgument("Bound min should not be null."));
bool get_bound = true;
Expr a_lower_bound, a_upper_bound, b_lower_bound, b_upper_bound;
get_bound =
Expand All @@ -1025,7 +1057,9 @@ bool CasSimplifyMutator::GetMaxBound(Expr* lower_bound,
Expr max,
bool unfold_const_bound) {
auto bound_max = max.As<Max>();
CHECK(bound_max);
PADDLE_ENFORCE_NOT_NULL(
bound_max,
::common::errors::InvalidArgument("Bound max should not be null."));
bool get_bound = true;
Expr a_lower_bound, a_upper_bound, b_lower_bound, b_upper_bound;
get_bound =
Expand Down Expand Up @@ -1191,7 +1225,8 @@ inline bool IsVarAllNonnegative(
Expr CasSimplifyMutator::SimplifyMod(Expr u) {
VLOG(6) << "SimplifyMod:" << u;
auto* node = u.As<Mod>();
CHECK(node);
PADDLE_ENFORCE_NOT_NULL(
node, ::common::errors::InvalidArgument("Node should not be null."));

auto a = CasSimplify(node->a(), var_intervals);
auto b = CasSimplify(node->b(), var_intervals);
Expand Down Expand Up @@ -1684,7 +1719,10 @@ Expr ConvertCinnToCAS(Expr expr) {
Visit(&a);
Visit(&b);

CHECK(!is_zero(b)) << "Dividend should not be zero";
PADDLE_ENFORCE_EQ(
!is_zero(b),
true,
::common::errors::InvalidArgument("Dividend should not be zero."));

if (a.is_constant() && a.get_constant() == 0) {
*expr = make_const(a->type(), 0);
Expand Down Expand Up @@ -1744,10 +1782,16 @@ Expr ReplaceMinToConstant(Expr expr) {
auto min_a = op->a();
auto min_b = op->b();
if (min_a.is_constant() && !min_b.is_constant()) {
CHECK(min_a->type().is_integer());
PADDLE_ENFORCE_EQ(
min_a->type().is_integer(),
true,
::common::errors::InvalidArgument("Min a should be an integer."));
*expr = ir::ir_utils::IRCopy(min_a);
} else if (min_b.is_constant() && !min_a.is_constant()) {
CHECK(min_b->type().is_integer());
PADDLE_ENFORCE_EQ(
min_b->type().is_integer(),
true,
::common::errors::InvalidArgument("Min b should be an integer."));
*expr = ir::ir_utils::IRCopy(min_b);
}
}
Expand Down Expand Up @@ -1777,10 +1821,16 @@ Expr ReplaceMaxToConstant(Expr expr) {
auto max_a = op->a();
auto max_b = op->b();
if (max_a.is_constant() && !max_b.is_constant()) {
CHECK(max_a->type().is_integer());
PADDLE_ENFORCE_EQ(
max_a->type().is_integer(),
true,
::common::errors::InvalidArgument("Max a should be an integer."));
*expr = ir::ir_utils::IRCopy(max_a);
} else if (max_b.is_constant() && !max_a.is_constant()) {
CHECK(max_b->type().is_integer());
PADDLE_ENFORCE_EQ(
max_b->type().is_integer(),
true,
::common::errors::InvalidArgument("Max b should be an integer."));
*expr = ir::ir_utils::IRCopy(max_b);
}
}
Expand All @@ -1807,7 +1857,10 @@ Expr ConvertCasToCinn(Expr expr) {
operands.push_back(c);
}

CHECK(!operands.empty());
PADDLE_ENFORCE_EQ(
!operands.empty(),
true,
::common::errors::InvalidArgument("Operands should not be empty."));
if (operands.size() == 1) {
*expr = operands[0];
} else if (operands.size() == 2) {
Expand All @@ -1832,7 +1885,10 @@ Expr ConvertCasToCinn(Expr expr) {
operands.push_back(c);
}

CHECK(!operands.empty());
PADDLE_ENFORCE_EQ(
!operands.empty(),
true,
::common::errors::InvalidArgument("Operands should not be empty."));
if (operands.size() == 1) {
*expr = operands[0];
} else if (operands.size() == 2) {
Expand All @@ -1855,7 +1911,10 @@ Expr ConvertCasToCinn(Expr expr) {
Visit(&a);
Visit(&b);

CHECK(!is_zero(b)) << "Dividend should not be zero";
PADDLE_ENFORCE_EQ(
!is_zero(b),
true,
::common::errors::InvalidArgument("Dividend should not be zero."));
*expr = Div::Make(a, b);
Visit(expr);
}
Expand Down Expand Up @@ -1978,19 +2037,22 @@ Expr SimplifyConstantFrac(FracOp* node) {

if (ai) {
auto* bi = node->b().As<ir::IntImm>();
CHECK(bi);
PADDLE_ENFORCE_NOT_NULL(
bi, ::common::errors::InvalidArgument("Bi should not be null."));
return make_const(ai->type(), ai->value / bi->value);
}

if (au) {
auto* bu = node->b().As<ir::UIntImm>();
CHECK(bu);
PADDLE_ENFORCE_NOT_NULL(
bu, ::common::errors::InvalidArgument("Bu should not be null."));
return make_const(au->type(), au->value / bu->value);
}

if (af) {
auto* bf = node->b().As<ir::FloatImm>();
CHECK(af);
PADDLE_ENFORCE_NOT_NULL(
af, ::common::errors::InvalidArgument("Af should not be null."));
return make_const(af->type(), af->value / bf->value);
}
CINN_NOT_IMPLEMENTED
Expand Down Expand Up @@ -2164,7 +2226,8 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) {
auto* af = a.As<FloatImm>();
auto* bf = b.As<FloatImm>();
if (ai) {
CHECK(bi);
PADDLE_ENFORCE_NOT_NULL(
bi, ::common::errors::InvalidArgument("Bi should not be null."));
int g = gcd(ai->value, bi->value);
int a_d = ai->value / g;
int b_d = bi->value / g;
Expand Down
Loading

0 comments on commit 327c5f7

Please sign in to comment.