Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.78】为神经网络编译器 CINN 增加 gather、gather_nd、scatter、scatter_nd 算子 #210

Merged
merged 36 commits into from
Sep 16, 2022

Conversation

zrr1999
Copy link
Member

@zrr1999 zrr1999 commented Aug 17, 2022

基本实现已完成:PaddlePaddle/CINN#897

@paddle-bot
Copy link

paddle-bot bot commented Aug 17, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请检查PR提交格式和内容是否完备,具体请参考示例模版
Your PR has been submitted. Thanks for your contribution!
Please check its format and content. For this, you can refer to Template and Demo.

@zrr1999 zrr1999 marked this pull request as ready for review August 17, 2022 03:11
@zrr1999 zrr1999 changed the title 【Hackathon No.69】为神经网络编译器 CINN 增加 gather、gather_nd、scatter、scatter_nd 算子 【Hackathon No.78】为神经网络编译器 CINN 增加 gather、gather_nd、scatter、scatter_nd 算子 Aug 17, 2022
@zrr1999
Copy link
Member Author

zrr1999 commented Aug 24, 2022

我在实现过程中发现需要使用cinn_host_find_value_nd函数,因此参考已有的cinn_host_find_value函数实现了相应函数,已在提案中添加此部分。
具体实现见https://github.com/PaddlePaddle/CINN/pull/897/files ,host已添加相应测试,cuda因为实现方法几乎一致,暂时还没有添加测试。

@zrr1999
Copy link
Member Author

zrr1999 commented Aug 29, 2022

image

格式似乎还是不太对,如上图,是我看到的公式,别的一些地方格式也不对。你可以用我之前comment的view file方法预览一下

这部分已经修改完毕,我把例子改成了代码形式

@zrr1999
Copy link
Member Author

zrr1999 commented Aug 29, 2022

我在实现过程中发现需要使用cinn_host_find_value_nd函数,因此参考已有的cinn_host_find_value函数实现了相应函数,已在提案中添加此部分。 具体实现见https://github.com/PaddlePaddle/CINN/pull/897/files ,host已添加相应测试,cuda因为实现方法几乎一致,暂时还没有添加测试。

RFC除了格式,别的应该问题不大了。代码方面麻烦添加cuda测试,因为实际软件工程中,哪怕实现方法几乎的代码,测试时就有可能出问题。请确保一下没问题。

好的,近期会添加cuda的测试

@zrr1999
Copy link
Member Author

zrr1999 commented Aug 30, 2022

我在实现过程中发现需要使用cinn_host_find_value_nd函数,因此参考已有的cinn_host_find_value函数实现了相应函数,已在提案中添加此部分。 具体实现见https://github.com/PaddlePaddle/CINN/pull/897/files ,host已添加相应测试,cuda因为实现方法几乎一致,暂时还没有添加测试。

RFC除了格式,别的应该问题不大了。代码方面麻烦添加cuda测试,因为实际软件工程中,哪怕实现方法几乎的代码,测试时就有可能出问题。请确保一下没问题。

我参考host的相关测试,创建了cinn/runtime/cuda/cuda_intrinsics_test.cc文件,但是在其中直接使用callextern会报错 codegen_llvm.cc:655] Check failed: callee Unknown function referenced. [cinn_cuda_index_add]。
我参考了scatterindex等pr在实现相应cuda函数时均未添加测试,所以我不确定是否可以用类似host写测试的方法写cuda的测试,希望可以有一个例子。

TEST(find_value_nd, basic) {
  Expr M(10), N(20);
  Placeholder<float> x("x", {M, N});
  auto y = Compute({N}, [&](Expr i) { return CallExtern("cinn_cuda_find_int_nd", {x, M, x({Expr(5), Expr(3)}), i, N}); });

  auto stages = CreateStages({y});

  auto jit = backends::SimpleJIT::Create();

  ir::Module::Builder builder("module1", common::DefaultNVGPUTarget());

  auto fn = Lower("fn", stages, {x, y});
  LOG(INFO) << "fn:\n" << fn;

  builder.AddFunction(fn);

  jit->Link(builder.Build());

  auto fn_ptr = jit->Lookup("fn");
  auto fnp    = reinterpret_cast<lower_func_ptr_t>(fn_ptr);
  ASSERT_TRUE(fnp);

  auto* x_buf   = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build();
  auto* out_buf = common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build();
  auto args     = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build();
  fnp(args.data(), args.size());

  auto* x_buf_data   = reinterpret_cast<float*>(x_buf->memory);
  auto* out_buf_data = reinterpret_cast<int*>(out_buf->memory);

  for (int i = 0; i < out_buf->num_elements(); i++) {
    LOG_FIRST_N(INFO, 3) << out_buf_data[i];
    if (out_buf_data[i] != -1) {
      ASSERT_NEAR(x_buf_data[out_buf_data[i] * 20 + i], x_buf_data[5 * 20 + 3], 1e-5);
    }
  }
}

@zrr1999
Copy link
Member Author

zrr1999 commented Sep 1, 2022

@zhhsplendid

Copy link
Member

@zhhsplendid zhhsplendid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhhsplendid zhhsplendid merged commit 022dcc5 into PaddlePaddle:master Sep 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants