Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

【PaddlePaddle Hackathon No.78】add gather, gather_nd, scatter and scatter_nd op #897

Merged
merged 44 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b30a7da
add gather and scatter op
zrr1999 Aug 8, 2022
9675f38
fix bug
zrr1999 Aug 21, 2022
6d02455
fix bug
zrr1999 Aug 22, 2022
0a8872a
fix bug
zrr1999 Aug 22, 2022
dabe72c
fix bug
zrr1999 Aug 23, 2022
3b4e976
fix bug
zrr1999 Aug 24, 2022
cce1792
fix bug
zrr1999 Aug 24, 2022
abca2db
fix bug
zrr1999 Aug 24, 2022
417b389
fix bug
zrr1999 Aug 24, 2022
f3a607c
fix bug
zrr1999 Aug 24, 2022
1a39c01
fix bug
zrr1999 Aug 24, 2022
107c519
fix bug
zrr1999 Aug 24, 2022
bd6cd48
fix bug
zrr1999 Aug 24, 2022
97d9578
fix bug
zrr1999 Aug 24, 2022
4e0c3d5
fix bug
zrr1999 Aug 24, 2022
ba87eca
fix bug
zrr1999 Aug 24, 2022
94834cf
fix bug
zrr1999 Aug 24, 2022
6520697
fix bug
zrr1999 Aug 24, 2022
b5cfedb
fix bug
zrr1999 Aug 24, 2022
8b08876
fix bug
zrr1999 Aug 24, 2022
feaa87b
fix bug
zrr1999 Aug 24, 2022
42c1754
fix bug
zrr1999 Aug 25, 2022
6fde20f
fix bug
zrr1999 Aug 25, 2022
4dc115a
Merge branch 'develop' into gather
zrr1999 Aug 29, 2022
3867a21
fix bug
zrr1999 Aug 29, 2022
a2c4cf4
Merge branch 'develop' into gather
zrr1999 Aug 30, 2022
f6d5ca1
fix bug
zrr1999 Aug 30, 2022
21b3787
Merge branch 'develop' into gather
zrr1999 Sep 12, 2022
3d2e079
fix bugs
zrr1999 Sep 12, 2022
3025621
fix bugs
zrr1999 Sep 14, 2022
7d656a8
Merge branch 'develop' into gather
zrr1999 Sep 14, 2022
c0c97d3
fix bugs
zrr1999 Sep 14, 2022
0c68371
fix bugs
zrr1999 Sep 14, 2022
ad91b7c
fix bugs
zrr1999 Sep 14, 2022
9ad4e2b
add pybind
zrr1999 Sep 15, 2022
658b569
fix bugs
zrr1999 Sep 15, 2022
e0ea015
fix bugs
zrr1999 Sep 15, 2022
8d25e23
fix bugs
zrr1999 Sep 16, 2022
58eb499
fix bugs
zrr1999 Sep 16, 2022
08714e9
add SetRandint function
zrr1999 Sep 18, 2022
f2d5f12
fix bugs
zrr1999 Sep 18, 2022
bdb4ade
fix bugs
zrr1999 Sep 19, 2022
8b1be8e
fix bugs
zrr1999 Sep 19, 2022
671cbe5
fix bugs
zrr1999 Sep 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,41 @@ Variable NetBuilder::ReluGrad(const Variable& lhs, const Variable& rhs) {
return CustomInstr("relu_grad", {lhs, rhs}, {}).front();
}

Variable NetBuilder::Gather(const Variable& x, const Variable& index, const int& axis) {
return CustomInstr("gather", {x, index}, {{"axis", axis}}).front();
}

Variable NetBuilder::GatherNd(const Variable& x, const Variable& index, const std::vector<int>& axes) {
return CustomInstr("gather_nd", {x, index}, {{"axes", axes}}).front();
}

Variable NetBuilder::Scatter(const Variable& src, const Variable& index, const Variable& out, const int& axis) {
return CustomInstr("scatter", {src, index, out}, {{"axis", axis}}).front();
}
Variable NetBuilder::Scatter(const Variable& src,
const Variable& index,
const std::vector<int>& shape,
const float& default_value,
const int& axis) {
auto out = FillConstant(shape, default_value, UniqName("fill_constant"), "float", false);
return Scatter(src, index, out, axis);
}

Variable NetBuilder::ScatterNd(const Variable& src,
const Variable& index,
const Variable& out,
const std::vector<int>& axes) {
return CustomInstr("scatter_nd", {src, index, out}, {{"axes", axes}}).front();
}
Variable NetBuilder::ScatterNd(const Variable& src,
const Variable& index,
const std::vector<int>& shape,
const float& default_value,
const std::vector<int>& axes) {
auto out = FillConstant(shape, default_value, UniqName("fill_constant"), "float", false);
return ScatterNd(src, index, out, axes);
}

