diff --git a/paddle/cinn/common/cas.cc b/paddle/cinn/common/cas.cc index ad1b4fdf9d2b9..475d5ea9364ed 100644 --- a/paddle/cinn/common/cas.cc +++ b/paddle/cinn/common/cas.cc @@ -260,7 +260,10 @@ Expr CasSimplifyMutator::SimplifyRationalNumber(Expr u) { auto* ni = n.As(); auto* di = d.As(); - CHECK(ni && di); + PADDLE_ENFORCE_EQ( + ni && di, + true, + ::common::errors::InvalidArgument("Ni and Di should not be null.")); int nv = ni->value; int dv = di->value; @@ -611,7 +614,8 @@ Expr CasSimplifyMutator::SimplifySum(Expr u) { u = SumOrProductGetSingleElementsRec(u); auto* sum = u.As(); - CHECK(sum); + PADDLE_ENFORCE_NOT_NULL( + sum, ::common::errors::InvalidArgument("Sum should not be null.")); auto& operands = sum->operands(); @@ -815,7 +819,10 @@ std::vector CasSimplifyMutator::SimplifySumRec( VLOG(7) << "SimplifySumRec operands: " << ss.str(); } #endif - CHECK(!operands.empty()); + PADDLE_ENFORCE_EQ( + !operands.empty(), + true, + ::common::errors::InvalidArgument("Operands should not be empty.")); if (operands.size() < 2) return {CasSimplify(operands.front(), var_intervals)}; auto mid_it = operands.begin() + operands.size() / 2; @@ -837,10 +844,15 @@ void CasSimplifyMutator::UnfoldBound(Expr* lower_bound, Expr* upper_bound, Expr var, bool unfold_const_bound) { - CHECK(lower_bound); - CHECK(upper_bound); + PADDLE_ENFORCE_NOT_NULL( + lower_bound, + ::common::errors::InvalidArgument("Lower bound should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + upper_bound, + ::common::errors::InvalidArgument("Upper bound should not be null.")); auto v_var = var.As<_Var_>(); - CHECK(v_var); + PADDLE_ENFORCE_NOT_NULL( + v_var, ::common::errors::InvalidArgument("Var should not be null.")); if (var_intervals.count(v_var->name)) { auto& interval = var_intervals.at(v_var->name); if (interval.e_l.defined() && interval.e_r.defined()) { @@ -868,8 +880,12 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound, Expr* upper_bound, Expr var, bool unfold_const_bound) { - CHECK(lower_bound); - CHECK(upper_bound); + PADDLE_ENFORCE_NOT_NULL( + lower_bound, + ::common::errors::InvalidArgument("Lower bound should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + upper_bound, + ::common::errors::InvalidArgument("Upper bound should not be null.")); auto v_var = var.As<_Var_>(); auto v_product = var.As(); auto v_frac = var.As(); @@ -887,7 +903,10 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound, Expr v_lower, v_upper; UnfoldBound(&v_lower, &v_upper, non_const_oper, unfold_const_bound); auto const_v = const_oper.get_constant(); - CHECK(v_lower.defined() && v_upper.defined()); + PADDLE_ENFORCE_EQ(v_lower.defined() && v_upper.defined(), + true, + ::common::errors::InvalidArgument( + "V lower and upper should be defined.")); if (const_v > 0) { p_lower_bound = Product::Make({const_oper, v_lower}); p_upper_bound = Product::Make({const_oper, v_upper}); @@ -910,7 +929,10 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound, Expr v_lower, v_upper; UnfoldBound(&v_lower, &v_upper, non_const_oper, unfold_const_bound); auto const_v = const_oper.get_constant(); - CHECK(v_lower.defined() && v_upper.defined()); + PADDLE_ENFORCE_EQ(v_lower.defined() && v_upper.defined(), + true, + ::common::errors::InvalidArgument( + "V lower and upper should be defined.")); if (const_v > 0) { p_lower_bound = FracOp::Make(v_lower, const_oper); p_upper_bound = FracOp::Make(v_upper, const_oper); @@ -931,8 +953,12 @@ bool CasSimplifyMutator::GetOperandBound(Expr* lower_bound, Expr v, bool unfold_const_bound) { // only support simple operand of int, var and var's product with int - CHECK(lower_bound); - CHECK(upper_bound); + PADDLE_ENFORCE_NOT_NULL( + lower_bound, + ::common::errors::InvalidArgument("Lower bound should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + upper_bound, + ::common::errors::InvalidArgument("Upper bound should not be null.")); auto* v_int = v.As(); if (v_int) { AddBaseAndSimplify(lower_bound, v); @@ -949,8 +975,12 @@ bool CasSimplifyMutator::GetSumBound(Expr* lower_bound, Expr sum, bool unfold_const_bound) { // only support sum of int, var and var's product with int - CHECK(lower_bound); - CHECK(upper_bound); + PADDLE_ENFORCE_NOT_NULL( + lower_bound, + ::common::errors::InvalidArgument("Lower bound should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + upper_bound, + ::common::errors::InvalidArgument("Upper bound should not be null.")); auto bound_sum = sum.As(); // CHECK(bound_sum); bool get_bound = true; @@ -1002,7 +1032,9 @@ bool CasSimplifyMutator::GetMinBound(Expr* lower_bound, // only support min's operands as sum, int or var or var's product with int or // min/max auto bound_min = min.As(); - CHECK(bound_min); + PADDLE_ENFORCE_NOT_NULL( + bound_min, + ::common::errors::InvalidArgument("Bound min should not be null.")); bool get_bound = true; Expr a_lower_bound, a_upper_bound, b_lower_bound, b_upper_bound; get_bound = @@ -1025,7 +1057,9 @@ bool CasSimplifyMutator::GetMaxBound(Expr* lower_bound, Expr max, bool unfold_const_bound) { auto bound_max = max.As(); - CHECK(bound_max); + PADDLE_ENFORCE_NOT_NULL( + bound_max, + ::common::errors::InvalidArgument("Bound max should not be null.")); bool get_bound = true; Expr a_lower_bound, a_upper_bound, b_lower_bound, b_upper_bound; get_bound = @@ -1191,7 +1225,8 @@ inline bool IsVarAllNonnegative( Expr CasSimplifyMutator::SimplifyMod(Expr u) { VLOG(6) << "SimplifyMod:" << u; auto* node = u.As(); - CHECK(node); + PADDLE_ENFORCE_NOT_NULL( + node, ::common::errors::InvalidArgument("Node should not be null.")); auto a = CasSimplify(node->a(), var_intervals); auto b = CasSimplify(node->b(), var_intervals); @@ -1684,7 +1719,10 @@ Expr ConvertCinnToCAS(Expr expr) { Visit(&a); Visit(&b); - CHECK(!is_zero(b)) << "Dividend should not be zero"; + PADDLE_ENFORCE_EQ( + !is_zero(b), + true, + ::common::errors::InvalidArgument("Dividend should not be zero.")); if (a.is_constant() && a.get_constant() == 0) { *expr = make_const(a->type(), 0); @@ -1744,10 +1782,16 @@ Expr ReplaceMinToConstant(Expr expr) { auto min_a = op->a(); auto min_b = op->b(); if (min_a.is_constant() && !min_b.is_constant()) { - CHECK(min_a->type().is_integer()); + PADDLE_ENFORCE_EQ( + min_a->type().is_integer(), + true, + ::common::errors::InvalidArgument("Min a should be an integer.")); *expr = ir::ir_utils::IRCopy(min_a); } else if (min_b.is_constant() && !min_a.is_constant()) { - CHECK(min_b->type().is_integer()); + PADDLE_ENFORCE_EQ( + min_b->type().is_integer(), + true, + ::common::errors::InvalidArgument("Min b should be an integer.")); *expr = ir::ir_utils::IRCopy(min_b); } } @@ -1777,10 +1821,16 @@ Expr ReplaceMaxToConstant(Expr expr) { auto max_a = op->a(); auto max_b = op->b(); if (max_a.is_constant() && !max_b.is_constant()) { - CHECK(max_a->type().is_integer()); + PADDLE_ENFORCE_EQ( + max_a->type().is_integer(), + true, + ::common::errors::InvalidArgument("Max a should be an integer.")); *expr = ir::ir_utils::IRCopy(max_a); } else if (max_b.is_constant() && !max_a.is_constant()) { - CHECK(max_b->type().is_integer()); + PADDLE_ENFORCE_EQ( + max_b->type().is_integer(), + true, + ::common::errors::InvalidArgument("Max b should be an integer.")); *expr = ir::ir_utils::IRCopy(max_b); } } @@ -1807,7 +1857,10 @@ Expr ConvertCasToCinn(Expr expr) { operands.push_back(c); } - CHECK(!operands.empty()); + PADDLE_ENFORCE_EQ( + !operands.empty(), + true, + ::common::errors::InvalidArgument("Operands should not be empty.")); if (operands.size() == 1) { *expr = operands[0]; } else if (operands.size() == 2) { @@ -1832,7 +1885,10 @@ Expr ConvertCasToCinn(Expr expr) { operands.push_back(c); } - CHECK(!operands.empty()); + PADDLE_ENFORCE_EQ( + !operands.empty(), + true, + ::common::errors::InvalidArgument("Operands should not be empty.")); if (operands.size() == 1) { *expr = operands[0]; } else if (operands.size() == 2) { @@ -1855,7 +1911,10 @@ Expr ConvertCasToCinn(Expr expr) { Visit(&a); Visit(&b); - CHECK(!is_zero(b)) << "Dividend should not be zero"; + PADDLE_ENFORCE_EQ( + !is_zero(b), + true, + ::common::errors::InvalidArgument("Dividend should not be zero.")); *expr = Div::Make(a, b); Visit(expr); } @@ -1978,19 +2037,22 @@ Expr SimplifyConstantFrac(FracOp* node) { if (ai) { auto* bi = node->b().As(); - CHECK(bi); + PADDLE_ENFORCE_NOT_NULL( + bi, ::common::errors::InvalidArgument("Bi should not be null.")); return make_const(ai->type(), ai->value / bi->value); } if (au) { auto* bu = node->b().As(); - CHECK(bu); + PADDLE_ENFORCE_NOT_NULL( + bu, ::common::errors::InvalidArgument("Bu should not be null.")); return make_const(au->type(), au->value / bu->value); } if (af) { auto* bf = node->b().As(); - CHECK(af); + PADDLE_ENFORCE_NOT_NULL( + af, ::common::errors::InvalidArgument("Af should not be null.")); return make_const(af->type(), af->value / bf->value); } CINN_NOT_IMPLEMENTED @@ -2164,7 +2226,8 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { auto* af = a.As(); auto* bf = b.As(); if (ai) { - CHECK(bi); + PADDLE_ENFORCE_NOT_NULL( + bi, ::common::errors::InvalidArgument("Bi should not be null.")); int g = gcd(ai->value, bi->value); int a_d = ai->value / g; int b_d = bi->value / g; diff --git a/paddle/cinn/hlir/pe/transform.cc b/paddle/cinn/hlir/pe/transform.cc index 76fdb4e7da182..ebfe0076f31ee 100644 --- a/paddle/cinn/hlir/pe/transform.cc +++ b/paddle/cinn/hlir/pe/transform.cc @@ -45,8 +45,14 @@ std::vector> GetMatmulNewShapes( ::common::errors::InvalidArgument( "The matmul should only have two inputs.")); const auto &x_shape = inputs_shape[0], &y_shape = inputs_shape[1]; - CHECK(!x_shape.empty()) << "The shape of matmul input 'x' should not empty."; - CHECK(!y_shape.empty()) << "The shape of matmul input 'y' should not empty."; + PADDLE_ENFORCE_EQ(!x_shape.empty(), + true, + ::common::errors::InvalidArgument( + "The shape of matmul input 'x' should not empty.")); + PADDLE_ENFORCE_EQ(!y_shape.empty(), + true, + ::common::errors::InvalidArgument( + "The shape of matmul input 'y' should not empty.")); auto matmul_info = [&]() { std::stringstream ss; @@ -85,10 +91,12 @@ std::vector> GetMatmulNewShapes( if (max_dim == 1) { // vector * vector - CHECK(x_shape[0] == y_shape[0]) - << "The matmul input X's numbers must be equal to Y's numbers,when " - "X/Y's dims =1. But here " - << matmul_info(); + PADDLE_ENFORCE_EQ(x_shape[0] == y_shape[0], + true, + ::common::errors::InvalidArgument( + "The matmul input X's numbers must be equal to Y's " + "numbers,when X/Y's dims =1. But here %s.", + matmul_info())); new_x_shape = trans_x ? std::vector{x_shape[0], 1} : std::vector{1, x_shape[0]}; @@ -181,10 +189,13 @@ std::vector> GetMatmulNewShapes( // get the batch dimension after broadcast int x_pos = x_dim - 3, y_pos = y_dim - 3, out_pos = max_dim - 3; while (x_pos >= 0 && y_pos >= 0) { - CHECK(x_shape[x_pos] == y_shape[y_pos] || x_shape[x_pos] == 1 || - y_shape[y_pos] == 1) - << "Input X and Y's batch dimension should be same or 1. But here " - << matmul_info(); + PADDLE_ENFORCE_EQ( + x_shape[x_pos] == y_shape[y_pos] || x_shape[x_pos] == 1 || + y_shape[y_pos] == 1, + true, + ::common::errors::InvalidArgument("Input X and Y's batch dimension " + "should be same or 1. But here %s.", + matmul_info())); out_shape[out_pos] = (x_shape[x_pos] == 1) ? y_shape[y_pos] : x_shape[x_pos]; @@ -215,8 +226,14 @@ std::vector> GetMulNewShapes( ::common::errors::InvalidArgument( "The mul should only have two inputs.")); const auto &x_shape = inputs_shape[0], &y_shape = inputs_shape[1]; - CHECK(!x_shape.empty()) << "The shape of mul input 'x' should not empty."; - CHECK(!y_shape.empty()) << "The shape of mul input 'y' should not empty."; + PADDLE_ENFORCE_EQ(!x_shape.empty(), + true, + ::common::errors::InvalidArgument( + "The shape of matmul input 'x' should not empty.")); + PADDLE_ENFORCE_EQ(!y_shape.empty(), + true, + ::common::errors::InvalidArgument( + "The shape of matmul input 'y' should not empty.")); auto mul_info = [&]() { std::stringstream ss; @@ -294,10 +311,16 @@ std::vector Matmul(const Tensor& A, std::vector shape_B = B->shape; int a_dim = shape_A.size(); int b_dim = shape_B.size(); - CHECK(a_dim == 3U || a_dim == 2U) - << "tensor_A's dim should be 2 or 3 while current dim is " << a_dim; - CHECK(b_dim == 3U || b_dim == 2U) - << "tensor_B's dim should be 2 or 3 while current dim is " << b_dim; + PADDLE_ENFORCE_EQ( + a_dim == 3U || a_dim == 2U, + true, + ::common::errors::InvalidArgument( + "Tensor_A's dim should be 2 or 3 while current dim is %d.", a_dim)); + PADDLE_ENFORCE_EQ( + b_dim == 3U || b_dim == 2U, + true, + ::common::errors::InvalidArgument( + "Tensor_B's dim should be 2 or 3 while current dim is %d.", b_dim)); PADDLE_ENFORCE_EQ(a_dim, b_dim, ::common::errors::InvalidArgument( @@ -307,8 +330,11 @@ std::vector Matmul(const Tensor& A, Expr y_height = trans_b ? shape_B.back() : shape_B[b_dim - 2]; Expr M = trans_a ? shape_A.back() : shape_A[a_dim - 2]; Expr N = trans_b ? shape_B[b_dim - 2] : shape_B.back(); - CHECK(is_zero(x_width - y_height)) - << "matrix multiplication requires x_width to be same with y_height"; + PADDLE_ENFORCE_EQ( + is_zero(x_width - y_height), + true, + ::common::errors::InvalidArgument( + "Matrix multiplication requires x_width to be same with y_height.")); std::vector output_shape; std::vector out; if (a_dim == 3) { @@ -324,8 +350,12 @@ std::vector Matmul(const Tensor& A, int out_dim = indice.size(); std::vector A_indice; std::vector B_indice; - CHECK(out_dim == 3U || out_dim == 2U) - << "indice size should be 2 or 3 while current dim is " << out_dim; + PADDLE_ENFORCE_EQ( + out_dim == 3U || out_dim == 2U, + true, + ::common::errors::InvalidArgument( + "Indice size should be 2 or 3 while current dim is %d.", + out_dim)); if (out_dim == 3U) { // batch A_indice.push_back(indice[0]); @@ -435,9 +465,13 @@ ir::Tensor Concat(const std::vector& input_tensors, "Concat should have at least 1 input tensors")); std::vector output_shape = input_tensors[0]->shape; int input_dim = output_shape.size(); - CHECK(axis >= -input_dim && axis < input_dim) - << "Concat's axis should be in [-R, R)" - << ", but get axis: " << axis << ", R: " << input_dim; + PADDLE_ENFORCE_EQ( + axis >= -input_dim && axis < input_dim, + true, + ::common::errors::InvalidArgument( + "Concat's axis should be in [-R, R), but get axis: %d, R: %d.", + axis, + input_dim)); if (axis < 0) axis += output_shape.size(); for (int i = 1; i < input_size; i++) { @@ -482,10 +516,16 @@ std::vector MatmulV2(const Tensor& A, std::vector shape_B = B->shape; int a_dim = shape_A.size(); int b_dim = shape_B.size(); - CHECK(a_dim == 3U || a_dim == 2U) - << "tensor_A's dim should be 2 or 3 while current dim is " << a_dim; - CHECK(b_dim == 3U || b_dim == 2U) - << "tensor_B's dim should be 2 or 3 while current dim is " << b_dim; + PADDLE_ENFORCE_EQ( + a_dim == 3U || a_dim == 2U, + true, + ::common::errors::InvalidArgument( + "Tensor_A's dim should be 2 or 3 while current dim is %d.", a_dim)); + PADDLE_ENFORCE_EQ( + b_dim == 3U || b_dim == 2U, + true, + ::common::errors::InvalidArgument( + "Tensor_B's dim should be 2 or 3 while current dim is %d.", b_dim)); PADDLE_ENFORCE_EQ(a_dim, b_dim, ::common::errors::InvalidArgument( @@ -495,8 +535,11 @@ std::vector MatmulV2(const Tensor& A, Expr y_height = trans_b ? shape_B.back() : shape_B[b_dim - 2]; Expr M = trans_a ? shape_A.back() : shape_A[a_dim - 2]; Expr N = trans_b ? shape_B[b_dim - 2] : shape_B.back(); - CHECK(is_zero(x_width - y_height)) - << "matrix multiplication requires x_width to be same with y_height"; + PADDLE_ENFORCE_EQ( + is_zero(x_width - y_height), + true, + ::common::errors::InvalidArgument( + "Matrix multiplication requires x_width to be same with y_height.")); Var reduce_k(x_width, UniqName("reduce_k")); std::vector output_shape; std::vector out; @@ -544,8 +587,12 @@ std::vector MatmulV2(const Tensor& A, std::vector indice_a; std::vector indice_b; int out_dim = indice.size(); - CHECK(out_dim == 3U || out_dim == 2U) - << "indice size should be 2 or 3 while current dim is " << out_dim; + PADDLE_ENFORCE_EQ( + out_dim == 3U || out_dim == 2U, + true, + ::common::errors::InvalidArgument( + "Indice size should be 2 or 3 while current dim is %d.", + out_dim)); if (out_dim == 3) { // batch indice_a.push_back(indice[0]); @@ -578,16 +625,24 @@ std::vector MatmulMKL(const Tensor& A, float alpha, const std::string& name, const cinn::common::Target& target) { - CHECK(std::holds_alternative(target.arch)) - << "mkl should be used in the cpu environment"; + PADDLE_ENFORCE_EQ(std::holds_alternative(target.arch), + true, + ::common::errors::InvalidArgument( + "Mkl should be used in the cpu environment.")); std::vector shape_A = A->shape; std::vector shape_B = B->shape; int a_dim = shape_A.size(); int b_dim = shape_B.size(); - CHECK(a_dim == 3U || a_dim == 2U) - << "tensor_A's dim should be 2 or 3 while current dim is " << a_dim; - CHECK(b_dim == 3U || b_dim == 2U) - << "tensor_B's dim should be 2 or 3 while current dim is " << b_dim; + PADDLE_ENFORCE_EQ( + a_dim == 3U || a_dim == 2U, + true, + ::common::errors::InvalidArgument( + "Tensor_A's dim should be 2 or 3 while current dim is %d.", a_dim)); + PADDLE_ENFORCE_EQ( + b_dim == 3U || b_dim == 2U, + true, + ::common::errors::InvalidArgument( + "Tensor_B's dim should be 2 or 3 while current dim is %d.", b_dim)); PADDLE_ENFORCE_EQ(a_dim, b_dim, ::common::errors::InvalidArgument( @@ -603,8 +658,11 @@ std::vector MatmulMKL(const Tensor& A, Expr y_height = trans_b ? shape_B.back() : shape_B[b_dim - 2]; Expr M = trans_a ? shape_A.back() : shape_A[a_dim - 2]; Expr N = trans_b ? shape_B[b_dim - 2] : shape_B.back(); - CHECK(is_zero(x_width - y_height)) - << "matrix multiplication requires x_width to be same with y_height"; + PADDLE_ENFORCE_EQ( + is_zero(x_width - y_height), + true, + ::common::errors::InvalidArgument( + "Matrix multiplication requires x_width to be same with y_height.")); ir::Tensor call; if (a_dim == 2U) { @@ -854,8 +912,10 @@ std::vector MulMKL(const Tensor& A, const Tensor& B, const std::string& name, const cinn::common::Target& target) { - CHECK(std::holds_alternative(target.arch)) - << "mkl should be used in the cpu environment"; + PADDLE_ENFORCE_EQ(std::holds_alternative(target.arch), + true, + ::common::errors::InvalidArgument( + "Mkl should be used in the cpu environment.")); std::vector shape_A = A->shape; std::vector shape_B = B->shape; int a_dim = shape_A.size(); @@ -877,8 +937,11 @@ std::vector MulMKL(const Tensor& A, Expr y_height = shape_B[1]; Expr M = shape_A[0]; Expr N = shape_B[0]; - CHECK(is_zero(x_width - y_height)) - << "matrix multiplication requires x_width to be same with y_height"; + PADDLE_ENFORCE_EQ( + is_zero(x_width - y_height), + true, + ::common::errors::InvalidArgument( + "Matrix multiplication requires x_width to be same with y_height.")); PADDLE_ENFORCE_EQ(A->shape[1], B->shape[1], ::common::errors::InvalidArgument( @@ -937,14 +1000,26 @@ void GetLayoutTransformInfo( "sub-axis factor should be larger than 0")); int src_primal_index = src_layout.axis_names().find(prim_axis_name); int dst_primal_index = dst_layout.axis_names().find(prim_axis_name); - CHECK(src_primal_index != src_layout.axis_names().npos); - CHECK(dst_primal_index != dst_layout.axis_names().npos); + PADDLE_ENFORCE_EQ( + src_primal_index != src_layout.axis_names().npos, + true, + ::common::errors::InvalidArgument( + "Src primal index should not be equal to src layout npos.")); + PADDLE_ENFORCE_EQ( + dst_primal_index != dst_layout.axis_names().npos, + true, + ::common::errors::InvalidArgument( + "Dst primal index should not be equal to dst layout npos.")); (*split_index_map)[src_primal_index] = {dst_primal_index, i, factor}; } else { int src_primal_index = src_layout.axis_names().find(prim_axis_name); if (split_index_map->find(src_primal_index) != split_index_map->end()) continue; - CHECK(src_primal_index != src_layout.axis_names().npos); + PADDLE_ENFORCE_EQ( + src_primal_index != src_layout.axis_names().npos, + true, + ::common::errors::InvalidArgument( + "Src primal index should not be equal to src layout npos.")); (*split_index_map)[src_primal_index] = {i}; } } @@ -972,7 +1047,11 @@ std::vector InferShapeLayoutTransform( } else if (src_dim < dst_dim) { GetLayoutTransformInfo(old_layout, new_layout, split_index_map); for (int i = 0; i < src_dim; i++) { - CHECK(split_index_map->find(i) != split_index_map->end()); + PADDLE_ENFORCE_EQ( + split_index_map->find(i) != split_index_map->end(), + true, + ::common::errors::InvalidArgument( + "Spilt index map found should not be equal to end.")); if ((*split_index_map)[i].size() == 3) { int dst_prim_index = (*split_index_map)[i][0]; int dst_sub_index = (*split_index_map)[i][1]; @@ -989,7 +1068,11 @@ std::vector InferShapeLayoutTransform( } else { GetLayoutTransformInfo(new_layout, old_layout, split_index_map); for (int i = 0; i < dst_dim; i++) { - CHECK(split_index_map->find(i) != split_index_map->end()); + PADDLE_ENFORCE_EQ( + split_index_map->find(i) != split_index_map->end(), + true, + ::common::errors::InvalidArgument( + "Spilt index map found should not be equal to end.")); if ((*split_index_map)[i].size() == 3) { int src_prim_index = (*split_index_map)[i][0]; int src_sub_index = (*split_index_map)[i][1]; @@ -1020,8 +1103,11 @@ ir::Tensor LayoutTransform(const Tensor& input, const std::string& src_layout, const std::string& dst_layout, const std::string& name) { - CHECK(src_layout != dst_layout) - << "dst_layout is same with src_layout, should not do layout transform"; + PADDLE_ENFORCE_EQ( + src_layout != dst_layout, + true, + ::common::errors::InvalidArgument("Dst layout is same with src_layout, " + "should not do layout transform.")); // NCHW -> NCHWxc // NCHWxc -> NCHW // OIHW -> OIHWxixo @@ -1055,7 +1141,11 @@ ir::Tensor LayoutTransform(const Tensor& input, std::vector new_indice(src_dim); int min_dim = std::min(src_dim, dst_dim); for (int i = 0; i < min_dim; i++) { - CHECK(split_index_map.find(i) != split_index_map.end()); + PADDLE_ENFORCE_EQ( + split_index_map.find(i) != split_index_map.end(), + true, + ::common::errors::InvalidArgument( + "Spilt index map found should not be equal to end.")); std::vector split_infos = split_index_map.at(i); if (split_infos.size() == 3) { int prim_index = split_infos[0]; @@ -1092,8 +1182,10 @@ ir::Tensor Reverse(const ir::Tensor& input, const std::vector& axis, const std::string& output_name) { for (auto& val : axis) { - CHECK(val >= 0 && val < static_cast(input->shape.size())) - << "axis should be [0,n_dim)"; + PADDLE_ENFORCE_EQ( + val >= 0 && val < static_cast(input->shape.size()), + true, + ::common::errors::InvalidArgument("Axis should be [0,n_dim).")); } std::vector shape = input->shape; return lang::Compute( @@ -1116,8 +1208,10 @@ ir::Tensor Transpose(const ir::Tensor& input, ::common::errors::InvalidArgument( "input shape size and axis size is not equal!")); for (int idx = 0; idx < axis.size(); ++idx) { - CHECK(axis[idx] >= 0 && axis[idx] < axis.size()) - << "axis value should be among [0,axis.size())"; + PADDLE_ENFORCE_EQ(axis[idx] >= 0 && axis[idx] < axis.size(), + true, + ::common::errors::InvalidArgument( + "Axis value should be among [0,axis.size()).")); for (int idy = idx + 1; idy < axis.size(); ++idy) { PADDLE_ENFORCE_NE( axis[idx],