Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… fix-anchor-fusion
  • Loading branch information
huangjiyi committed Aug 1, 2024
2 parents 0c7c524 + b3d6d26 commit e6cd91e
Show file tree
Hide file tree
Showing 368 changed files with 8,915 additions and 4,483 deletions.
11 changes: 8 additions & 3 deletions paddle/cinn/adt/igroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@ std::shared_ptr<IndexExprInferContext> MakeIndexExprInferContext(
const auto& anchor_iterators = igroup.GetAnchorIterators();

for (std::size_t i = 0; i < anchor_iterators->size(); ++i) {
CHECK(anchor_iterator2value
.emplace(anchor_iterators->at(i), anchor_iterators->at(i))
.second);
PADDLE_ENFORCE_EQ(
anchor_iterator2value
.emplace(anchor_iterators->at(i), anchor_iterators->at(i))
.second,
true,
phi::errors::InvalidArgument(
"The element in anchor iterators failed to insert in anchor "
"iterator2value! Please check."));
}

return std::make_shared<IndexExprInferContext>(anchor_iterator2value);
Expand Down
33 changes: 28 additions & 5 deletions paddle/cinn/adt/schedule_mesh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,22 +385,39 @@ ScheduleMesh MeshReshape(const ScheduleMesh& sched_mesh,
const auto& origin_shape = GetOutputDimValues(sched_mesh);
std::int64_t origin_numel = 1;
for (const auto& dim : *origin_shape) {
CHECK(dim.Has<std::int64_t>());
PADDLE_ENFORCE_EQ(
dim.Has<std::int64_t>(),
true,
phi::errors::InvalidArgument(
"Each dimension in 'origin_shape' must have an int64_t value."));
origin_numel *= dim.Get<std::int64_t>();
}

std::int64_t numel = 1;
bool dynamic_shape = false;
for (const auto& dim : shape) {
if (dim < 0) {
CHECK(dim == -1 && !dynamic_shape);
PADDLE_ENFORCE_EQ(
dim == -1 && !dynamic_shape,
true,
phi::errors::InvalidArgument("Negative dimension in 'shape' must be "
"-1 to represent dynamic shape. "
"But received: %d",
dim));
dynamic_shape = true;
} else {
numel *= dim;
}
}

CHECK(dynamic_shape || numel == origin_numel);
PADDLE_ENFORCE_EQ(dynamic_shape || numel == origin_numel,
true,
phi::errors::InvalidArgument(
"The total number of elements must match between "
"'shape' and 'origin_shape' "
"unless there is a dynamic shape. "
"But received: numel = %d, origin_numel = %d",
numel,
origin_numel));
List<LoopSize> reshape_to{};
for (const auto& dim : shape) {
if (dim < 0) {
Expand Down Expand Up @@ -465,7 +482,13 @@ ScheduleMesh MeshPaddingRoundUp(
continue;
}
std::int64_t align_size = align_sizes.at(i).value();
CHECK(shape->at(i).Has<std::int64_t>());
PADDLE_ENFORCE_EQ(
shape->at(i).Has<std::int64_t>(),
true,
phi::errors::InvalidArgument(
"Each dimension in 'shape' must have an int64_t value. "
"But the dimension at index %d does not.",
i));
std::int64_t dim = shape->at(i).Get<std::int64_t>();
std::int64_t padding_size =
(dim + align_size - 1) / align_size * align_size;
Expand Down
21 changes: 15 additions & 6 deletions paddle/cinn/backends/function_prototype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,18 @@ void FunctionProto::AssertMatch(const ir::Call *op) const {

void FunctionProto::CheckValid() {
if (ret_type.is_void()) {
CHECK(!mutable_arg_types.empty())
<< "A void function should have at least one mutable argument to "
"output something";
PADDLE_ENFORCE_EQ(
!mutable_arg_types.empty(),
true,
phi::errors::InvalidArgument(
"A void function should have at least one mutable argument to "
"output something."));
} else {
CHECK(mutable_arg_types.empty())
<< "A function with return should not have mutable argument";
PADDLE_ENFORCE_EQ(
mutable_arg_types.empty(),
true,
phi::errors::InvalidArgument(
"A function with return should not have mutable argument."));
}
}

Expand All @@ -107,7 +113,10 @@ FunctionProto::shape_inference_t FunctionProto::ShapeFollowNthArgument(int n) {
::common::errors::InvalidArgument(
"The argument index is out of range"));
auto x = args[n].as_tensor();
CHECK(x);
PADDLE_ENFORCE_NOT_NULL(
x,
phi::errors::InvalidArgument(
"The argument at index (%d) must be a tensor.", n));
return x->shape;
};
}
Expand Down
126 changes: 102 additions & 24 deletions paddle/cinn/backends/llvm/llvm_intrin_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,34 @@ namespace codegen {

template <int id, int arg_nums, bool add_float_suffix = true>
inline void MakeFloatIntrinOp(lang::Args args, lang::RetValue *rv) {
PADDLE_ENFORCE_GE(args.size(),
1U,
::common::errors::InvalidArgument(
"The number of args should be greater than 1."));
PADDLE_ENFORCE_GE(
args.size(),
1U,
phi::errors::InvalidArgument(
"The number of arguments should be at least 1. Received: %d",
args.size()));

Expr arg = args[0];
ir::Call *node = arg->as<ir::Call>();
CHECK(node);

PADDLE_ENFORCE_NOT_NULL(node,
phi::errors::InvalidArgument(
"The argument must be a valid call expression."));

PADDLE_ENFORCE_GE(
node->read_args.size(),
arg_nums,
::common::errors::InvalidArgument(
"The number of read args should be greater than arg_nums."));
phi::errors::InvalidArgument(
"The number of read arguments should be at least %d. Received: %d",
arg_nums,
node->read_args.size()));

if (add_float_suffix) {
CHECK(node->type().is_float());
PADDLE_ENFORCE_EQ(node->type().is_float(),
true,
phi::errors::InvalidArgument(
"The node type should be float. Received: %s",
node->type().to_string().c_str()));
*rv = ir::intrinsics::BuiltinIntrin::Make(
node->name + "f", node->read_args, id, arg_nums, node->type());
} else {
Expand Down Expand Up @@ -98,8 +112,16 @@ void RegisterCpuIntrinRule() {
"The number of args should be greater than 1."));
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
PADDLE_ENFORCE_NOT_NULL(
node,
phi::errors::InvalidArgument(
"The argument must be a valid call expression."));
PADDLE_ENFORCE_EQ(
!node->read_args.empty(),
true,
phi::errors::InvalidArgument(
"The read_args of the node should not be empty."));

Expr arg = node->read_args[0];
*rv = !(lang::IsInf(arg)) && !(lang::IsNan(arg));
});
Expand All @@ -112,8 +134,16 @@ void RegisterCpuIntrinRule() {
"The number of args should be greater than 1."));
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
PADDLE_ENFORCE_NOT_NULL(
node,
phi::errors::InvalidArgument("The argument must be a valid call "
"expression. Received null."));
PADDLE_ENFORCE_EQ(!node->read_args.empty(),
true,
phi::errors::InvalidArgument(
"The read_args of the node should not be empty. "
"The provided read_args are empty."));

Expr arg = node->read_args[0];
Type type = arg->type();
if (type.is_int() || type.is_uint()) {
Expand All @@ -132,8 +162,16 @@ void RegisterCpuIntrinRule() {
"The number of args should be greater than 1."));
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
PADDLE_ENFORCE_NOT_NULL(
node,
phi::errors::InvalidArgument("The argument must be a valid call "
"expression. Received null."));
PADDLE_ENFORCE_EQ(!node->read_args.empty(),
true,
phi::errors::InvalidArgument(
"The read_args of the node should not be empty. "
"Received empty read_args."));

Expr arg = node->read_args[0];
*rv = make_const(arg->type(), 1) / lang::Sqrt(arg);
});
Expand All @@ -146,8 +184,16 @@ void RegisterCpuIntrinRule() {
"The number of args should be greater than 1."));
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
PADDLE_ENFORCE_NOT_NULL(
node,
phi::errors::InvalidArgument("The argument must be a valid call "
"expression. Received null."));
PADDLE_ENFORCE_EQ(!node->read_args.empty(),
true,
phi::errors::InvalidArgument(
"The read_args of the node should not be empty. "
"Received empty read_args."));