Variable NetBuilder::Cast(const Variable& operand, const std::string& dtype) {
if (operand->type == common::Str2Type(dtype)) {
return Identity(operand);
Expand Down
26 changes: 24 additions & 2 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,27 @@ class NetBuilder {
*/
Variable Clip(const std::vector<Variable>& x, const float& max, const float& min);

Variable Gather(const Variable& x, const Variable& index, const int& axis = 0);

Variable GatherNd(const Variable& x, const Variable& index, const std::vector<int>& axes = {});

Variable Scatter(const Variable& src, const Variable& index, const Variable& out, const int& axis = 0);
Variable Scatter(const Variable& src,
const Variable& index,
const std::vector<int>& shape,
const float& default_value = 0,
const int& axis = 0);

Variable ScatterNd(const Variable& src,
const Variable& index,
const Variable& out,
const std::vector<int>& axes = {});
Variable ScatterNd(const Variable& src,
const Variable& index,
const std::vector<int>& shape,
const float& default_value = 0,
const std::vector<int>& axes = {});

/**
* @brief This operator checks if all `x` and `y` satisfy the condition: `|x - y| <= atol + rtol * |y|`
* @param x The first variable.
Expand Down Expand Up @@ -793,7 +814,7 @@ class NetBuilder {
const std::string& data_layout = "NCHW");

/**
* @brief Sort Variable x along the given axis. The original Variable x will not be changed.
* @brief Sort Variable x along the given axis and return sorted index. The original Variable x will not be changed.
* @param operand The variable that will be sorted.
* @param axis Specify the axis to operate on the input. Default: 0.
* @param is_ascend Sort mode.
Expand All @@ -803,7 +824,8 @@ class NetBuilder {
Variable ArgSort(const Variable& operand, const int& axis, const bool& is_ascend = true);

/**
* @brief Sort Variable x along the given axis. The original Variable x will not be changed.
* @brief Sort Variable x along the given axis and return sorted variable. The original Variable x will not be
* changed.
* @param operand The variable that will be sorted.
* @param axis Specify the axis to operate on the input. Default: 0.
* @param is_ascend Sort mode.
Expand Down
254 changes: 253 additions & 1 deletion cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,258 @@ TEST(net_build, program_execute_clip) {
}
}

TEST(net_build, program_execute_gather) {
const int B = 4;
const int H_IN1 = 11;
const int H_IN2 = 14;

NetBuilder builder("net_builder");
Placeholder input1 = builder.CreateInput(Float(32), {B, H_IN1}, "In1");
Placeholder input2 = builder.CreateInput(Int(32), {B, H_IN2}, "In2");
Variable output = builder.Gather(input1, input2, 1);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input1.id()));
scope->Var<hlir::framework::Tensor>(std::string(input2.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input1_tensor = scope->GetTensor(std::string(input1.id()));
SetRandData<float>(input1_tensor, target);
float* input1_data = input1_tensor->mutable_data<float>(target);

auto input2_tensor = scope->GetTensor(std::string(input2.id()));
SetRandInt(input2_tensor, target);
int* input2_data = input2_tensor->mutable_data<int>(target);
memset(input2_data, 0, sizeof(int) * B * H_IN2);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_tensor->type(), Float(32));
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], B);
EXPECT_EQ(output_shape[1], H_IN2);

float* output_data = output_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize output_data";
for (int b = 0; b < B; ++b) {
for (int h = 0; h < H_IN2; ++h) {
std::string line;
int index = h + H_IN2 * b;
float in_data = input1_data[input2_data[index] + H_IN1 * b];
float out_data = output_data[index];
line += (std::to_string(out_data) + ", ");
EXPECT_EQ(in_data, out_data);
VLOG(6) << line;
}
}
}

TEST(net_build, program_execute_gather_nd) {
const int B = 4;
const int H_IN1 = 11;
const int H_IN2 = 14;

NetBuilder builder("net_builder");
Placeholder input1 = builder.CreateInput(Float(32), {B, H_IN1}, "In1");
Placeholder input2 = builder.CreateInput(Int(32), {B, H_IN2, 1}, "In2");
Variable output = builder.GatherNd(input1, input2, {1});
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input1.id()));
scope->Var<hlir::framework::Tensor>(std::string(input2.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input1_tensor = scope->GetTensor(std::string(input1.id()));
SetRandData<float>(input1_tensor, target);
float* input1_data = input1_tensor->mutable_data<float>(target);

auto input2_tensor = scope->GetTensor(std::string(input2.id()));
SetRandInt(input2_tensor, target);
int* input2_data = input2_tensor->mutable_data<int>(target);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_tensor->type(), Float(32));
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], B);
EXPECT_EQ(output_shape[1], H_IN2);

float* output_data = output_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize output_data";
for (int b = 0; b < B; ++b) {
for (int h = 0; h < H_IN2; ++h) {
std::string line;
int index = h + H_IN2 * b;
float in_data = input1_data[input2_data[index] + H_IN1 * b];
float out_data = output_data[index];
line += (std::to_string(out_data) + ", ");
EXPECT_EQ(in_data, out_data);
VLOG(6) << line;
}
}
}

TEST(net_build, program_execute_scatter) {
const float default_value = 3.14;
const int B = 3;
const int H_IN = 4;
const int H_OUT = 11;

NetBuilder builder("net_builder");
Placeholder input1 = builder.CreateInput(Float(32), {B, H_IN}, "In1");
Placeholder input2 = builder.CreateInput(Int(32), {B, H_IN}, "In2");
Variable output = builder.Scatter(input1, input2, {B, H_OUT}, default_value, 1);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input1.id()));
scope->Var<hlir::framework::Tensor>(std::string(input2.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input1_tensor = scope->GetTensor(std::string(input1.id()));
SetRandData<float>(input1_tensor, target);
float* input1_data = input1_tensor->mutable_data<float>(target);

auto input2_tensor = scope->GetTensor(std::string(input2.id()));
SetRandInt(input2_tensor, target);
int* input2_data = input2_tensor->mutable_data<int>(target);
memset(input2_data, 0, sizeof(int) * B * H_IN);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_tensor->type(), Float(32));
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], B);
EXPECT_EQ(output_shape[1], H_OUT);

float true_data[B * H_OUT];
for (int b = 0; b < B; ++b) {
for (int h = 0; h < H_OUT; ++h) {
int index = h + H_OUT * b;
true_data[index] = default_value;
}
}
for (int b = 0; b < B; ++b) {
for (int h = 0; h < H_IN; ++h) {
int index = h + H_IN * b;
true_data[input2_data[index] + H_OUT * b] = input1_data[index];
std::cout << index << " " << input2_data[index] + H_OUT * b << " " << true_data[input2_data[index] + H_OUT * b];
}
}

float* output_data = output_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize output_data";
for (int b = 0; b < B; ++b) {
for (int h = 0; h < H_OUT; ++h) {
std::string line;
int index = h + H_OUT * b;
float t_data = true_data[index];
float out_data = output_data[index];
line += (std::to_string(out_data) + ", ");
EXPECT_EQ(t_data, out_data);
VLOG(6) << line;
}
}
}

TEST(net_build, program_execute_scatter_nd) {
const float default_value = 3.14;
const int B = 3;
const int H_IN = 4;
const int H_OUT = 11;

NetBuilder builder("net_builder");
Placeholder input1 = builder.CreateInput(Float(32), {B, H_IN}, "In1");
Placeholder input2 = builder.CreateInput(Int(32), {B, H_IN, 1}, "In2");
Variable output = builder.ScatterNd(input1, input2, {B, H_OUT}, default_value, {1});
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input1.id()));
scope->Var<hlir::framework::Tensor>(std::string(input2.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input1_tensor = scope->GetTensor(std::string(input1.id()));
SetRandData<float>(input1_tensor, target);

auto input2_tensor = scope->GetTensor(std::string(input2.id()));
SetRandInt(input2_tensor, target);

runtime_program->Execute();

int* input2_data;
float* input1_data;
input2_data = input2_tensor->mutable_data<int>(target);
input1_data = input1_tensor->mutable_data<float>(target);

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_tensor->type(), Float(32));
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], B);
EXPECT_EQ(output_shape[1], H_OUT);

float true_data[B * H_OUT];
for (int b = 0; b < B; ++b) {
for (int h = 0; h < H_OUT; ++h) {
int index = h + H_OUT * b;
true_data[index] = default_value;
}
}
for (int b = 0; b < B; ++b) {
for (int h = 0; h < H_IN; ++h) {
int index = h + H_IN * b;
true_data[input2_data[index] + H_OUT * b] = input1_data[index];
}
}

float* output_data = output_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize output_data";
for (int b = 0; b < B; ++b) {
for (int h = 0; h < H_OUT; ++h) {
std::string line;
int index = h + H_OUT * b;
float t_data = true_data[index];
float out_data = output_data[index];
line += (std::to_string(out_data) + ", ");
EXPECT_EQ(t_data, out_data);
VLOG(6) << line;
}
}
}

TEST(net_build, program_execute_cast) {
const int B = 4;
const int H = 7;
Expand All @@ -266,7 +518,7 @@ TEST(net_build, program_execute_cast) {
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<int>(input_tensor, target);
SetRandInt(input_tensor, target);
int* input_data = input_tensor->mutable_data<int>(target);

runtime_program->Execute();
Expand Down
4 changes: 4 additions & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS
gather.cc
scatter.cc
cast.cc
squeeze.cc
clip.cc
Expand All @@ -11,6 +13,8 @@ gather_srcs(cinnapi_src SRCS

cc_test(test_cast SRCS cast_test.cc DEPS cinncore)
cc_test(test_squeeze SRCS squeeze_test.cc DEPS cinncore)
cc_test(test_gather SRCS gather_test.cc DEPS cinncore)
cc_test(test_scatter SRCS scatter_test.cc DEPS cinncore)
cc_test(test_clip SRCS clip_test.cc DEPS cinncore)
cc_test(test_sort SRCS sort_test.cc DEPS cinncore)
cc_test(test_arange SRCS arange_test.cc DEPS cinncore)
Loading