Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

【Hackathon No.73】add one_hot op #963

Merged
merged 4 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,16 @@ Variable NetBuilder::Cast(const Variable& operand, const std::string& dtype) {
return CustomInstr("cast", {operand}, {{"dtype", dtype}}).front();
}

Variable NetBuilder::OneHot(const Variable& indices,
const Variable& on_value,
const Variable& off_value,
const int depth,
const int axis,
const std::string& dtype) {
return CustomInstr("one_hot", {indices, on_value, off_value}, {{"depth", depth}, {"axis", axis}, {"dtype", dtype}})
.front();
}

Variable NetBuilder::Squeeze(const Variable& operand, const std::vector<int>& axes) {
return CustomInstr("squeeze", {operand}, {{"axes", axes}}).front();
}
Expand Down
15 changes: 15 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,21 @@ class NetBuilder {
*/
Variable Cast(const Variable& x, const std::string& dtype);

/**
* @brief Returns a one-hot tensor where the locations repsented by indices take value `on_value`,
* other locations take value `off_value`.
* @param on_value Value to fill at indices. Its shape must be [1].
* @param on_value Value to fill at all other positions besides indices. Its shape must be [1]
* @param depth Depth of the one-hot dimension.
* @param axis Axis to fill.
*/
Variable OneHot(const Variable& indices,
const Variable& on_value,
const Variable& off_value,
const int depth,
const int axis,
const std::string& dtype);

// *******************************************
// Decomposer Operator
/**
Expand Down
95 changes: 95 additions & 0 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1351,5 +1351,100 @@ TEST(net_build, program_execute_repeat_axis_1) {
}
}

TEST(net_build, program_execute_one_hot) {
const int M = 4;
const int N = 4;
const int on_value = 1;
const int off_value = 0;
const int depth = 11;
const int axis = 0; // [-1 , M]
const std::string dtype = "int32";

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Int(32), {M, N}, "In");
Placeholder on_value_input = builder.CreateInput(Int(32), {1}, "OnValue");
Placeholder off_value_input = builder.CreateInput(Int(32), {1}, "OffValue");
Variable output = builder.OneHot(input, on_value_input, off_value_input, depth, axis, dtype);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(on_value_input.id()));
scope->Var<hlir::framework::Tensor>(std::string(off_value_input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
const std::vector<int>& intput_shape = input_tensor->shape().data();
SetRandInt(input_tensor, target);
int* input_data = input_tensor->mutable_data<int>(target);

auto on_value_tensor = scope->GetTensor(std::string(on_value_input.id()));
int* on_value_data = on_value_tensor->mutable_data<int>(target);
on_value_data[0] = on_value;

auto off_value_tensor = scope->GetTensor(std::string(off_value_input.id()));
int* off_value_data = off_value_tensor->mutable_data<int>(target);
off_value_data[0] = off_value;

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
int* output_data = output_tensor->mutable_data<int>(target);

EXPECT_EQ(output_tensor->type(), Int(32));
EXPECT_EQ(output_shape.size(), intput_shape.size() + 1);

const int true_axis = axis == -1 ? M : axis;
int input_shape_index = 0;

for (int i = 0; i < output_shape.size(); i++) {
LOG(INFO) << output_shape[i];
if (i == true_axis) {
EXPECT_EQ(output_shape[i], depth);
} else {
EXPECT_EQ(output_shape[i], intput_shape[input_shape_index++]);
}
}

for (int i = 0; i < output_shape[0]; ++i) {
for (int j = 0; j < output_shape[1]; ++j) {
for (int k = 0; k < output_shape[2]; ++k) {
std::vector<int> s = {i, j, k};
int input_index = 0;
int output_index = 0;
int base = 1;

for (int x = s.size() - 1; x >= 0; --x) {
if (x == true_axis) {
continue;
}
input_index += base * s[x];
base = base * output_shape[x];
}

base = 1;

for (int x = s.size() - 1; x >= 0; --x) {
output_index += base * s[x];
base = base * output_shape[x];
}

if (s[true_axis] == input_data[input_index]) {
EXPECT_EQ(output_data[output_index], on_value);
} else {
EXPECT_EQ(output_data[output_index], off_value);
}
}
}
}
}

} // namespace frontend
} // namespace cinn
2 changes: 2 additions & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ gather_srcs(cinnapi_src SRCS
argmax.cc
squeeze.cc
repeat.cc
one_hot.cc
)

cc_test(test_cast SRCS cast_test.cc DEPS cinncore)
Expand All @@ -26,3 +27,4 @@ cc_test(test_argmax SRCS argmax_test.cc DEPS cinncore)
cc_test(test_arange SRCS arange_test.cc DEPS cinncore)
cc_test(test_flip SRCS flip_test.cc DEPS cinncore)
cc_test(test_repeat SRCS repeat_test.cc DEPS cinncore)
cc_test(test_one_hot SRCS one_hot_test.cc DEPS cinncore)
Loading