Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Parse dim args in codegen host #59395

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,26 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) {
return f_;
}

llvm::Value* CodeGenCUDA_Host::LowerParseArgsValueCall(
const ir::Call* call_ir) {
auto ret_type = CinnTypeToLLVMType(Int(32), m_);
std::vector<llvm::Type*> args_type;
CHECK_EQ(call_ir->read_args.size(), 2);
CHECK(call_ir->read_args[0].is_var() &&
call_ir->read_args[0].as_var()->type().is_cpp_handle());
CHECK(call_ir->read_args[1].type().is_int(32));
args_type.push_back(CinnTypeToLLVMType(type_of<void*>(), m_));
args_type.push_back(CinnTypeToLLVMType(type_of<int32_t>(), m_));

auto func_type = llvm::FunctionType::get(ret_type, args_type, false);
auto call_func = m_->getOrInsertFunction(call_ir->name, func_type);

std::vector<llvm::Value*> call_args;
call_args.push_back(std::addressof(*f_->arg_begin()));
call_args.push_back(b_->getInt32(call_ir->read_args[1].as_int32()));
return b_->CreateCall(call_func, call_args);
}

llvm::Value* CodeGenCUDA_Host::LowerCUDAKernelCall(const ir::Call* call_ir) {
std::vector<llvm::Value*> ll_function_args;
std::transform(f_->arg_begin(),
Expand Down
11 changes: 10 additions & 1 deletion paddle/cinn/backends/codegen_cuda_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <vector>

#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/runtime/intrinsic.h"

PD_DECLARE_bool(cinn_bucket_compile);

Expand All @@ -47,7 +48,13 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
}

llvm::Value *Visit(const ir::Call *op) override {
return LowerCUDAKernelCall(op);
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
return LowerParseArgsValueCall(op);
} else if (op->name == runtime::intrinsic::call_cuda_kernel) {
return LowerCUDAKernelCall(op);
} else {
CINN_NOT_IMPLEMENTED;
}
}

private:
Expand All @@ -68,6 +75,8 @@ class CodeGenCUDA_Host : public CodeGenLLVM {

llvm::Value *LowerHostFunc(const ir::_LoweredFunc_ *func);

llvm::Value *LowerParseArgsValueCall(const ir::Call *call_ir);

llvm::Value *LowerCUDAKernelCall(const ir::Call *op);
};

Expand Down
20 changes: 20 additions & 0 deletions paddle/cinn/backends/codegen_cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,26 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
buckets_.emplace_back(ir::IfThenElse::Make(predicate, call_extern_api));
}

void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
ir::Expr func) {
std::vector<ir::Argument> args = func.as_lowered_func_ref()->args;
for (int i = 0; i < args.size(); ++i) {
if (args[i].is_var()) {
ir::Expr call_get_value_in_kernel_args =
ir::Call::Make(Int(32),
runtime::intrinsic::get_value_in_cuda_kernel_args,
{kernel_args_, ir::Expr(i)},
{},
ir::CallType::Extern,
ir::FunctionRef(),
0);
ir::Expr stmt = ir::Let::Make(ir::Expr(args[i].var_arg()),
call_get_value_in_kernel_args);
arg_defs_.push_back(stmt);
}
}
}

Expr detail::CollectBucketStrategyHostFunctionVisitor::CreateDeviceFunction(
ir::Expr expr, ir::Expr predicate) {
auto copied = ir::ir_utils::IRCopy(expr);
Expand Down
10 changes: 9 additions & 1 deletion paddle/cinn/backends/codegen_cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,30 +162,38 @@ struct CollectBucketStrategyHostFunctionVisitor
CHECK_EQ(op->functions.size(), op->predicates.size());
for (int i = 0; i < op->functions.size(); ++i) {
ProcessLoweredFunc(op->functions[i], op->predicates[i]);
if (i == 0) {
ProcessArgs(op->functions[i]);
}
}

std::vector<ir::Argument> arguments = {
ir::Argument(kernel_args_, ir::Argument::IO::kOutput),
ir::Argument(kernel_args_num_, ir::Argument::IO::kInput),
ir::Argument(kernel_stream_, ir::Argument::IO::kOutput)};
std::vector<ir::Expr> body_stmts(arg_defs_);
body_stmts.insert(body_stmts.end(), buckets_.begin(), buckets_.end());
ir::Expr host_func =
ir::_LoweredFunc_::Make(op->functions[0].as_lowered_func()->name,
arguments,
ir::Block::Make(buckets_),
ir::Block::Make(body_stmts),
{});
host_module_builder.AddFunctionWithoutOptim(
host_func.as_lowered_func_ref());
}

