Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Parse dim args in codegen host #59395

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
@@ -210,6 +210,26 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) {
return f_;
}

llvm::Value* CodeGenCUDA_Host::LowerParseArgsValueCall(
const ir::Call* call_ir) {
auto ret_type = CinnTypeToLLVMType(Int(32), m_);
std::vector<llvm::Type*> args_type;
CHECK_EQ(call_ir->read_args.size(), 2);
CHECK(call_ir->read_args[0].is_var() &&
call_ir->read_args[0].as_var()->type().is_cpp_handle());
CHECK(call_ir->read_args[1].type().is_int(32));
args_type.push_back(CinnTypeToLLVMType(type_of<void*>(), m_));
args_type.push_back(CinnTypeToLLVMType(type_of<int32_t>(), m_));

auto func_type = llvm::FunctionType::get(ret_type, args_type, false);
auto call_func = m_->getOrInsertFunction(call_ir->name, func_type);

std::vector<llvm::Value*> call_args;
call_args.push_back(std::addressof(*f_->arg_begin()));
call_args.push_back(b_->getInt32(call_ir->read_args[1].as_int32()));
return b_->CreateCall(call_func, call_args);
}

llvm::Value* CodeGenCUDA_Host::LowerCUDAKernelCall(const ir::Call* call_ir) {
std::vector<llvm::Value*> ll_function_args;
std::transform(f_->arg_begin(),
11 changes: 10 additions & 1 deletion paddle/cinn/backends/codegen_cuda_host.h
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
#include <vector>

#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/runtime/intrinsic.h"

PD_DECLARE_bool(cinn_bucket_compile);

@@ -47,7 +48,13 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
}

llvm::Value *Visit(const ir::Call *op) override {
return LowerCUDAKernelCall(op);
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
return LowerParseArgsValueCall(op);
} else if (op->name == runtime::intrinsic::call_cuda_kernel) {
return LowerCUDAKernelCall(op);
} else {
CINN_NOT_IMPLEMENTED;
}
}

private:
@@ -68,6 +75,8 @@ class CodeGenCUDA_Host : public CodeGenLLVM {

llvm::Value *LowerHostFunc(const ir::_LoweredFunc_ *func);

llvm::Value *LowerParseArgsValueCall(const ir::Call *call_ir);

llvm::Value *LowerCUDAKernelCall(const ir::Call *op);
};

20 changes: 20 additions & 0 deletions paddle/cinn/backends/codegen_cuda_util.cc
Original file line number Diff line number Diff line change
@@ -108,6 +108,26 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
buckets_.emplace_back(ir::IfThenElse::Make(predicate, call_extern_api));
}

void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
ir::Expr func) {
std::vector<ir::Argument> args = func.as_lowered_func_ref()->args;
for (int i = 0; i < args.size(); ++i) {
if (args[i].is_var()) {
ir::Expr call_get_value_in_kernel_args =
ir::Call::Make(Int(32),
runtime::intrinsic::get_value_in_cuda_kernel_args,
{kernel_args_, ir::Expr(i)},
{},
ir::CallType::Extern,
ir::FunctionRef(),
0);
ir::Expr stmt = ir::Let::Make(ir::Expr(args[i].var_arg()),
call_get_value_in_kernel_args);
arg_defs_.push_back(stmt);
}
}
}

Expr detail::CollectBucketStrategyHostFunctionVisitor::CreateDeviceFunction(
ir::Expr expr, ir::Expr predicate) {
auto copied = ir::ir_utils::IRCopy(expr);
10 changes: 9 additions & 1 deletion paddle/cinn/backends/codegen_cuda_util.h
Original file line number Diff line number Diff line change
@@ -162,30 +162,38 @@ struct CollectBucketStrategyHostFunctionVisitor
CHECK_EQ(op->functions.size(), op->predicates.size());
for (int i = 0; i < op->functions.size(); ++i) {
ProcessLoweredFunc(op->functions[i], op->predicates[i]);
if (i == 0) {
ProcessArgs(op->functions[i]);
}
}

std::vector<ir::Argument> arguments = {
ir::Argument(kernel_args_, ir::Argument::IO::kOutput),
ir::Argument(kernel_args_num_, ir::Argument::IO::kInput),
ir::Argument(kernel_stream_, ir::Argument::IO::kOutput)};
std::vector<ir::Expr> body_stmts(arg_defs_);
body_stmts.insert(body_stmts.end(), buckets_.begin(), buckets_.end());
ir::Expr host_func =
ir::_LoweredFunc_::Make(op->functions[0].as_lowered_func()->name,
arguments,
ir::Block::Make(buckets_),
ir::Block::Make(body_stmts),
{});
host_module_builder.AddFunctionWithoutOptim(
host_func.as_lowered_func_ref());
}

