From 1bd733a2b286b77b4b7638fcc438944ca1adeab9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AD=E4=B8=AA=E9=AA=A8=E5=A4=B4?= <46243324+zrr1999@users.noreply.github.com> Date: Thu, 22 Sep 2022 10:12:19 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PaddlePaddle=20Hackathon=20No.78?= =?UTF-8?q?=E3=80=91add=20gather,=20gather=5Fnd,=20scatter=20and=20scatter?= =?UTF-8?q?=5Fnd=20op=20(#897)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 【PaddlePaddle Hackathon No.78】add gather, gather_nd, scatter and scatter_nd op --- cinn/frontend/net_builder.cc | 35 ++ cinn/frontend/net_builder.h | 26 +- cinn/frontend/net_builder_test.cc | 254 +++++++++++++- cinn/hlir/op/contrib/CMakeLists.txt | 4 + cinn/hlir/op/contrib/gather.cc | 235 +++++++++++++ cinn/hlir/op/contrib/gather.h | 34 ++ cinn/hlir/op/contrib/gather_test.cc | 154 +++++++++ cinn/hlir/op/contrib/scatter.cc | 318 ++++++++++++++++++ cinn/hlir/op/contrib/scatter.h | 44 +++ cinn/hlir/op/contrib/scatter_test.cc | 165 +++++++++ cinn/hlir/op/use_ops.h | 2 + cinn/pybind/frontend.cc | 4 +- cinn/runtime/cuda/CMakeLists.txt | 10 +- .../runtime/cuda/cinn_cuda_runtime_source.cuh | 66 ++-- cinn/utils/data_util.cc | 26 ++ cinn/utils/data_util.h | 2 + 16 files changed, 1335 insertions(+), 44 deletions(-) create mode 100644 cinn/hlir/op/contrib/gather.cc create mode 100644 cinn/hlir/op/contrib/gather.h create mode 100644 cinn/hlir/op/contrib/gather_test.cc create mode 100644 cinn/hlir/op/contrib/scatter.cc create mode 100644 cinn/hlir/op/contrib/scatter.h create mode 100644 cinn/hlir/op/contrib/scatter_test.cc diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index f14496e32a..4cf33ca9ab 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -349,6 +349,41 @@ Variable NetBuilder::ReluGrad(const Variable& lhs, const Variable& rhs) { return CustomInstr("relu_grad", {lhs, rhs}, {}).front(); } +Variable NetBuilder::Gather(const Variable& x, const Variable& index, const int& axis) { + return CustomInstr("gather", {x, index}, {{"axis", axis}}).front(); +} + +Variable NetBuilder::GatherNd(const Variable& x, const Variable& index, const std::vector& axes) { + return CustomInstr("gather_nd", {x, index}, {{"axes", axes}}).front(); +} + +Variable NetBuilder::Scatter(const Variable& src, const Variable& index, const Variable& out, const int& axis) { + return CustomInstr("scatter", {src, index, out}, {{"axis", axis}}).front(); +} +Variable NetBuilder::Scatter(const Variable& src, + const Variable& index, + const std::vector& shape, + const float& default_value, + const int& axis) { + auto out = FillConstant(shape, default_value, UniqName("fill_constant"), "float", false); + return Scatter(src, index, out, axis); +} + +Variable NetBuilder::ScatterNd(const Variable& src, + const Variable& index, + const Variable& out, + const std::vector& axes) { + return CustomInstr("scatter_nd", {src, index, out}, {{"axes", axes}}).front(); +} +Variable NetBuilder::ScatterNd(const Variable& src, + const Variable& index, + const std::vector& shape, + const float& default_value, + const std::vector& axes) { + auto out = FillConstant(shape, default_value, UniqName("fill_constant"), "float", false); + return ScatterNd(src, index, out, axes); +} + Variable NetBuilder::Cast(const Variable& operand, const std::string& dtype) { if (operand->type == common::Str2Type(dtype)) { return Identity(operand); diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h index e213c83b9c..fbfd68c574 100644 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -262,6 +262,27 @@ class NetBuilder { */ Variable Clip(const std::vector& x, const float& max, const float& min); + Variable Gather(const Variable& x, const Variable& index, const int& axis = 0); + + Variable GatherNd(const Variable& x, const Variable& index, const std::vector& axes = {}); + + Variable Scatter(const Variable& src, const Variable& index, const Variable& out, const int& axis = 0); + Variable Scatter(const Variable& src, + const Variable& index, + const std::vector& shape, + const float& default_value = 0, + const int& axis = 0); + + Variable ScatterNd(const Variable& src, + const Variable& index, + const Variable& out, + const std::vector& axes = {}); + Variable ScatterNd(const Variable& src, + const Variable& index, + const std::vector& shape, + const float& default_value = 0, + const std::vector& axes = {}); + /** * @brief This operator checks if all `x` and `y` satisfy the condition: `|x - y| <= atol + rtol * |y|` * @param x The first variable. @@ -793,7 +814,7 @@ class NetBuilder { const std::string& data_layout = "NCHW"); /** - * @brief Sort Variable x along the given axis. The original Variable x will not be changed. + * @brief Sort Variable x along the given axis and return sorted index. The original Variable x will not be changed. * @param operand The variable that will be sorted. * @param axis Specify the axis to operate on the input. Default: 0. * @param is_ascend Sort mode. @@ -803,7 +824,8 @@ class NetBuilder { Variable ArgSort(const Variable& operand, const int& axis, const bool& is_ascend = true); /** - * @brief Sort Variable x along the given axis. The original Variable x will not be changed. + * @brief Sort Variable x along the given axis and return sorted variable. The original Variable x will not be + * changed. * @param operand The variable that will be sorted. * @param axis Specify the axis to operate on the input. Default: 0. * @param is_ascend Sort mode. diff --git a/cinn/frontend/net_builder_test.cc b/cinn/frontend/net_builder_test.cc index 4920f994d7..924e00060f 100755 --- a/cinn/frontend/net_builder_test.cc +++ b/cinn/frontend/net_builder_test.cc @@ -246,6 +246,258 @@ TEST(net_build, program_execute_clip) { } } +TEST(net_build, program_execute_gather) { + const int B = 4; + const int H_IN1 = 11; + const int H_IN2 = 14; + + NetBuilder builder("net_builder"); + Placeholder input1 = builder.CreateInput(Float(32), {B, H_IN1}, "In1"); + Placeholder input2 = builder.CreateInput(Int(32), {B, H_IN2}, "In2"); + Variable output = builder.Gather(input1, input2, 1); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input1.id())); + scope->Var(std::string(input2.id())); + scope->Var(std::string(output->id)); + + auto input1_tensor = scope->GetTensor(std::string(input1.id())); + SetRandData(input1_tensor, target); + float* input1_data = input1_tensor->mutable_data(target); + + auto input2_tensor = scope->GetTensor(std::string(input2.id())); + SetRandInt(input2_tensor, target); + int* input2_data = input2_tensor->mutable_data(target); + memset(input2_data, 0, sizeof(int) * B * H_IN2); + + runtime_program->Execute(); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + EXPECT_EQ(output_tensor->type(), Float(32)); + EXPECT_EQ(output_shape.size(), 2UL); + EXPECT_EQ(output_shape[0], B); + EXPECT_EQ(output_shape[1], H_IN2); + + float* output_data = output_tensor->mutable_data(target); + VLOG(6) << "Visualize output_data"; + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H_IN2; ++h) { + std::string line; + int index = h + H_IN2 * b; + float in_data = input1_data[input2_data[index] + H_IN1 * b]; + float out_data = output_data[index]; + line += (std::to_string(out_data) + ", "); + EXPECT_EQ(in_data, out_data); + VLOG(6) << line; + } + } +} + +TEST(net_build, program_execute_gather_nd) { + const int B = 4; + const int H_IN1 = 11; + const int H_IN2 = 14; + + NetBuilder builder("net_builder"); + Placeholder input1 = builder.CreateInput(Float(32), {B, H_IN1}, "In1"); + Placeholder input2 = builder.CreateInput(Int(32), {B, H_IN2, 1}, "In2"); + Variable output = builder.GatherNd(input1, input2, {1}); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input1.id())); + scope->Var(std::string(input2.id())); + scope->Var(std::string(output->id)); + + auto input1_tensor = scope->GetTensor(std::string(input1.id())); + SetRandData(input1_tensor, target); + float* input1_data = input1_tensor->mutable_data(target); + + auto input2_tensor = scope->GetTensor(std::string(input2.id())); + SetRandInt(input2_tensor, target); + int* input2_data = input2_tensor->mutable_data(target); + + runtime_program->Execute(); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + EXPECT_EQ(output_tensor->type(), Float(32)); + EXPECT_EQ(output_shape.size(), 2UL); + EXPECT_EQ(output_shape[0], B); + EXPECT_EQ(output_shape[1], H_IN2); + + float* output_data = output_tensor->mutable_data(target); + VLOG(6) << "Visualize output_data"; + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H_IN2; ++h) { + std::string line; + int index = h + H_IN2 * b; + float in_data = input1_data[input2_data[index] + H_IN1 * b]; + float out_data = output_data[index]; + line += (std::to_string(out_data) + ", "); + EXPECT_EQ(in_data, out_data); + VLOG(6) << line; + } + } +} + +TEST(net_build, program_execute_scatter) { + const float default_value = 3.14; + const int B = 3; + const int H_IN = 4; + const int H_OUT = 11; + + NetBuilder builder("net_builder"); + Placeholder input1 = builder.CreateInput(Float(32), {B, H_IN}, "In1"); + Placeholder input2 = builder.CreateInput(Int(32), {B, H_IN}, "In2"); + Variable output = builder.Scatter(input1, input2, {B, H_OUT}, default_value, 1); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input1.id())); + scope->Var(std::string(input2.id())); + scope->Var(std::string(output->id)); + + auto input1_tensor = scope->GetTensor(std::string(input1.id())); + SetRandData(input1_tensor, target); + float* input1_data = input1_tensor->mutable_data(target); + + auto input2_tensor = scope->GetTensor(std::string(input2.id())); + SetRandInt(input2_tensor, target); + int* input2_data = input2_tensor->mutable_data(target); + memset(input2_data, 0, sizeof(int) * B * H_IN); + + runtime_program->Execute(); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + EXPECT_EQ(output_tensor->type(), Float(32)); + EXPECT_EQ(output_shape.size(), 2UL); + EXPECT_EQ(output_shape[0], B); + EXPECT_EQ(output_shape[1], H_OUT); + + float true_data[B * H_OUT]; + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H_OUT; ++h) { + int index = h + H_OUT * b; + true_data[index] = default_value; + } + } + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H_IN; ++h) { + int index = h + H_IN * b; + true_data[input2_data[index] + H_OUT * b] = input1_data[index]; + std::cout << index << " " << input2_data[index] + H_OUT * b << " " << true_data[input2_data[index] + H_OUT * b]; + } + } + + float* output_data = output_tensor->mutable_data(target); + VLOG(6) << "Visualize output_data"; + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H_OUT; ++h) { + std::string line; + int index = h + H_OUT * b; + float t_data = true_data[index]; + float out_data = output_data[index]; + line += (std::to_string(out_data) + ", "); + EXPECT_EQ(t_data, out_data); + VLOG(6) << line; + } + } +} + +TEST(net_build, program_execute_scatter_nd) { + const float default_value = 3.14; + const int B = 3; + const int H_IN = 4; + const int H_OUT = 11; + + NetBuilder builder("net_builder"); + Placeholder input1 = builder.CreateInput(Float(32), {B, H_IN}, "In1"); + Placeholder input2 = builder.CreateInput(Int(32), {B, H_IN, 1}, "In2"); + Variable output = builder.ScatterNd(input1, input2, {B, H_OUT}, default_value, {1}); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input1.id())); + scope->Var(std::string(input2.id())); + scope->Var(std::string(output->id)); + + auto input1_tensor = scope->GetTensor(std::string(input1.id())); + SetRandData(input1_tensor, target); + + auto input2_tensor = scope->GetTensor(std::string(input2.id())); + SetRandInt(input2_tensor, target); + + runtime_program->Execute(); + + int* input2_data; + float* input1_data; + input2_data = input2_tensor->mutable_data(target); + input1_data = input1_tensor->mutable_data(target); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + EXPECT_EQ(output_tensor->type(), Float(32)); + EXPECT_EQ(output_shape.size(), 2UL); + EXPECT_EQ(output_shape[0], B); + EXPECT_EQ(output_shape[1], H_OUT); + + float true_data[B * H_OUT]; + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H_OUT; ++h) { + int index = h + H_OUT * b; + true_data[index] = default_value; + } + } + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H_IN; ++h) { + int index = h + H_IN * b; + true_data[input2_data[index] + H_OUT * b] = input1_data[index]; + } + } + + float* output_data = output_tensor->mutable_data(target); + VLOG(6) << "Visualize output_data"; + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H_OUT; ++h) { + std::string line; + int index = h + H_OUT * b; + float t_data = true_data[index]; + float out_data = output_data[index]; + line += (std::to_string(out_data) + ", "); + EXPECT_EQ(t_data, out_data); + VLOG(6) << line; + } + } +} + TEST(net_build, program_execute_cast) { const int B = 4; const int H = 7; @@ -266,7 +518,7 @@ TEST(net_build, program_execute_cast) { scope->Var(std::string(output->id)); auto input_tensor = scope->GetTensor(std::string(input.id())); - SetRandData(input_tensor, target); + SetRandInt(input_tensor, target); int* input_data = input_tensor->mutable_data(target); runtime_program->Execute(); diff --git a/cinn/hlir/op/contrib/CMakeLists.txt b/cinn/hlir/op/contrib/CMakeLists.txt index 7dd7945d5c..5643c4cd69 100644 --- a/cinn/hlir/op/contrib/CMakeLists.txt +++ b/cinn/hlir/op/contrib/CMakeLists.txt @@ -1,6 +1,8 @@ core_gather_headers() gather_srcs(cinnapi_src SRCS + gather.cc + scatter.cc cast.cc squeeze.cc clip.cc @@ -11,6 +13,8 @@ gather_srcs(cinnapi_src SRCS cc_test(test_cast SRCS cast_test.cc DEPS cinncore) cc_test(test_squeeze SRCS squeeze_test.cc DEPS cinncore) +cc_test(test_gather SRCS gather_test.cc DEPS cinncore) +cc_test(test_scatter SRCS scatter_test.cc DEPS cinncore) cc_test(test_clip SRCS clip_test.cc DEPS cinncore) cc_test(test_sort SRCS sort_test.cc DEPS cinncore) cc_test(test_arange SRCS arange_test.cc DEPS cinncore) diff --git a/cinn/hlir/op/contrib/gather.cc b/cinn/hlir/op/contrib/gather.cc new file mode 100644 index 0000000000..0f533aa4d2 --- /dev/null +++ b/cinn/hlir/op/contrib/gather.cc @@ -0,0 +1,235 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/gather.h" + +#include + +#include +#include +#include +#include + +#include "cinn/common/cas.h" +#include "cinn/common/common.h" +#include "cinn/common/context.h" +#include "cinn/common/macros.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/pe/elementwise.h" +#include "cinn/hlir/pe/ir_schedule_pe.h" +#include "cinn/hlir/pe/nn.h" +#include "cinn/hlir/pe/schedule.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace hlir { +namespace op { + +using common::CINNValue; +using common::CINNValuePack; + +ir::Tensor Gather(const ir::Tensor &A, const ir::Tensor &B, const int &axis, const std::string &name) { + CHECK_EQ(A->shape.size(), B->shape.size()); + auto res = Compute( + B->shape, + [=](const std::vector &indices) { + std::vector A_indices; + for (int i = 0; i < axis; ++i) { + A_indices.push_back(indices[i]); + } + A_indices.push_back(B(indices)); + for (size_t i = axis + 1; i < A->shape.size(); ++i) { + A_indices.push_back(indices[i]); + } + return lang::Identity(A(A_indices)); + }, + name); + return res; +} + +ir::Tensor GatherNd(const ir::Tensor &A, const ir::Tensor &B, const std::vector &axes, const std::string &name) { + std::vector out_shape = B->shape; + out_shape.pop_back(); + auto res = Compute( + out_shape, + [=](const std::vector &indices) { + std::vector A_indices(indices.begin(), indices.begin() + A->shape.size()); + std::vector B_indices(indices); + for (int i = 0; i < axes.size(); ++i) { + B_indices.push_back(Expr(i)); + A_indices[axes[i]] = B(B_indices); + B_indices.pop_back(); + } + return lang::Identity(A(A_indices)); + }, + name); + return res; +} + +std::shared_ptr StrategyForGather(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + auto attr_store = attrs.attr_store; + CHECK(attr_store.count("axis")) << "find no attr of axis"; + int axis = absl::get(attr_store.at("axis")); + std::string op_name("gather"); + + framework::CINNCompute gather_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of " << op_name << " compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 2U) << "2 input tensors for " << op_name << " compute\n"; + Expr A = pack_args[0]; + Expr B = pack_args[1]; + CHECK(A.as_tensor()); + CHECK(B.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto tensor_B = B.as_tensor_ref(); + auto stages = CreateStages({tensor_A, tensor_B}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") << ", B shape: " << utils::Join(tensor_B->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + std::string tensor_name = UniqName("Gather_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 3U); + tensor_name = pack_args[2].operator std::string(); + } + ir::Tensor out = Gather(tensor_A, tensor_B, axis, tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) << "Output type of " << op_name << " is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule gather_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of gather schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(gather_compute, gather_schedule, "strategy.gather.x86", 1); + return strategy; +} + +std::shared_ptr StrategyForGatherNd(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + auto attr_store = attrs.attr_store; + CHECK(attr_store.count("axes")) << "find no attr of axes"; + std::vector axes = absl::get>(attr_store.at("axes")); + std::string op_name("gather_nd"); + + framework::CINNCompute gather_nd_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of " << op_name << " compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 2U) << "2 input tensors for " << op_name << " compute\n"; + Expr A = pack_args[0]; + Expr B = pack_args[1]; + CHECK(A.as_tensor()); + CHECK(B.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto tensor_B = B.as_tensor_ref(); + auto stages = CreateStages({tensor_A, tensor_B}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") << ", B shape: " << utils::Join(tensor_B->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + std::string tensor_name = UniqName("GatherNd_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 3U); + tensor_name = pack_args[2].operator std::string(); + } + ir::Tensor out = GatherNd(tensor_A, tensor_B, axes, tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) << "Output type of " << op_name << " is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule gather_nd_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of gather_nd schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(gather_nd_compute, gather_nd_schedule, "strategy.gather_nd.x86", 1); + return strategy; +} + +std::vector> InferShapeForGather(const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; + CHECK_EQ(inputs_shape[0].size(), inputs_shape[1].size()) << "The inputs' dims should be equal."; + std::vector> res{inputs_shape[1]}; + return res; +} + +std::vector> InferShapeForGatherNd(const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; + std::vector output_shape(inputs_shape[1].begin(), inputs_shape[1].end() - 1); + std::vector> res{output_shape}; + return res; +} + +std::vector InferDtypeForGather(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + std::vector res{inputs_type[0]}; + return res; +} + +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(gather_ops) { + CINN_REGISTER_OP(gather) + .describe("Gather.") + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForGather) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForGather)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForGather)) + .set_support_level(4); + + CINN_REGISTER_OP(gather_nd) + .describe("GatherNd.") + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForGatherNd) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForGatherNd)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForGather)) + .set_support_level(4); + + return true; +} diff --git a/cinn/hlir/op/contrib/gather.h b/cinn/hlir/op/contrib/gather.h new file mode 100644 index 0000000000..fb4a389681 --- /dev/null +++ b/cinn/hlir/op/contrib/gather.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" + +namespace cinn { +namespace hlir { +namespace op { + +ir::Tensor Gather(const ir::Tensor& A, const ir::Tensor& B, const int& axis, const std::string& name); + +ir::Tensor GatherNd(const ir::Tensor& A, const ir::Tensor& B, const std::vector& axes, const std::string& name); + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/gather_test.cc b/cinn/hlir/op/contrib/gather_test.cc new file mode 100644 index 0000000000..ea0bedd7c2 --- /dev/null +++ b/cinn/hlir/op/contrib/gather_test.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/gather.h" + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/common/context.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace hlir { +namespace op { + +TEST(GenerateCode_Cpu, Gather) { + common::Context::Global().ResetNameId(); + + common::Target target = common::DefaultHostTarget(); + + ir::Expr n(4); + ir::Expr h_in1(28); + ir::Expr h_in2(14); + + lang::Placeholder in1("in1", {n, h_in1}); + lang::Placeholder in2("in2", {n, h_in2}); + ir::Tensor res = Gather(in1, in2, 1, "test_gather_out"); + + poly::StageMap stages = poly::CreateStages({res}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_Gather", stages, {res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("Gather_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + VLOG(6) << "Cpu Codegen result:"; + auto target_source = R"ROC( +#include +#include + +void TestGenerateCodeCpu_Gather(void* _args, int32_t num_args) +{ + cinn_buffer_t* _test_gather_out = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _in1 = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 4, 28 }); + cinn_buffer_t* _in2 = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 14 }); + cinn_buffer_malloc((void*)(0), _test_gather_out); + cinn_buffer_malloc((void*)(0), _in1); + cinn_buffer_malloc((void*)(0), _in2); + const float* in1 = ((const float*)(_in1->memory)); + const int32_t* in2 = ((const int32_t*)(_in2->memory)); + float* test_gather_out = ((float*)(_test_gather_out->memory)); + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 14; j += 1) { + test_gather_out[((14 * i) + j)] = in1[((28 * i) + in2[((14 * i) + j)])]; + }; + }; + cinn_buffer_free((void*)(0), _in1); + cinn_buffer_free((void*)(0), _in2); + cinn_buffer_free((void*)(0), _test_gather_out); +} + )ROC"; + CHECK_EQ(utils::Trim(code), utils::Trim(target_source)); +} + +TEST(GenerateCode_Cpu, GatherNd) { + common::Context::Global().ResetNameId(); + + common::Target target = common::DefaultHostTarget(); + + ir::Expr n(4); + ir::Expr h_in1(28); + ir::Expr h_in2(14); + ir::Expr w(1); + + lang::Placeholder in1("in1", {n, h_in1}); + lang::Placeholder in2("in2", {n, h_in2, w}); + ir::Tensor res = GatherNd(in1, in2, {1}, "test_gather_nd_out"); + + poly::StageMap stages = poly::CreateStages({res}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_GatherNd", stages, {res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("GatherNd_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + VLOG(6) << "Cpu Codegen result:"; + VLOG(6) << code << std::endl; + + auto target_source = R"ROC( +#include +#include + +void TestGenerateCodeCpu_GatherNd(void* _args, int32_t num_args) +{ + cinn_buffer_t* _test_gather_nd_out = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _in1 = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 4, 28 }); + cinn_buffer_t* _in2 = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 14, 1 }); + cinn_buffer_malloc((void*)(0), _test_gather_nd_out); + cinn_buffer_malloc((void*)(0), _in1); + cinn_buffer_malloc((void*)(0), _in2); + const float* in1 = ((const float*)(_in1->memory)); + const int32_t* in2 = ((const int32_t*)(_in2->memory)); + float* test_gather_nd_out = ((float*)(_test_gather_nd_out->memory)); + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 14; j += 1) { + test_gather_nd_out[((14 * i) + j)] = in1[((28 * i) + in2[((14 * i) + j)])]; + }; + }; + cinn_buffer_free((void*)(0), _in1); + cinn_buffer_free((void*)(0), _in2); + cinn_buffer_free((void*)(0), _test_gather_nd_out); +} + )ROC"; + CHECK_EQ(utils::Trim(code), utils::Trim(target_source)); +} + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/scatter.cc b/cinn/hlir/op/contrib/scatter.cc new file mode 100644 index 0000000000..cd106fcd2c --- /dev/null +++ b/cinn/hlir/op/contrib/scatter.cc @@ -0,0 +1,318 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/scatter.h" + +#include + +#include +#include +#include +#include + +#include "cinn/common/cas.h" +#include "cinn/common/common.h" +#include "cinn/common/context.h" +#include "cinn/common/macros.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/pe/elementwise.h" +#include "cinn/hlir/pe/ir_schedule_pe.h" +#include "cinn/hlir/pe/nn.h" +#include "cinn/hlir/pe/schedule.h" +#include "cinn/hlir/pe/transform.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace hlir { +namespace op { + +using common::CINNValue; +using common::CINNValuePack; + +ir::Tensor Scatter(const ir::Tensor &A, + const ir::Tensor &B, + const ir::Tensor &C, + const common::Target &target, + const int &axis, + const std::string &name) { + CHECK_EQ(A->shape.size(), B->shape.size()); + CHECK_EQ(A->shape.size(), C->shape.size()); + + std::string extern_fun_name; + if (target.arch == common::Target::Arch::NVGPU) { + extern_fun_name.assign("cinn_cuda_find_int_nd"); + } else if (target.arch == common::Target::Arch::X86) { + extern_fun_name.assign("cinn_host_find_int_nd"); + } else { + LOG(FATAL) << "Scatter only support X86 and NVGPU ! Please Check.\n"; + } + + int pos_axis = axis; + if (pos_axis < 0) { + pos_axis += C->shape.size(); + } + + ir::Tensor transpose_B; + if (pos_axis == A->shape.size() - 1) { + transpose_B = B; + } else { + std::vector new_axes; + for (int i = 0; i < A->shape.size(); ++i) { + if (i != pos_axis) { + new_axes.push_back(i); + } + } + new_axes.push_back(pos_axis); + transpose_B = pe::Transpose(B, new_axes, B->name + "_index_transpose"); + } + auto res = Compute( + C->shape, + [=](const std::vector &indices) { + Expr offset(0); + for (int i = 0; i < indices.size(); i++) { + if (i != pos_axis) { + offset = offset * C->shape[i] + indices[i]; + } + } + auto B_shape_axis = B->shape[pos_axis]; + offset = common::AutoSimplify(offset * B_shape_axis); + auto idx = lang::CallExtern(extern_fun_name, {transpose_B, B_shape_axis, indices[pos_axis], offset, Expr(1)}); + std::vector A_indices(indices); + A_indices[pos_axis] = idx; + auto keep = ir::EQ::Make(idx, Expr(-1)); + return ir::Select::Make(keep, C(indices), A(A_indices)); + }, + name); + return res; +} + +ir::Tensor ScatterNd(const ir::Tensor &A, + const ir::Tensor &B, + const ir::Tensor &C, + const common::Target &target, + const std::vector &axes, + const std::string &name) { + CHECK(!A->shape.empty()); + CHECK_EQ(A->shape.size() + 1, B->shape.size()); + CHECK_EQ(A->shape.size() + axes.size() - 1, C->shape.size()); + + std::string extern_fun_name; + if (target.arch == common::Target::Arch::NVGPU) { + extern_fun_name.assign("cinn_cuda_find_int_nd"); + } else if (target.arch == common::Target::Arch::X86) { + extern_fun_name.assign("cinn_host_find_int_nd"); + } else { + LOG(FATAL) << "ScatterNd only support X86 and NVGPU ! Please Check.\n"; + } + + std::vector pos_axes; + for (auto axis : axes) { + if (axis < 0) { + pos_axes.push_back(axis + C->shape.size()); + } else { + pos_axes.push_back(axis); + } + } + + auto res = Compute( + C->shape, + [=](const std::vector &indices) { + auto offset = Expr(0); + std::vector A_indices; + for (int i = 0; i < indices.size(); i++) { + if (std::find(pos_axes.begin(), pos_axes.end(), i) == pos_axes.end()) { + offset = offset * C->shape[i] + indices[i]; + A_indices.push_back(indices[i]); + } + } + offset = offset * B->shape[B->shape.size() - 2] * B->shape[B->shape.size() - 1]; + auto keep = Expr(true); + std::vector idx; + for (int i = 0; i < pos_axes.size(); ++i) { + auto cur_idx = lang::CallExtern(extern_fun_name, + {B, + B->shape[B->shape.size() - 2], + indices[pos_axes[i]], + common::AutoSimplify(offset + Expr(i)), + Expr(static_cast(pos_axes.size()))}); + if (idx.empty()) { + idx.push_back(cur_idx); + A_indices.push_back(cur_idx); + } else { + keep = ir::And::Make(keep, ir::EQ::Make(idx[0], cur_idx)); + idx[0] = cur_idx; + } + } + keep = common::AutoSimplify(ir::And::Make(keep, ir::EQ::Make(idx[0], Expr(-1)))); + return ir::Select::Make(keep, C(indices), A(A_indices)); + }, + name); + return res; +} + +std::shared_ptr StrategyForScatter(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + auto attr_store = attrs.attr_store; + CHECK(attr_store.count("axis")) << "find no attr of axis"; + int axis = absl::get(attr_store.at("axis")); + std::string op_name("scatter"); + + framework::CINNCompute scatter_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of " << op_name << " compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 3U) << "3 input tensors for " << op_name << " compute\n"; + Expr A = pack_args[0]; + Expr B = pack_args[1]; + Expr C = pack_args[2]; + CHECK(A.as_tensor()); + CHECK(B.as_tensor()); + CHECK(C.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto tensor_B = B.as_tensor_ref(); + auto tensor_C = C.as_tensor_ref(); + auto stages = CreateStages({tensor_A, tensor_B, tensor_C}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") << ", B shape: " << utils::Join(tensor_B->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + std::string tensor_name = UniqName("Scatter_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 4U); + tensor_name = pack_args[3].operator std::string(); + } + ir::Tensor out = Scatter(tensor_A, tensor_B, tensor_C, target, axis, tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) << "Output type of " << op_name << " is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule scatter_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of scatter schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(scatter_compute, scatter_schedule, "strategy.scatter.x86", 1); + return strategy; +} + +std::shared_ptr StrategyForScatterNd(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + auto attr_store = attrs.attr_store; + CHECK(attr_store.count("axes")) << "find no attr of axis"; + std::vector axes = absl::get>(attr_store.at("axes")); + std::string op_name("scatter_nd"); + + framework::CINNCompute scatter_nd_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of " << op_name << " compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 3U) << "3 input tensors for " << op_name << " compute\n"; + Expr A = pack_args[0]; + Expr B = pack_args[1]; + Expr C = pack_args[2]; + CHECK(A.as_tensor()); + CHECK(B.as_tensor()); + CHECK(C.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto tensor_B = B.as_tensor_ref(); + auto tensor_C = C.as_tensor_ref(); + auto stages = CreateStages({tensor_A, tensor_B, tensor_C}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") << ", B shape: " << utils::Join(tensor_B->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + std::string tensor_name = UniqName("ScatterNd_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 4U); + tensor_name = pack_args[3].operator std::string(); + } + ir::Tensor out = ScatterNd(tensor_A, tensor_B, tensor_C, target, axes, tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) << "Output type of " << op_name << " is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule scatter_nd_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of scatter_nd schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(scatter_nd_compute, scatter_nd_schedule, "strategy.scatter_nd.x86", 1); + return strategy; +} + +std::vector> InferShapeForScatter(const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 3U) << "The input's shape size should be 3! Please check again."; + std::vector> res{inputs_shape[2]}; + return res; +} + +std::vector InferDtypeForScatter(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 3U) << "The input's type size should be 3! Please check again."; + CHECK_EQ(inputs_type[1], Int(32)) << "The index's type should be int! Please check again."; + std::vector res{inputs_type[2]}; + return res; +} + +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(scatter_ops) { + CINN_REGISTER_OP(scatter) + .describe("Scatter.") + .set_num_inputs(3) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForScatter) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForScatter)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForScatter)) + .set_support_level(4); + + CINN_REGISTER_OP(scatter_nd) + .describe("ScatterNd.") + .set_num_inputs(3) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForScatterNd) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForScatter)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForScatter)) + .set_support_level(4); + + return true; +} diff --git a/cinn/hlir/op/contrib/scatter.h b/cinn/hlir/op/contrib/scatter.h new file mode 100644 index 0000000000..aa8cbcc8ae --- /dev/null +++ b/cinn/hlir/op/contrib/scatter.h @@ -0,0 +1,44 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" + +namespace cinn { +namespace hlir { +namespace op { + +ir::Tensor Scatter(const ir::Tensor& A, + const ir::Tensor& B, + const ir::Tensor& out, + const common::Target& target, + const int& axis, + const std::string& name); + +ir::Tensor ScatterNd(const ir::Tensor& A, + const ir::Tensor& B, + const ir::Tensor& out, + const common::Target& target, + const std::vector& axes, + const std::string& name); + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/scatter_test.cc b/cinn/hlir/op/contrib/scatter_test.cc new file mode 100644 index 0000000000..51441a4cdc --- /dev/null +++ b/cinn/hlir/op/contrib/scatter_test.cc @@ -0,0 +1,165 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/scatter.h" + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/common/context.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace hlir { +namespace op { + +TEST(GenerateCode_Cpu, Scatter) { + common::Context::Global().ResetNameId(); + + auto target = common::DefaultHostTarget(); + + ir::Expr n(4); + ir::Expr h_in(8); + ir::Expr h_out(14); + + lang::Placeholder in1("in1", {n, h_in}); + lang::Placeholder in2("in2", {n, h_in}); + lang::Placeholder out("out", {n, h_out}); + ir::Tensor res = Scatter(in1, in2, out, target, 1, "test_scatter_out"); + + poly::StageMap stages = poly::CreateStages({res}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_Scatter", stages, {res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("Scatter_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + VLOG(6) << "Cpu Codegen result:"; + VLOG(6) << code << std::endl; + + auto target_source = R"ROC( +#include +#include + +void TestGenerateCodeCpu_Scatter(void* _args, int32_t num_args) +{ + cinn_buffer_t* _test_scatter_out = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _in1 = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 4, 8 }); + cinn_buffer_t* _in2 = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 8 }); + cinn_buffer_t* _out = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 4, 14 }); + cinn_buffer_malloc((void*)(0), _test_scatter_out); + cinn_buffer_malloc((void*)(0), _in1); + cinn_buffer_malloc((void*)(0), _in2); + cinn_buffer_malloc((void*)(0), _out); + const float* in1 = ((const float*)(_in1->memory)); + const int32_t* in2 = ((const int32_t*)(_in2->memory)); + const float* out = ((const float*)(_out->memory)); + float* test_scatter_out = ((float*)(_test_scatter_out->memory)); + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 14; j += 1) { + test_scatter_out[((14 * i) + j)] = (((cinn_host_find_int_nd(_in2, 8, j, (8 * i), 1) == -1)) ? out[((14 * i) + j)] : in1[((8 * i) + cinn_host_find_int_nd(_in2, 8, j, (8 * i), 1))]); + }; + }; + cinn_buffer_free((void*)(0), _in1); + cinn_buffer_free((void*)(0), _in2); + cinn_buffer_free((void*)(0), _out); + cinn_buffer_free((void*)(0), _test_scatter_out); +} +)ROC"; + CHECK_EQ(utils::Trim(code), utils::Trim(target_source)); +} + +TEST(GenerateCode_Cpu, ScatterNd) { + common::Context::Global().ResetNameId(); + + auto target = common::DefaultHostTarget(); + + ir::Expr n(4); + ir::Expr h_in(8); + ir::Expr h_out(14); + + lang::Placeholder in1("in1", {n, h_in}); + lang::Placeholder in2("in2", {n, h_in, ir::Expr(1)}); + lang::Placeholder out("out", {n, h_out}); + ir::Tensor res = ScatterNd(in1, in2, out, target, {1}, "test_scatter_out"); + + poly::StageMap stages = poly::CreateStages({res}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_Scatter", stages, {res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("Scatter_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + VLOG(6) << "Cpu Codegen result:"; + VLOG(6) << code << std::endl; + + auto target_source = R"ROC( +#include +#include + +void TestGenerateCodeCpu_Scatter(void* _args, int32_t num_args) +{ + cinn_buffer_t* _test_scatter_out = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _in1 = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 4, 8 }); + cinn_buffer_t* _in2 = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 8, 1 }); + cinn_buffer_t* _out = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 4, 14 }); + cinn_buffer_malloc((void*)(0), _test_scatter_out); + cinn_buffer_malloc((void*)(0), _in1); + cinn_buffer_malloc((void*)(0), _in2); + cinn_buffer_malloc((void*)(0), _out); + const float* in1 = ((const float*)(_in1->memory)); + const int32_t* in2 = ((const int32_t*)(_in2->memory)); + const float* out = ((const float*)(_out->memory)); + float* test_scatter_out = ((float*)(_test_scatter_out->memory)); + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 14; j += 1) { + test_scatter_out[((14 * i) + j)] = (((cinn_host_find_int_nd(_in2, 8, j, (8 * i), 1) == -1)) ? out[((14 * i) + j)] : in1[((8 * i) + cinn_host_find_int_nd(_in2, 8, j, (8 * i), 1))]); + }; + }; + cinn_buffer_free((void*)(0), _in1); + cinn_buffer_free((void*)(0), _in2); + cinn_buffer_free((void*)(0), _out); + cinn_buffer_free((void*)(0), _test_scatter_out); +} + )ROC"; + CHECK_EQ(utils::Trim(code), utils::Trim(target_source)); +} + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/use_ops.h b/cinn/hlir/op/use_ops.h index 5c3e6b3c0b..38f538f1d5 100644 --- a/cinn/hlir/op/use_ops.h +++ b/cinn/hlir/op/use_ops.h @@ -22,6 +22,8 @@ CINN_USE_REGISTER(broadcast_ops) CINN_USE_REGISTER(broadcast_grad_ops) CINN_USE_REGISTER(elementwise_ops) CINN_USE_REGISTER(transform_ops) +CINN_USE_REGISTER(gather_ops) +CINN_USE_REGISTER(scatter_ops) CINN_USE_REGISTER(cast_ops) CINN_USE_REGISTER(sort_ops) CINN_USE_REGISTER(squeeze_ops) diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc index f403265903..dc51bf47f9 100644 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -536,7 +536,9 @@ void BindFrontend(pybind11::module *m) { py::arg("output_shape") = std::vector{}) .def("cast", &NetBuilder::Cast, py::arg("x"), py::arg("dtype")) .def("clip", &NetBuilder::Clip, py::arg("x"), py::arg("max"), py::arg("min")) - .def("arange", &NetBuilder::Arange, py::arg("start"), py::arg("end"), py::arg("step"), py::arg("dtype")); + .def("arange", &NetBuilder::Arange, py::arg("start"), py::arg("end"), py::arg("step"), py::arg("dtype")) + .def("gather", &NetBuilder::Gather, py::arg("x"), py::arg("index"), py::arg("axis")) + .def("gather_nd", &NetBuilder::GatherNd, py::arg("x"), py::arg("index"), py::arg("axes")); auto computation = py::class_>(*m, "Computation"); py::class_(computation, "CompileOptions") diff --git a/cinn/runtime/cuda/CMakeLists.txt b/cinn/runtime/cuda/CMakeLists.txt index ffa4c25136..612d5f6607 100755 --- a/cinn/runtime/cuda/CMakeLists.txt +++ b/cinn/runtime/cuda/CMakeLists.txt @@ -1,15 +1,15 @@ if (NOT WITH_CUDA) return() -endif() +endif () core_gather_headers() gather_srcs(cinnapi_src SRCS - cuda_module.cc - cuda_util.cc - cuda_intrinsics.cc - ) + cuda_module.cc + cuda_util.cc + cuda_intrinsics.cc + ) nv_test(test_cuda_module SRCS cuda_module_test.cc DEPS cinncore) nv_library(cuda_runtime SRCS cinn_cuda_runtime_source.cuh) diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index 9bcdee96e1..77f73cf062 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -1,7 +1,6 @@ /** * \file This file contains all the intrinsics available to be used in CUDA code generated by CodeGen. */ - #define CINN_FLT_MAX 3.402823e+38f #define CINN_FLT_MIN -3.402823e+38f @@ -50,18 +49,15 @@ __device__ inline bool cinn_any(const bool left, const bool right) { return left #define cinn_shuffle_function(offset, op, init) \ shfl_res = __shfl_down_sync(mask, tmp_val, offset, 32); \ shfl_res = threadIdx.x % 32 + offset < lane ? shfl_res : init; \ - tmp_val = op(tmp_val, shfl_res); - -#define cinn_warp_shuffle_internal_kernel(TYPE, value, op, init) \ - TYPE tmp_val = value, shfl_res; \ - unsigned int mask = __activemask(); \ - unsigned int lane = __popc(mask); \ - cinn_shuffle_function(16, op, init) \ - cinn_shuffle_function(8, op, init) \ - cinn_shuffle_function(4, op, init) \ - cinn_shuffle_function(2, op, init) \ - cinn_shuffle_function(1, op, init) \ - tmp_val = __shfl_sync(mask, tmp_val, 0, 32); \ + tmp_val = op(tmp_val, shfl_res); + +#define cinn_warp_shuffle_internal_kernel(TYPE, value, op, init) \ + TYPE tmp_val = value, shfl_res; \ + unsigned int mask = __activemask(); \ + unsigned int lane = __popc(mask); \ + cinn_shuffle_function(16, op, init) cinn_shuffle_function(8, op, init) cinn_shuffle_function(4, op, init) \ + cinn_shuffle_function(2, op, init) cinn_shuffle_function(1, op, init) tmp_val = \ + __shfl_sync(mask, tmp_val, 0, 32); \ return tmp_val; __device__ inline float cinn_warp_shuffle_sum_internal(const float value) { @@ -106,28 +102,28 @@ __device__ inline float cinn_warp_reduce_avg(const float *buf, int offset, int e } #define cinn_block_reduce_internal_kernel(TYPE, value, init_value, cinn_warp_shuffle_internal) \ - int warp_id = threadIdx.x / 32; \ - __shared__ TYPE tmp[32]; \ - if (warp_id == 0) { \ - tmp[threadIdx.x] = init_value; \ - } \ - TYPE tmp_val = cinn_warp_shuffle_internal(value); \ - if (blockDim.x <= 32) { \ - return tmp_val; \ - } \ - __syncthreads(); \ - if (threadIdx.x % 32 == 0) { \ - tmp[warp_id] = tmp_val; \ - } \ - __syncthreads(); \ - if (warp_id == 0) { \ - tmp_val = tmp[threadIdx.x]; \ - tmp_val = cinn_warp_shuffle_internal(tmp_val); \ - if (threadIdx.x == 0) { \ - tmp[0] = tmp_val; \ - } \ - } \ - __syncthreads(); \ + int warp_id = threadIdx.x / 32; \ + __shared__ TYPE tmp[32]; \ + if (warp_id == 0) { \ + tmp[threadIdx.x] = init_value; \ + } \ + TYPE tmp_val = cinn_warp_shuffle_internal(value); \ + if (blockDim.x <= 32) { \ + return tmp_val; \ + } \ + __syncthreads(); \ + if (threadIdx.x % 32 == 0) { \ + tmp[warp_id] = tmp_val; \ + } \ + __syncthreads(); \ + if (warp_id == 0) { \ + tmp_val = tmp[threadIdx.x]; \ + tmp_val = cinn_warp_shuffle_internal(tmp_val); \ + if (threadIdx.x == 0) { \ + tmp[0] = tmp_val; \ + } \ + } \ + __syncthreads(); \ return tmp[0]; // block reduce sum internal diff --git a/cinn/utils/data_util.cc b/cinn/utils/data_util.cc index 6ae079129d..56b5cdebd5 100644 --- a/cinn/utils/data_util.cc +++ b/cinn/utils/data_util.cc @@ -14,8 +14,34 @@ #include "cinn/utils/data_util.h" +#include "iostream" + namespace cinn { +void SetRandInt(hlir::framework::Tensor tensor, const common::Target& target, int seed) { + if (seed == -1) { + std::random_device rd; + seed = rd(); + } + std::default_random_engine engine(seed); + std::uniform_int_distribution dist(0, 10); + size_t num_ele = tensor->shape().numel(); + std::vector random_data(num_ele); + for (size_t i = 0; i < num_ele; i++) { + random_data[i] = static_cast(dist(engine)); // All random data + } + + auto* data = tensor->mutable_data(target); +#ifdef CINN_WITH_CUDA + if (target == common::DefaultNVGPUTarget()) { + cudaMemcpy(data, random_data.data(), num_ele * sizeof(int), cudaMemcpyHostToDevice); + return; + } +#endif + CHECK(target == common::DefaultHostTarget()); + std::copy(random_data.begin(), random_data.end(), data); +} + template <> void SetRandData(hlir::framework::Tensor tensor, const common::Target& target, int seed) { if (seed == -1) { diff --git a/cinn/utils/data_util.h b/cinn/utils/data_util.h index b4aa4d005c..d1d765c549 100644 --- a/cinn/utils/data_util.h +++ b/cinn/utils/data_util.h @@ -23,6 +23,8 @@ #endif namespace cinn { +void SetRandInt(hlir::framework::Tensor tensor, const common::Target& target, int seed = -1); + template void SetRandData(hlir::framework::Tensor tensor, const common::Target& target, int seed = -1);