Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:gouzil/Paddle into error_msg/whi…
Browse files Browse the repository at this point in the history
…le_cond
  • Loading branch information
gouzil committed Oct 15, 2024
2 parents 01fd3cc + 159aa58 commit 7885ce9
Show file tree
Hide file tree
Showing 612 changed files with 17,330 additions and 5,580 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ if(NOT DEFINED XPU_XRE_BASE_VERSION)
set(XPU_XRE_BASE_VERSION "4.32.0.1")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "eb35/20240926")
set(XPU_XHPC_BASE_DATE "eb35/20240927")
endif()
set(XPU_XCCL_BASE_VERSION "1.2.11e")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
2 changes: 1 addition & 1 deletion cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ endif()

copy(
inference_lib_dist
SRCS ${CMAKE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h
SRCS ${CMAKE_BINARY_DIR}/paddle/phi/core/framework/framework.pb.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/internal)
copy(
inference_lib_dist
Expand Down
16 changes: 9 additions & 7 deletions paddle/cinn/adt/equation_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ std::unordered_map<Variable, Value> InferValuesImpl(
PADDLE_ENFORCE_EQ(
ctx->HasValue(in_variable),
true,
phi::errors::NotFound("The param id's out_iter must contain "
"its in_iter's value"));
::common::errors::NotFound("The param id's out_iter must contain "
"its in_iter's value"));
return {{out_iter.value(), ctx->GetValue(in_variable)}};
}

Expand All @@ -49,8 +49,8 @@ std::unordered_map<Variable, Value> InferValuesImpl(
PADDLE_ENFORCE_EQ(
ctx->HasValue(in_variable),
true,
phi::errors::NotFound("The param id's out_iter must contain "
"its in_iter's value"));
::common::errors::NotFound("The param id's out_iter must contain "
"its in_iter's value"));
return {{out_index.value(), ctx->GetValue(in_variable)}};
}