Expr arg = node->read_args[0];
Expr ln10 = make_const(arg->type(), 2.302585093);
*rv = lang::Exp(arg * ln10);
Expand All @@ -161,8 +207,16 @@ void RegisterCpuIntrinRule() {
"The number of args should be greater than 1."));
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
PADDLE_ENFORCE_NOT_NULL(
node,
phi::errors::InvalidArgument("The argument must be a valid call "
"expression. Received null."));
PADDLE_ENFORCE_EQ(!node->read_args.empty(),
true,
phi::errors::InvalidArgument(
"The read_args of the node should not be empty. "
"Received empty read_args."));

Expr arg = node->read_args[0];
*rv = lang::Sin(arg) / lang::Cos(arg);
});
Expand All @@ -175,8 +229,16 @@ void RegisterCpuIntrinRule() {
"The number of args should be greater than 1."));
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
PADDLE_ENFORCE_NOT_NULL(
node,
phi::errors::InvalidArgument("The argument must be a valid call "
"expression. Received null."));
PADDLE_ENFORCE_EQ(!node->read_args.empty(),
true,
phi::errors::InvalidArgument(
"The read_args of the node should not be empty. "
"Received empty read_args."));

Expr arg = node->read_args[0];
Expr zero = make_const(arg->type(), 0);
Expr one = make_const(arg->type(), 1);
Expand All @@ -199,8 +261,16 @@ void RegisterCpuIntrinRule() {
"The number of args should be greater than 1."));
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
PADDLE_ENFORCE_NOT_NULL(
node,
phi::errors::InvalidArgument("The argument must be a valid call "
"expression. Received null."));
PADDLE_ENFORCE_EQ(!node->read_args.empty(),
true,
phi::errors::InvalidArgument(
"The read_args of the node should not be empty. "
"Received empty read_args."));