void ProcessLoweredFunc(ir::Expr func, ir::Expr predicate);

void ProcessArgs(ir::Expr func);

Expr CreateDeviceFunction(ir::Expr expr, ir::Expr predicate);

inline std::string GenDeviceKernelName(const std::string& fn_name,
ir::Expr predicate);

private:
std::vector<ir::Expr> buckets_;
std::vector<ir::Expr> arg_defs_;

ir::Var kernel_args_;
ir::Var kernel_args_num_;
11 changes: 11 additions & 0 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@
PD_DECLARE_bool(cinn_use_cuda_vectorize);
PD_DECLARE_bool(cinn_enable_map_expr);
PD_DECLARE_bool(cinn_enable_map_expr_schedule);
PD_DECLARE_bool(cinn_bucket_compile);

namespace cinn {
namespace hlir {
@@ -501,6 +502,16 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
}
}

// add fake symbolic args for test
if (FLAGS_cinn_bucket_compile) {
group_func_args.emplace_back(ir::_Var_::Make("fake_symbol1", Int(32)),
ir::Argument::IO::kOutput);
group_func_args.emplace_back(ir::_Var_::Make("fake_symbol2", Int(32)),
ir::Argument::IO::kOutput);
group->output_names.push_back("fake_symbol1");
group->output_names.push_back("fake_symbol2");
}

#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(func_body));
#endif
8 changes: 8 additions & 0 deletions paddle/cinn/runtime/cuda/cuda_intrinsics.cc
Original file line number Diff line number Diff line change
@@ -426,6 +426,14 @@ CINN_REGISTER_HELPER(cuda_intrinsics) {
}

CINN_REGISTER_HELPER(cinn_cuda_host_api) {
using cinn::runtime::cuda::cinn_get_value_in_cuda_kernel_args;
REGISTER_EXTERN_FUNC_HELPER(cinn_get_value_in_cuda_kernel_args,
cinn::common::DefaultHostTarget())
.SetRetType<int64_t>()
.AddInputType<void *>() // args
.AddInputType<int>() // index
.End();

using cinn::runtime::cuda::cinn_call_cuda_kernel;
REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_kernel,
cinn::common::DefaultHostTarget())
5 changes: 5 additions & 0 deletions paddle/cinn/runtime/cuda/cuda_util.cc
Original file line number Diff line number Diff line change
@@ -78,6 +78,11 @@ class CublasHandle {
cublasHandle_t cuhandle;
};

int32_t cinn_get_value_in_cuda_kernel_args(void *v_args, int idx) {
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
return args[idx].operator int32_t();
}

void cinn_call_cuda_kernel(void *kernel_fn,
void *v_args,
int num_args,
2 changes: 2 additions & 0 deletions paddle/cinn/runtime/cuda/cuda_util.h
Original file line number Diff line number Diff line change
@@ -95,6 +95,8 @@ void cinn_call_cuda_memcpy(void* v_args,
size_t count,
void* stream = nullptr);

int32_t cinn_get_value_in_cuda_kernel_args(void* v_args, int idx);

/**
* Call a CUDA compiled kernel.
*
3 changes: 3 additions & 0 deletions paddle/cinn/runtime/intrinsic.h
Original file line number Diff line number Diff line change
@@ -104,6 +104,9 @@ static const char* print_debug_args_repr = "cinn_print_debug_args";

static const char* call_cuda_kernel = "cinn_call_cuda_kernel";

static const char* get_value_in_cuda_kernel_args =
"cinn_get_value_in_cuda_kernel_args";

static const char* pod_values_to_array_repr = "pod_values_to_array";

static const char* get_address_repr = "get_address";
8 changes: 7 additions & 1 deletion test/cpp/pir/cinn/compilation_task_test.cc
Original file line number Diff line number Diff line change
@@ -108,7 +108,13 @@ TEST(CompilationTask, CompileGroup) {
auto runtime_program = ir_compiler.Build(groups);

// Step 3: Execute Runtime Instruction and check Scope.
ASSERT_NO_THROW(runtime_program->Execute());
std::string arg_name = "var_0";
cinn_buffer_t* buffer = scope->GetTensor(arg_name)->buffer();
std::map<std::string, cinn_pod_value_t> name2podargs = {
{arg_name, cinn_pod_value_t(buffer)},
{"fake_symbol1", cinn_pod_value_t(int32_t(4096))},
{"fake_symbol2", cinn_pod_value_t(int32_t(128))}};
ASSERT_NO_THROW(runtime_program->Execute(&name2podargs));
for (auto& var_name : scope->var_names()) {
std::string name = {var_name.begin(), var_name.end()};
int64_t numel = scope->GetTensor(name)->shape().numel();