From eefa645974584889a7927145fd49134ece2597fe Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sun, 25 Sep 2022 21:37:48 -0700 Subject: [PATCH 1/3] add one_hot op --- cinn/frontend/net_builder.cc | 9 + cinn/frontend/net_builder.h | 15 ++ cinn/frontend/net_builder_test.cc | 95 ++++++++++ cinn/hlir/op/contrib/CMakeLists.txt | 2 + cinn/hlir/op/contrib/one_hot.cc | 256 +++++++++++++++++++++++++++ cinn/hlir/op/contrib/one_hot.h | 38 ++++ cinn/hlir/op/contrib/one_hot_test.cc | 107 +++++++++++ cinn/hlir/op/use_ops.h | 1 + 8 files changed, 523 insertions(+) create mode 100644 cinn/hlir/op/contrib/one_hot.cc create mode 100644 cinn/hlir/op/contrib/one_hot.h create mode 100644 cinn/hlir/op/contrib/one_hot_test.cc diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index fba5eb8cd1..845f783c4f 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -391,6 +391,15 @@ Variable NetBuilder::Cast(const Variable& operand, const std::string& dtype) { return CustomInstr("cast", {operand}, {{"dtype", dtype}}).front(); } +Variable NetBuilder::OneHot(const Variable& indices, + const Variable& on_value, + const Variable& off_value, + const int depth, + const int axis, + const std::string& dtype){ + return CustomInstr("one_hot", {indices, on_value, off_value}, {{"depth", depth}, {"axis", axis}, {"dtype", dtype}}).front(); +} + Variable NetBuilder::Squeeze(const Variable& operand, const std::vector& axes) { return CustomInstr("squeeze", {operand}, {{"axes", axes}}).front(); } diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h index fbfd68c574..00d9a9c8da 100644 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -641,6 +641,21 @@ class NetBuilder { */ Variable Cast(const Variable& x, const std::string& dtype); + /** + * @brief Returns a one-hot tensor where the locations repsented by indices take value `on_value`, + * other locations take value `off_value`. + * @param on_value Value to fill at indices. Its shape must be [1]. + * @param on_value Value to fill at all other positions besides indices. Its shape must be [1] + * @param depth Depth of the one-hot dimension. + * @param axis Axis to fill. + */ + Variable OneHot(const Variable& indices, + const Variable& on_value, + const Variable& off_value, + const int depth, + const int axis, + const std::string& dtype); + // ******************************************* // Decomposer Operator /** diff --git a/cinn/frontend/net_builder_test.cc b/cinn/frontend/net_builder_test.cc index 924e00060f..a9feb093b2 100755 --- a/cinn/frontend/net_builder_test.cc +++ b/cinn/frontend/net_builder_test.cc @@ -887,5 +887,100 @@ TEST(net_build, program_execute_arange_int) { } } +TEST(net_build, program_execute_one_hot) { + const int M = 4; + const int N = 4; + const int on_value = 1; + const int off_value = 0; + const int depth = 11; + const int axis = 0; // [-1 , M] + const std::string dtype = "int32"; + + NetBuilder builder("net_builder"); + Placeholder input = builder.CreateInput(Int(32), {M, N}, "In"); + Placeholder on_value_input = builder.CreateInput(Int(32), {1}, "OnValue"); + Placeholder off_value_input = builder.CreateInput(Int(32), {1}, "OffValue"); + Variable output = builder.OneHot(input, on_value_input, off_value_input, depth, axis, dtype); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input.id())); + scope->Var(std::string(on_value_input.id())); + scope->Var(std::string(off_value_input.id())); + scope->Var(std::string(output->id)); + + auto input_tensor = scope->GetTensor(std::string(input.id())); + const std::vector& intput_shape = input_tensor->shape().data(); + SetRandInt(input_tensor, target); + int* input_data = input_tensor->mutable_data(target); + + auto on_value_tensor = scope->GetTensor(std::string(on_value_input.id())); + int* on_value_data = on_value_tensor->mutable_data(target); + on_value_data[0] = on_value; + + auto off_value_tensor = scope->GetTensor(std::string(off_value_input.id())); + int* off_value_data = off_value_tensor->mutable_data(target); + off_value_data[0] = off_value; + + runtime_program->Execute(); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + int* output_data = output_tensor->mutable_data(target); + + EXPECT_EQ(output_tensor->type(), Int(32)); + EXPECT_EQ(output_shape.size(), intput_shape.size() + 1); + + const int true_axis = axis == -1 ? M : axis; + int input_shape_index = 0; + + for (int i = 0; i < output_shape.size(); i++) { + LOG(INFO) << output_shape[i]; + if (i == true_axis) { + EXPECT_EQ(output_shape[i], depth); + } else { + EXPECT_EQ(output_shape[i], intput_shape[input_shape_index++]); + } + } + + for (int i = 0; i < output_shape[0]; ++i) { + for (int j = 0; j < output_shape[1]; ++j) { + for (int k = 0; k < output_shape[2]; ++k) { + std::vector s = {i, j, k}; + int input_index = 0; + int output_index = 0; + int base = 1; + + for (int x = s.size() - 1; x >= 0; --x) { + if (x == true_axis) { + continue; + } + input_index += base * s[x]; + base = base * output_shape[x]; + } + + base = 1; + + for (int x = s.size() - 1; x >= 0; --x) { + output_index += base * s[x]; + base = base * output_shape[x]; + } + + if (s[true_axis] == input_data[input_index]) { + EXPECT_EQ(output_data[output_index], on_value); + } else { + EXPECT_EQ(output_data[output_index], off_value); + } + } + } + } +} + } // namespace frontend } // namespace cinn diff --git a/cinn/hlir/op/contrib/CMakeLists.txt b/cinn/hlir/op/contrib/CMakeLists.txt index 5643c4cd69..2949e19071 100644 --- a/cinn/hlir/op/contrib/CMakeLists.txt +++ b/cinn/hlir/op/contrib/CMakeLists.txt @@ -9,6 +9,7 @@ gather_srcs(cinnapi_src SRCS arange.cc sort.cc squeeze.cc + one_hot.cc ) cc_test(test_cast SRCS cast_test.cc DEPS cinncore) @@ -18,3 +19,4 @@ cc_test(test_scatter SRCS scatter_test.cc DEPS cinncore) cc_test(test_clip SRCS clip_test.cc DEPS cinncore) cc_test(test_sort SRCS sort_test.cc DEPS cinncore) cc_test(test_arange SRCS arange_test.cc DEPS cinncore) +cc_test(test_one_hot SRCS one_hot_test.cc DEPS cinncore) diff --git a/cinn/hlir/op/contrib/one_hot.cc b/cinn/hlir/op/contrib/one_hot.cc new file mode 100644 index 0000000000..a2e405e5bb --- /dev/null +++ b/cinn/hlir/op/contrib/one_hot.cc @@ -0,0 +1,256 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/one_hot.h" + +#include + +#include +#include +#include +#include + +#include "cinn/common/cas.h" +#include "cinn/common/common.h" +#include "cinn/common/context.h" +#include "cinn/common/macros.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/pe/ir_schedule_pe.h" +#include "cinn/hlir/pe/nn.h" +#include "cinn/hlir/pe/transform.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace hlir { +namespace op { + +using common::CINNValuePack; + +ir::Tensor OneHot(const ir::Tensor& indices, + const ir::Tensor& on_value, + const ir::Tensor& off_value, + const int depth, + const int axis, + const Type& dtype, + const std::string& output_name) { + int ndim = static_cast(indices->shape.size()); + CHECK(axis == -1 || (0 <= axis && axis <= ndim)) << "one_hot only accepts `axis` in [-1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; + CHECK(depth > 0) << "one_hot only accepts `depth > 0`" + << ", but got depth = " << depth; + + CHECK(on_value->shape.size() == 1U && on_value->shape[0].as_int32() == 1U) << "The shape of on_value must be [1]"; + CHECK(off_value->shape.size() == 1U && off_value->shape[0].as_int32() == 1U) << "The shape of off_value must be [1]"; + + int true_axis = (axis == -1) ? ndim : axis; + std::vector new_shape; + int indices_index = 0; + + for (int i = 0; i < ndim + 1; ++i) { + if (i == true_axis) { + new_shape.push_back(Expr(depth)); + } else { + new_shape.push_back(indices->shape[indices_index++]); + } + } + + ir::Expr on_value_cast = ir::Cast::Make(dtype, on_value(Expr(0))); + ir::Expr off_value_cast = ir::Cast::Make(dtype, off_value(Expr(0))); + + ir::Tensor res = lang::Compute( + new_shape, + [=](const std::vector& iter) { + std::vector indices_indices; + + for (size_t i = 0; i < iter.size(); i++) { + if (static_cast(i) == true_axis) { + continue; + } + indices_indices.push_back(iter[i]); + } + + auto idx = iter[true_axis]; + auto elem = ir::Cast::Make(idx.type(), indices(indices_indices)); + return ir::Select::Make(ir::EQ::Make(elem, idx), on_value_cast, off_value_cast); + }, + common::UniqName(output_name)); + + return res; +} + +std::vector InferShapeForOneHot(const std::vector& inputs_shape, + const framework::AttrMapType& attrs) { + CHECK_EQ(inputs_shape.size(), 3UL) << "The number of one_hot's input should be 3"; + + int depth; + int axis; + + for (auto& iter : attrs) { + if (iter.first == "depth") { + depth = absl::get(iter.second); + } else if (iter.first == "axis") { + axis = absl::get(iter.second); + } + } + + const std::vector& in_shape = inputs_shape[0]; + int ndim = static_cast(in_shape.size()); + int true_axis = (axis == -1) ? in_shape.size() : axis; + int indices_index = 0; + std::vector new_shape; + + for (int i = 0; i < ndim + 1; ++i) { + if (i == true_axis) { + new_shape.push_back(depth); + } else { + new_shape.push_back(in_shape[indices_index++]); + } + } + + std::vector> res{new_shape}; + return res; +} + +std::vector InferDtypeForOneHot(const std::vector& inputs_type, const framework::AttrMapType& attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + + std::string dtype = "float32"; + if (attrs.find("dtype") != attrs.end()) { + dtype = absl::get(attrs.at("dtype")); + } + + std::vector res{common::Str2Type(dtype)}; + return res; +} + +std::shared_ptr StrategyForOneHot(const framework::NodeAttr& attrs, + const std::vector& inputs, + const std::vector& out_type, + const std::vector>& output_shapes, + const Target& target) { + int depth; + int axis; + std::string dtype = "float32"; + + for (auto& iter : attrs.attr_store) { + if (iter.first == "depth") { + depth = absl::get(iter.second); + } else if (iter.first == "axis") { + axis = absl::get(iter.second); + } else if (iter.first == "dtype") { + dtype = absl::get(iter.second); + } + } + + CHECK(depth > 0) << "one_hot only accepts `depth > 0`" + << ", but got depth = " << depth; + + framework::CINNCompute one_hot_compute([=](lang::Args args, lang::RetValue* ret) { + CHECK(!args.empty()) << "The input argument of one_hot compute is empty! Please check.\n"; + common::CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) << "at least one input tensor for transpose compute\n"; + Expr indices_expr = pack_args[0]; + Expr on_value_expr = pack_args[1]; + Expr off_value_expr = pack_args[2]; + CHECK(indices_expr.as_tensor()); + CHECK(on_value_expr.as_tensor()); + CHECK(off_value_expr.as_tensor()); + + ir::Tensor indices = indices_expr.as_tensor_ref(); + ir::Tensor on_value = on_value_expr.as_tensor_ref(); + ir::Tensor off_value = off_value_expr.as_tensor_ref(); + + std::string tensor_name = common::UniqName("T_OneHot_out"); + + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 1U); + tensor_name = pack_args[0].operator std::string(); + } + + ir::Tensor out = OneHot(indices, on_value, off_value, depth, axis, common::Str2Type(dtype), tensor_name); + + std::vector res; + auto stages = CreateStages({indices, on_value, off_value}); + stages->InsertLazily(out); + res.push_back(common::CINNValue(out)); + res.push_back(common::CINNValue(stages)); + *ret = common::CINNValuePack{res}; + }); + + framework::CINNSchedule one_hot_schedule([=](lang::Args args, lang::RetValue* ret) { + if (FLAGS_cinn_ir_schedule) { + CHECK(!args.empty()) << "The input argument of repeat schedule is empty! Please check.\n"; + common::CINNValuePack arg_pack = args[0]; + std::vector vec_ast; + for (int i = 0; i < arg_pack.size(); i++) { + if (arg_pack[i].is_expr()) { + Expr temp = arg_pack[i]; + vec_ast.emplace_back(temp); + } + } + CHECK(!vec_ast.empty()); + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + ir_sch.MergeExprs(); + long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + if (prod_size > 1) { + if (target.arch == Target::Arch::NVGPU) { + pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target); + } else if (target.arch == Target::Arch::X86) { + pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); + } + } + std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + *ret = common::CINNValuePack{res}; + } else { + CHECK(!args.empty()) << "The input argument of repeat schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + *ret = arg_pack; + } + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(one_hot_compute, one_hot_schedule, "strategy.one_hot.x86", 1); + + return strategy; +} +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(one_hot_ops) { + CINN_REGISTER_OP(one_hot) + .describe( + "Returns a one-hot tensor where the locations repsented by indices take value `on_value`, " + "other locations take value `off_value`.") + .set_num_inputs(3) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForOneHot) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForOneHot)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForOneHot)) + .set_support_level(4); + + return true; +} diff --git a/cinn/hlir/op/contrib/one_hot.h b/cinn/hlir/op/contrib/one_hot.h new file mode 100644 index 0000000000..90e21d8c01 --- /dev/null +++ b/cinn/hlir/op/contrib/one_hot.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" + +namespace cinn { +namespace hlir { +namespace op { + +ir::Tensor OneHot(const ir::Tensor& indices, + const ir::Tensor& on_value, + const ir::Tensor& off_value, + const int depth, + const int axis, + const Type& dtype, + const std::string& output_name); + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/one_hot_test.cc b/cinn/hlir/op/contrib/one_hot_test.cc new file mode 100644 index 0000000000..d76769c5d0 --- /dev/null +++ b/cinn/hlir/op/contrib/one_hot_test.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/one_hot.h" + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/common/context.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace hlir { +namespace op { + +TEST(GenerateCode_Cpu, OneHot) { + common::Context::Global().ResetNameId(); + + common::Target target = common::DefaultHostTarget(); + + ir::Expr m(4); + ir::Expr n(4); + const int depth = 3; + const int axis = 1; + const std::string dtype = "float32"; + + lang::Placeholder in("in", {m, n}); + lang::Placeholder on_value("on_value", {Expr(1)}); + lang::Placeholder off_value("off_value", {Expr(1)}); + + ir::Tensor res = OneHot(in, on_value, off_value, depth, axis, common::Str2Type(dtype), "test_one_hot"); + + poly::StageMap stages = poly::CreateStages({res}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_OneHot", stages, {res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("OneHot_Module", target); + for (auto &f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + VLOG(6) << "Cpu Codegen result:"; + VLOG(6) << code << std::endl; + + auto target_source = R"ROC( +#include +#include + +void TestGenerateCodeCpu_OneHot(void* _args, int32_t num_args) +{ + cinn_buffer_t* _test_one_hot = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _in = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 4 }); + cinn_buffer_t* _off_value = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { }); + cinn_buffer_t* _on_value = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { }); + cinn_buffer_malloc((void*)(0), _test_one_hot); + cinn_buffer_malloc((void*)(0), _in); + cinn_buffer_malloc((void*)(0), _off_value); + cinn_buffer_malloc((void*)(0), _on_value); + const int32_t* in = ((const int32_t*)(_in->memory)); + const int32_t* off_value = ((const int32_t*)(_off_value->memory)); + const int32_t* on_value = ((const int32_t*)(_on_value->memory)); + float* test_one_hot = ((float*)(_test_one_hot->memory)); + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 3; j += 1) { + for (int32_t k = 0; k < 4; k += 1) { + test_one_hot[((12 * i) + ((4 * j) + k))] = (((in[((4 * i) + k)] == j)) ? ((float)(on_value[0])) : ((float)(off_value[0]))); + }; + }; + }; + cinn_buffer_free((void*)(0), _in); + cinn_buffer_free((void*)(0), _off_value); + cinn_buffer_free((void*)(0), _on_value); + cinn_buffer_free((void*)(0), _test_one_hot); +} + )ROC"; + + ASSERT_EQ(utils::Trim(code), utils::Trim(target_source)); +} + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/use_ops.h b/cinn/hlir/op/use_ops.h index 38f538f1d5..9fc5d11c25 100644 --- a/cinn/hlir/op/use_ops.h +++ b/cinn/hlir/op/use_ops.h @@ -31,3 +31,4 @@ CINN_USE_REGISTER(reduce_ops) CINN_USE_REGISTER(clip_ops) CINN_USE_REGISTER(custom_call_op) CINN_USE_REGISTER(arange_ops) +CINN_USE_REGISTER(one_hot_ops) From f6945fbb9ccd30df2c52c3ae66e2a2e31308aebb Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sun, 25 Sep 2022 22:30:56 -0700 Subject: [PATCH 2/3] format --- cinn/frontend/net_builder.cc | 13 +++++++------ cinn/frontend/net_builder.h | 2 +- cinn/frontend/net_builder_test.cc | 0 3 files changed, 8 insertions(+), 7 deletions(-) mode change 100755 => 100644 cinn/frontend/net_builder_test.cc diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index 845f783c4f..fcf452a32f 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -392,12 +392,13 @@ Variable NetBuilder::Cast(const Variable& operand, const std::string& dtype) { } Variable NetBuilder::OneHot(const Variable& indices, - const Variable& on_value, - const Variable& off_value, - const int depth, - const int axis, - const std::string& dtype){ - return CustomInstr("one_hot", {indices, on_value, off_value}, {{"depth", depth}, {"axis", axis}, {"dtype", dtype}}).front(); + const Variable& on_value, + const Variable& off_value, + const int depth, + const int axis, + const std::string& dtype) { + return CustomInstr("one_hot", {indices, on_value, off_value}, {{"depth", depth}, {"axis", axis}, {"dtype", dtype}}) + .front(); } Variable NetBuilder::Squeeze(const Variable& operand, const std::vector& axes) { diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h index 00d9a9c8da..ba4b0943b0 100644 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -641,7 +641,7 @@ class NetBuilder { */ Variable Cast(const Variable& x, const std::string& dtype); - /** + /** * @brief Returns a one-hot tensor where the locations repsented by indices take value `on_value`, * other locations take value `off_value`. * @param on_value Value to fill at indices. Its shape must be [1]. diff --git a/cinn/frontend/net_builder_test.cc b/cinn/frontend/net_builder_test.cc old mode 100755 new mode 100644 From 6d7f05808bbe30894a3b112a763640426ef7c97f Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sun, 25 Sep 2022 22:48:19 -0700 Subject: [PATCH 3/3] format code --- cinn/hlir/op/contrib/one_hot.cc | 12 ++++++------ cinn/hlir/op/contrib/one_hot_test.cc | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cinn/hlir/op/contrib/one_hot.cc b/cinn/hlir/op/contrib/one_hot.cc index a2e405e5bb..8597e3939e 100644 --- a/cinn/hlir/op/contrib/one_hot.cc +++ b/cinn/hlir/op/contrib/one_hot.cc @@ -74,13 +74,13 @@ ir::Tensor OneHot(const ir::Tensor& indices, } } - ir::Expr on_value_cast = ir::Cast::Make(dtype, on_value(Expr(0))); - ir::Expr off_value_cast = ir::Cast::Make(dtype, off_value(Expr(0))); + Expr on_value_cast = ir::Cast::Make(dtype, on_value(Expr(0))); + Expr off_value_cast = ir::Cast::Make(dtype, off_value(Expr(0))); ir::Tensor res = lang::Compute( new_shape, - [=](const std::vector& iter) { - std::vector indices_indices; + [=](const std::vector& iter) { + std::vector indices_indices; for (size_t i = 0; i < iter.size(); i++) { if (static_cast(i) == true_axis) { @@ -89,8 +89,8 @@ ir::Tensor OneHot(const ir::Tensor& indices, indices_indices.push_back(iter[i]); } - auto idx = iter[true_axis]; - auto elem = ir::Cast::Make(idx.type(), indices(indices_indices)); + Expr idx = iter[true_axis]; + Expr elem = ir::Cast::Make(idx.type(), indices(indices_indices)); return ir::Select::Make(ir::EQ::Make(elem, idx), on_value_cast, off_value_cast); }, common::UniqName(output_name)); diff --git a/cinn/hlir/op/contrib/one_hot_test.cc b/cinn/hlir/op/contrib/one_hot_test.cc index d76769c5d0..2c51165228 100644 --- a/cinn/hlir/op/contrib/one_hot_test.cc +++ b/cinn/hlir/op/contrib/one_hot_test.cc @@ -37,8 +37,8 @@ TEST(GenerateCode_Cpu, OneHot) { common::Target target = common::DefaultHostTarget(); - ir::Expr m(4); - ir::Expr n(4); + Expr m(4); + Expr n(4); const int depth = 3; const int axis = 1; const std::string dtype = "float32";