Expr arg = node->read_args[0];
*rv = (lang::Exp(arg) + lang::Exp(arg * make_const(arg->type(), -1))) /
make_const(arg->type(), 2);
Expand All @@ -214,8 +284,16 @@ void RegisterCpuIntrinRule() {
"The number of args should be greater than 1."));
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
PADDLE_ENFORCE_NOT_NULL(
node,
phi::errors::InvalidArgument("The argument must be a valid call "
"expression. Received null."));
PADDLE_ENFORCE_EQ(!node->read_args.empty(),
true,
phi::errors::InvalidArgument(
"The read_args of the node should not be empty. "
"Received empty read_args."));

Expr arg = node->read_args[0];
*rv = (lang::Exp(arg) - lang::Exp(arg * make_const(arg->type(), -1))) /
make_const(arg->type(), 2);
Expand Down
14 changes: 10 additions & 4 deletions paddle/cinn/backends/llvm/simple_jit.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ void SimpleJIT::AddModule(std::unique_ptr<llvm::Module> module, bool optimize) {
LOG(INFO) << "fn:\n" << DumpToString(fn);
}
*/
CHECK(!llvm::verifyModule(*module, &llvm::errs()))
<< "Transformation resulted in an invalid module\n\nmodule:\n";
PADDLE_ENFORCE_EQ(
!llvm::verifyModule(*module, &llvm::errs()),
true,
phi::errors::InvalidArgument(
"Transformation resulted in an invalid module\n\nmodule:\n"));

bool debug = false;
if (optimize) {
Expand Down Expand Up @@ -99,7 +102,8 @@ SimpleJIT::SimpleJIT() : context_(std::make_unique<llvm::LLVMContext>()) {
llvm::InitializeAllAsmPrinters();

jit_ = llvm::cantFail(llvm::orc::LLJITBuilder().create());
CHECK(jit_) << "JIT create failed";
PADDLE_ENFORCE_NOT_NULL(jit_,
phi::errors::InvalidArgument("JIT creation failed."));

auto proc_symbols_generator = llvm::cantFail(
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
Expand Down Expand Up @@ -129,7 +133,9 @@ void SimpleJIT::Link(ir::Module module, bool optimize) {
auto ir_emitter = std::make_unique<CodeGenT>(m.get(), b.get());
ir_emitter->Compile(module);

CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid module found";
PADDLE_ENFORCE_EQ(!llvm::verifyModule(*m, &llvm::errs()),
true,
phi::errors::InvalidArgument("Invalid module found."));

AddModule(std::move(m), optimize);
}
Expand Down
Loading

0 comments on commit e6cd91e

Please sign in to comment.