void ProcessLoweredFunc(ir::Expr func, ir::Expr predicate);

void ProcessArgs(ir::Expr func);

Expr CreateDeviceFunction(ir::Expr expr, ir::Expr predicate);

inline std::string GenDeviceKernelName(const std::string& fn_name,
ir::Expr predicate);

private:
std::vector<ir::Expr> buckets_;
std::vector<ir::Expr> arg_defs_;

ir::Var kernel_args_;
ir::Var kernel_args_num_;
Expand Down
11 changes: 11 additions & 0 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
PD_DECLARE_bool(cinn_use_cuda_vectorize);
PD_DECLARE_bool(cinn_enable_map_expr);
PD_DECLARE_bool(cinn_enable_map_expr_schedule);
PD_DECLARE_bool(cinn_bucket_compile);

namespace cinn {
namespace hlir {
Expand Down Expand Up @@ -501,6 +502,16 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
}
}

// add fake symbolic args for test
if (FLAGS_cinn_bucket_compile) {
group_func_args.emplace_back(ir::_Var_::Make("fake_symbol1", Int(32)),
ir::Argument::IO::kOutput);
group_func_args.emplace_back(ir::_Var_::Make("fake_symbol2", Int(32)),
ir::Argument::IO::kOutput);
group->output_names.push_back("fake_symbol1");
group->output_names.push_back("fake_symbol2");
}

#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(func_body));
#endif
Expand Down
8 changes: 8 additions & 0 deletions paddle/cinn/runtime/cuda/cuda_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,14 @@ CINN_REGISTER_HELPER(cuda_intrinsics) {
}

CINN_REGISTER_HELPER(cinn_cuda_host_api) {
using cinn::runtime::cuda::cinn_get_value_in_cuda_kernel_args;
REGISTER_EXTERN_FUNC_HELPER(cinn_get_value_in_cuda_kernel_args,
cinn::common::DefaultHostTarget())
.SetRetType<int64_t>()
.AddInputType<void *>() // args
.AddInputType<int>() // index
.End();

using cinn::runtime::cuda::cinn_call_cuda_kernel;
REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_kernel,
cinn::common::DefaultHostTarget())
Expand Down
5 changes: 5 additions & 0 deletions paddle/cinn/runtime/cuda/cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class CublasHandle {
cublasHandle_t cuhandle;
};

int32_t cinn_get_value_in_cuda_kernel_args(void *v_args, int idx) {
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
return args[idx].operator int32_t();
}

void cinn_call_cuda_kernel(void *kernel_fn,
void *v_args,
int num_args,
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/runtime/cuda/cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ void cinn_call_cuda_memcpy(void* v_args,
size_t count,
void* stream = nullptr);

int32_t cinn_get_value_in_cuda_kernel_args(void* v_args, int idx);

/**
* Call a CUDA compiled kernel.
*
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/runtime/intrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ static const char* print_debug_args_repr = "cinn_print_debug_args";

static const char* call_cuda_kernel = "cinn_call_cuda_kernel";

static const char* get_value_in_cuda_kernel_args =
"cinn_get_value_in_cuda_kernel_args";

static const char* pod_values_to_array_repr = "pod_values_to_array";

static const char* get_address_repr = "get_address";
Expand Down
8 changes: 7 additions & 1 deletion test/cpp/pir/cinn/compilation_task_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,13 @@ TEST(CompilationTask, CompileGroup) {
auto runtime_program = ir_compiler.Build(groups);

// Step 3: Execute Runtime Instruction and check Scope.
ASSERT_NO_THROW(runtime_program->Execute());
std::string arg_name = "var_0";
cinn_buffer_t* buffer = scope->GetTensor(arg_name)->buffer();
std::map<std::string, cinn_pod_value_t> name2podargs = {
{arg_name, cinn_pod_value_t(buffer)},
{"fake_symbol1", cinn_pod_value_t(int32_t(4096))},
{"fake_symbol2", cinn_pod_value_t(int32_t(128))}};
ASSERT_NO_THROW(runtime_program->Execute(&name2podargs));
for (auto& var_name : scope->var_names()) {
std::string name = {var_name.begin(), var_name.end()};
int64_t numel = scope->GetTensor(name)->shape().numel();
Expand Down