Expand Down Expand Up @@ -215,7 +215,7 @@ std::unordered_map<Variable, Value> InferValuesImpl(
PADDLE_ENFORCE_EQ(
ret.emplace(out_msg_in_indexes.value()->at(i), value).second,
true,
phi::errors::AlreadyExists([&]() {
::common::errors::AlreadyExists([&]() {
std::ostringstream oss;
oss << "Failed to insert the variable '"
<< "out_msg_in_indexes.value()->at(" << i
Expand All @@ -229,7 +229,7 @@ std::unordered_map<Variable, Value> InferValuesImpl(
if (out_index.has_value()) {
PADDLE_ENFORCE_EQ(ret.emplace(out_index.value(), value).second,
true,
phi::errors::AlreadyExists([&]() {
::common::errors::AlreadyExists([&]() {
std::ostringstream oss;
oss << "Failed to insert the variable '"
<< "out_index.value()"
Expand Down Expand Up @@ -306,7 +306,9 @@ void SolveEquations(
tValueInferSuccess<bool> has_unique_value =
MergeInferedValuesIntoCtx(function, ctx);
PADDLE_ENFORCE_EQ(
has_unique_value.value(), true, phi::errors::InvalidArgument([&]() {
has_unique_value.value(),
true,
::common::errors::InvalidArgument([&]() {
std::ostringstream oss;
oss << "Failed to merge inferred values into the context for "
"function '"
Expand Down
38 changes: 38 additions & 0 deletions paddle/cinn/backends/codegen_device_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,24 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
ir::CallType::Extern,
ir::FunctionRef(),
0);

// create memset calls for temp_spaces if needed
std::vector<ir::Expr> call_kernel_stmts;
for (auto &temp_space : func_node->temp_spaces) {
if (temp_space.need_zero_init()) {
ir::Expr size = common::cast(temp_space.size(), common::UInt(64));
ir::Expr call_get_arg =
lang::CallExtern(runtime::intrinsic::get_item_in_cuda_kernel_args,
{kernel_args_, ir::Expr(temp_space.arg_idx())});
ir::Expr call_memset = lang::CallExtern(
runtime::intrinsic::call_cuda_memset,
{call_get_arg, ir::Expr(1), ir::Expr(0), size, kernel_stream_});
call_kernel_stmts.push_back(call_memset);
}
}
call_kernel_stmts.push_back(call_extern_api);
call_extern_api = ir::Block::Make(call_kernel_stmts);

if (buckets_.empty()) {
buckets_.emplace_back(ir::IfThenElse::Make(predicate, call_extern_api));
} else {
Expand All @@ -270,6 +288,26 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
buckets_.emplace_back(
ir::IfThenElse::Make(predicate, call_extern_api, false_expr));
}

// create infer shape calls for temp_spaces
std::vector<ir::Expr> temp_space_infer_shape_stmts;
for (int i = 0; i < func_node->temp_spaces.size(); ++i) {
ir::Var tensor_shape_args(TENSOR_SHAPE_ARGS, type_of<int64_t **>());
ir::Expr size =
common::cast(func_node->temp_spaces[i].size(), common::Int(64));
ir::Expr call_set_value =
lang::CallExtern(runtime::intrinsic::infer_shape_set_value,
{ir::Expr(func_node->num_output_tensors + i),
ir::Expr(0),
size,
tensor_shape_args});
temp_space_infer_shape_stmts.push_back(call_set_value);
}
if (!temp_space_infer_shape_stmts.empty()) {
ir::Expr if_body = ir::Block::Make(temp_space_infer_shape_stmts);
temp_space_infer_shape_body_ =
ir::IfThenElse::Make(predicate, if_body, temp_space_infer_shape_body_);
}
}

void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/backends/codegen_device_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ struct CollectBucketStrategyHostFunctionVisitor
infer_shape_func_body_stmts.insert(
infer_shape_func_body_stmts.end(),
op->infer_shape_func.as_lowered_func()->body);
if (temp_space_infer_shape_body_.defined()) {
infer_shape_func_body_stmts.push_back(temp_space_infer_shape_body_);
}

std::vector<ir::Argument> infer_shape_arguments = {
ir::Argument(kernel_args_, ir::Argument::IO::kOutput),
Expand Down Expand Up @@ -307,6 +310,7 @@ struct CollectBucketStrategyHostFunctionVisitor
private:
std::vector<ir::Expr> buckets_;
std::vector<ir::Expr> arg_defs_;
ir::Expr temp_space_infer_shape_body_;

ir::Var kernel_args_;
ir::Var kernel_args_num_;
Expand Down
32 changes: 0 additions & 32 deletions paddle/cinn/backends/codegen_invoke_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,38 +62,6 @@ llvm::Value* CodeGenInvokeModule::LowerInvokeFunc(
return f_;
}

llvm::Value* CodeGenInvokeModule::LowerParseArgsValueCall(
const ir::Call* call_ir) {
auto ret_type = CinnTypeToLLVMType(Int(64), m_);
std::vector<llvm::Type*> args_type;
PADDLE_ENFORCE_EQ(
call_ir->read_args.size(),
2,
::common::errors::InvalidArgument(
"The number of arguments of ParseArgsValue should be 2"));
PADDLE_ENFORCE_EQ(call_ir->read_args[0].is_var() &&
call_ir->read_args[0].as_var()->type().is_cpp_handle(),
true,
::common::errors::InvalidArgument(
"The first read argument must be a variable "
"with a C++ handle type."));

PADDLE_ENFORCE_EQ(call_ir->read_args[1].type().is_int(32),
true,
::common::errors::InvalidArgument(
"The second read argument must be of type int32."));
args_type.push_back(CinnTypeToLLVMType(type_of<void*>(), m_));
args_type.push_back(CinnTypeToLLVMType(type_of<int32_t>(), m_));

auto func_type = llvm::FunctionType::get(ret_type, args_type, false);
auto call_func = m_->getOrInsertFunction(call_ir->name, func_type);

std::vector<llvm::Value*> call_args;
call_args.push_back(std::addressof(*f_->arg_begin()));
call_args.push_back(b_->getInt32(call_ir->read_args[1].as_int32()));
return b_->CreateCall(call_func, call_args);
}

llvm::Value* CodeGenSwitchHost::LowerInnerCaseCall(const ir::Call* op) {
std::vector<llvm::Value*> ll_function_args;
std::transform(f_->arg_begin(),
Expand Down
13 changes: 1 addition & 12 deletions paddle/cinn/backends/codegen_invoke_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,8 @@ class CodeGenInvokeModule : public CodeGenLLVM {
return LowerInvokeFunc(func);
}

llvm::Value *Visit(const ir::Call *op) override {
// TODO(Hongqing-work): change intrinsic name to get_value_in_kernel_args
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
return LowerParseArgsValueCall(op);
} else {
return CodeGenLLVM::Visit(op);
}
}

protected:
llvm::Value *LowerInvokeFunc(const ir::_LoweredFunc_ *func);

llvm::Value *LowerParseArgsValueCall(const ir::Call *call_ir);
};

class CodeGenHost : public CodeGenInvokeModule {
Expand All @@ -80,7 +69,7 @@ class CodeGenSwitchHost : public CodeGenInvokeModule {
// only support call of args get function and inner case host function call
llvm::Value *Visit(const ir::Call *op) override {
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
return CodeGenInvokeModule::LowerParseArgsValueCall(op);
return CodeGenLLVM::Visit(op);
} else {
return LowerInnerCaseCall(op);
}
Expand Down
76 changes: 6 additions & 70 deletions paddle/cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,53 +511,6 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) {
llvm::Value *CodeGenLLVM::CreateSerialFor(const ir::For *op, int stride) {
SymbolTableGuard symbol_table_guard(*symbol_table_);

do {
break;
llvm::BasicBlock *preheader_bb = b_->GetInsertBlock();
auto *for_begin = llvm::BasicBlock::Create(
b_->getContext(), "for_begin", b_->GetInsertBlock()->getParent());
auto *for_body = llvm::BasicBlock::Create(
b_->getContext(), "for_body", b_->GetInsertBlock()->getParent());
auto *for_end = llvm::BasicBlock::Create(
b_->getContext(), "for_end", b_->GetInsertBlock()->getParent());

Br(for_begin);
b_->SetInsertPoint(for_begin);

auto *begin = Visit(&op->min);
auto *loop_value = PHI(begin->getType(), 2);
loop_value->addIncoming(begin, preheader_bb);

llvm::Value *old_var = GetVar(op->loop_var->name);
SetVar(op->loop_var->name, loop_value);
auto *end = Visit(&op->extent);
CondBr(ICmpSLT(loop_value, end), for_body, for_end);
b_->SetInsertPoint(for_body);
Visit(&op->body);

if (old_var) {
SetVar(op->loop_var->name, old_var);
} else {
symbol_table_->Erase(op->loop_var->name);
}

auto loop_next = Add(loop_value,
llvm::ConstantInt::get(b_->getInt32Ty(), stride),
"indvar.inc",
true,
true);
loop_value->addIncoming(loop_next, b_->GetInsertBlock());

Br(for_begin);
b_->SetInsertPoint(for_end);

return nullptr;
// llvm::AllocaInst *loop_var = Alloca(b_->getInt32Ty(), nullptr,
// op->loop_var->name); loop_var->setAlignment(llvm::Align(4));
// SetVar(op->loop_var->name, loop_var);
} while (false);

////////////////////////////////////
llvm::BasicBlock *preheader_bb = b_->GetInsertBlock();
llvm::BasicBlock *exit_bb = nullptr;

Expand Down Expand Up @@ -814,20 +767,13 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Module_ *op) {
}

llvm::Value *CodeGenLLVM::Visit(const ir::_Var_ *op) {
llvm::Value *value = GetVar(op->name, false);
llvm::Value *result{};
CHECK(value) << "ir::_Var_[" << op->name << "]: value is null";
// TODO(fc500110) hard coding
if (LLVM_WillVarLowerAsPointer(op->name)) {
result = value;
} else if (value->getType()->isPointerTy() &&
!value->getType()->getPointerElementType()->isPointerTy()) {
result = Load(value, op->name + "_load");
} else {
result = value;
llvm::Value *value = GetVar(op->name, /* lazy= */ false);
// When visiting a Var that is allocated on the stack, we are actually
// reading its value instead of its address.
if (llvm::AllocaInst::classof(value)) {
return Load(value, op->name + "_load");
}

return result;
return value;
}

void CodeGenLLVM::Scalarize(
Expand Down Expand Up @@ -1043,12 +989,6 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Buffer_ *op) {

llvm::Value *CodeGenLLVM::Visit(const ir::_Tensor_ *op) {
return GetVar(op->name);
auto *buffer_op = op->buffer.As<ir::_Buffer_>();
if (symbol_table_->Lookup(buffer_op->name)) {
return Visit(buffer_op);
}

return SetVar(buffer_op->name, Visit(buffer_op));
}

template <typename T,
Expand Down Expand Up @@ -1437,10 +1377,6 @@ void CodeGenLLVM::InitTarget(const Target &target) {
naive_vec_alignment_ = GetNaiveVecAlignment(target);
}

bool LLVM_WillVarLowerAsPointer(const std::string &var_name) {
return var_name == "_args" || utils::EndsWith(var_name, "__ptr");
}

void CodeGenLLVM::AddTbaaMetadata(llvm::Instruction *inst,
absl::string_view buffer,
Expr index) {
Expand Down
8 changes: 0 additions & 8 deletions paddle/cinn/backends/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,6 @@ class LLVMIRVisitor : public ir::IRVisitorRequireReImpl<llvm::Value *> {
#undef __m
};

/**
* Tell whether a variable called \p \var_name will lowered to a pointer type in
* LLVM.
* @param var_name name of the variable.
* @return a boolean.
*/
bool LLVM_WillVarLowerAsPointer(const std::string &var_name);

class SymbolTable {
public:
SymbolTable() = default;
Expand Down
36 changes: 36 additions & 0 deletions paddle/cinn/common/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,41 @@ inline std::optional<ir::Expr> TryConstFold<ir::Mul>(ir::Expr a, ir::Expr b) {
return std::nullopt;
}

template <>
inline std::optional<ir::Expr> TryConstFold<ir::Div>(ir::Expr a, ir::Expr b) {
const ir::IntImm* pa = a.As<ir::IntImm>();
const ir::IntImm* pb = b.As<ir::IntImm>();
const auto& rtype = a.type();
if (pa && pb) {
int64_t res = pa->value / pb->value;
return cinn::common::make_shared<ir::IntImm>(rtype, res);
}
if (pa) {
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return a;
}
return std::nullopt;
}

template <>
inline std::optional<ir::Expr> TryConstFold<ir::Mod>(ir::Expr a, ir::Expr b) {
const ir::IntImm* pa = a.As<ir::IntImm>();
const ir::IntImm* pb = b.As<ir::IntImm>();
const auto& rtype = a.type();
if (pa && pb) {
int64_t res = pa->value % pb->value;
return cinn::common::make_shared<ir::IntImm>(rtype, res);
}
if (pa) {
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return ir::Zero(rtype);
}
return std::nullopt;
}

} // namespace common
} // namespace cinn
3 changes: 2 additions & 1 deletion paddle/cinn/common/dim_expr_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ struct DimExprConverterWithSymbolBindings::
return inputs_[input_idx]->sym_shape[input_dim_idx]->GetDimExpr();
}
// for data binding [S0, a, b], inputs[a] is Tensor A, return A(b)
return inputs_[input_idx](cinn::ir::Expr(input_dim_idx));
return ir::Cast::Make(cinn::common::I64(),
inputs_[input_idx](cinn::ir::Expr(input_dim_idx)));
}

DimExprToIrExprVisitorWithSymbolBinding(
Expand Down
Loading

0 comments on commit 7885ce9

Please sign in to comment.