-
Notifications
You must be signed in to change notification settings - Fork 263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.78】为神经网络编译器 CINN 增加 gather、gather_nd、scatter、scatter_nd 算子 #210
Conversation
我在实现过程中发现需要使用cinn_host_find_value_nd函数,因此参考已有的cinn_host_find_value函数实现了相应函数,已在提案中添加此部分。 |
好的,近期会添加cuda的测试 |
我参考host的相关测试,创建了cinn/runtime/cuda/cuda_intrinsics_test.cc文件,但是在其中直接使用callextern会报错 codegen_llvm.cc:655] Check failed: callee Unknown function referenced. [cinn_cuda_index_add]。 TEST(find_value_nd, basic) {
Expr M(10), N(20);
Placeholder<float> x("x", {M, N});
auto y = Compute({N}, [&](Expr i) { return CallExtern("cinn_cuda_find_int_nd", {x, M, x({Expr(5), Expr(3)}), i, N}); });
auto stages = CreateStages({y});
auto jit = backends::SimpleJIT::Create();
ir::Module::Builder builder("module1", common::DefaultNVGPUTarget());
auto fn = Lower("fn", stages, {x, y});
LOG(INFO) << "fn:\n" << fn;
builder.AddFunction(fn);
jit->Link(builder.Build());
auto fn_ptr = jit->Lookup("fn");
auto fnp = reinterpret_cast<lower_func_ptr_t>(fn_ptr);
ASSERT_TRUE(fnp);
auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build();
auto* out_buf = common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build();
auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build();
fnp(args.data(), args.size());
auto* x_buf_data = reinterpret_cast<float*>(x_buf->memory);
auto* out_buf_data = reinterpret_cast<int*>(out_buf->memory);
for (int i = 0; i < out_buf->num_elements(); i++) {
LOG_FIRST_N(INFO, 3) << out_buf_data[i];
if (out_buf_data[i] != -1) {
ASSERT_NEAR(x_buf_data[out_buf_data[i] * 20 + i], x_buf_data[5 * 20 + 3], 1e-5);
}
}
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
基本实现已完成:PaddlePaddle/CINN#897