Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Feb 14, 2023
1 parent 1e6857e commit 8677775
Show file tree
Hide file tree
Showing 8 changed files with 372 additions and 3 deletions.
2 changes: 2 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,5 +676,7 @@ Variable NetBuilder::Cholesky(const Variable& x, bool upper) {
return CustomInstr("cholesky", {x}, {{"upper", upper}}).front();
}

std::vector<Variable> NetBuilder::TopK(const Variable& x, int k) { return CustomInstr("top_k", {x}, {{"k", k}}); }

} // namespace frontend
} // namespace cinn
18 changes: 18 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,24 @@ class NetBuilder {
*/
Variable Cholesky(const Variable& x, bool upper = false);

/**
* @brief Return values and indices of the k largest or smallest at the optional axis.
* If the input is a 1-D Tensor, finds the k largest or smallest values and indices.
* If the input is a Tensor with higher rank, this operator computes the top k values
* and indices along the axis.
* @param x Input tensor.
* @param k The number of top elements to look for along the axis.
* @param axis Axis to compute indices along. The effective range is [-R, R), where R is
* x.ndim. when axis < 0, it works the same way as axis + R. Default is -1.
* @param largest largest is a flag, if set to true, algorithm will sort by descending
* order, otherwise sort by ascending order. Default is True.
* @param sorted controls whether to return the elements in sorted order, default value
* is True. In gpu device, it always return the sorted value.
* @return The values and indices. The value data type is the same as the input x. The
* indices data type is int64.
*/
std::vector<Variable> TopK(const Variable& x, int k);

private:
CINN_DISALLOW_COPY_AND_ASSIGN(NetBuilder);
};
Expand Down
51 changes: 51 additions & 0 deletions cinn/frontend/op_mappers/paddle/top_k.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/op_mappers/common_utils.h"

namespace cinn {
namespace frontend {
namespace paddle_mappers {

void TopKOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
CHECK_EQ(op_desc.Output("Indices").size(), 1UL);
auto indices_name = op_desc.Output("Indices").front();

CHECK(op_desc.HasAttr("k"));
int k = op_desc.GetAttr<int>("k");

auto x = ctx.GetVar(x_name);
auto outs = ctx.Builder()->TopK(x, k);
auto& out = outs[0];
auto& indices = outs[1];

ctx.AddVar(out_name, out, true);
ctx.AddVar(indices_name, indices, true);
ctx.AddVarModelToProgram(out_name, out->id, true);
ctx.AddVarModelToProgram(indices_name, indices->id, true);
}

} // namespace paddle_mappers
} // namespace frontend
} // namespace cinn

CINN_REGISTER_HELPER(paddle_top_k) {
CINN_REGISTER_OP_MAPPER(topk, cinn::frontend::paddle_mappers::TopKOpMapper)
return true;
}
1 change: 1 addition & 0 deletions cinn/frontend/op_mappers/use_op_mappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ CINN_USE_REGISTER(paddle_gaussian_random)
CINN_USE_REGISTER(paddle_uniform_random)
CINN_USE_REGISTER(paddle_one_hot)
CINN_USE_REGISTER(paddle_cumsum)
CINN_USE_REGISTER(paddle_top_k)

CINN_USE_REGISTER(science_broadcast)
CINN_USE_REGISTER(science_transform)
Expand Down
8 changes: 5 additions & 3 deletions cinn/hlir/framework/op_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,7 @@ std::vector<ir::LoweredFunc> OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo

auto node_datas = GetAllNodeData(node);
for (auto node_data : node_datas) {
VLOG(3) << "cinn_inputs.push_back " << node_data->id();
LOG(INFO) << "cinn_inputs.push_back " << node_data->id();
group->output_names.push_back(node_data->id());
out_types.push_back(this->type_dict_.at(node_data->id()));
out_shapes.push_back(this->shape_dict_.at(node_data->id()));
Expand Down Expand Up @@ -1324,7 +1324,9 @@ std::vector<ir::LoweredFunc> OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo
for (int i = 0; i < pack->size() - 1; i++) {
ir::Expr temp = pack[i];
// checkout whether the tensor is with buffer.
LOG(INFO) << "compute pack i = " << i << ", defined = " << temp.as_tensor_ref()->buffer.defined();
if (!temp.as_tensor_ref()->buffer.defined() || this->target_ != common::DefaultNVGPUTarget()) {
LOG(INFO) << "inputs here!";
inputs.push_back(temp.as_tensor_ref());
temp.as_tensor_ref()->WithBuffer();
args.emplace_back(temp.as_tensor_ref()->buffer, ir::Argument::IO::kOutput);
Expand All @@ -1350,8 +1352,8 @@ std::vector<ir::LoweredFunc> OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo
ir::Expr func_body = expr_pack[0];
std::vector<std::string> input_output_nodes(group->input_names);
input_output_nodes.insert(input_output_nodes.end(), group->output_names.begin(), group->output_names.end());
VLOG(6) << "func.size() = " << func.size() << ", expr_pack.size() = " << expr_pack.size();
VLOG(6) << "args.size() = " << args.size() << ", input_output_nodes.size() = " << input_output_nodes.size();
LOG(INFO) << "func.size() = " << func.size() << ", expr_pack.size() = " << expr_pack.size();
LOG(INFO) << "args.size() = " << args.size() << ", input_output_nodes.size() = " << input_output_nodes.size();
if (args.size() > input_output_nodes.size()) {
args = lang::GetArgs(func_body, input_output_nodes);
}
Expand Down
227 changes: 227 additions & 0 deletions cinn/hlir/op/contrib/sort.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(const framework::NodeA
tensor_name = pack_args[1].operator std::string();
}
ir::Tensor out = ArgSort(tensor_A, target, stages, axis, is_ascend, tensor_name);

std::vector<CINNValue> res;
stages->InsertLazily(out);
res.push_back(CINNValue(out));
Expand Down Expand Up @@ -327,6 +328,222 @@ std::vector<Type> InferDtypeForArgSort(const std::vector<Type> &inputs_type, con
return {Int(32)};
}

std::vector<ir::Tensor> ArgSort(const ir::Tensor &A,
const common::Target &target,
poly::StageMap stages,
const int &axis,
const bool &is_ascend,
const std::string &name,
int k) {
std::string find_func_name;
std::string index_func_name;
if (target.arch == common::Target::Arch::NVGPU) {
index_func_name.assign("cinn_cuda_");
find_func_name.assign("cinn_cuda_find_int_nd");
} else if (target.arch == common::Target::Arch::X86) {
index_func_name.assign("cinn_host_");
find_func_name.assign("cinn_host_find_int_nd");
} else {
LOG(FATAL) << "ArgSort only supports X86 and NVGPU ! Please Check.\n";
}
if (is_ascend) {
index_func_name.append("lt_num_float");
} else {
index_func_name.append("gt_num_float");
}
int pos_axis = axis;
if (pos_axis < 0) {
pos_axis += A->shape.size();
}
std::vector<cinn::ir::Expr> shape;
for (int i = 0; i < A->shape.size(); ++i) {
if (i == pos_axis) {
shape.emplace_back(Expr(k));
} else {
shape.emplace_back(A->shape[i]);
}
}
auto positions = Compute(
A->shape,
[=](const std::vector<Expr> &indices) {
Expr offset(0);
Expr stride(1);
for (int i = 0; i < indices.size(); i++) {
if (i < pos_axis) {
offset = offset * A->shape[i] + indices[i];
} else if (i == pos_axis) {
offset = offset * A->shape[i];
} else {
offset = offset * A->shape[i] + indices[i];
stride = stride * A->shape[i];
}
}
offset = common::AutoSimplify(offset);
stride = common::AutoSimplify(stride);
auto A_shape_axis = A->shape[pos_axis];
return lang::CallExtern(index_func_name, {A, A_shape_axis, A(indices), offset, stride});
},
name + "_temp");
auto res = Compute(
shape,
[=](const std::vector<Expr> &indices) {
Expr offset(0);
Expr stride(1);
for (int i = 0; i < indices.size(); i++) {
if (i < pos_axis) {
offset = offset * A->shape[i] + indices[i];
} else if (i == pos_axis) {
offset = offset * A->shape[i];
} else {
offset = offset * A->shape[i] + indices[i];
stride = stride * A->shape[i];
}
}
offset = common::AutoSimplify(offset);
stride = common::AutoSimplify(stride);

auto A_shape_axis = A->shape[pos_axis];
auto idx = lang::CallExtern(find_func_name, {positions, A_shape_axis, indices[pos_axis], offset, stride});
return idx;
},
name + "_idx");
return {res, positions};
}

ir::Tensor TopK(const ir::Tensor &sort_index,
const ir::Tensor &A,
const common::Target &target,
poly::StageMap stages,
const int &axis,
const bool &is_ascend,
const std::string &name) {
int pos_axis = axis;
if (pos_axis < 0) {
pos_axis += A->shape.size();
}
auto res = Compute(
sort_index->shape,
[=](const std::vector<Expr> &indices) {
std::vector<Expr> A_indices(indices);
A_indices[pos_axis] = sort_index(indices);
return A(A_indices);
},
name + "_out");
return res;
}

std::shared_ptr<framework::OpStrategy> StrategyForTopK(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
auto attr_store = attrs.attr_store;
int axis = -1;
bool is_ascend = true;

auto it = attr_store.find("k");
CHECK(it != attr_store.end());
int k = absl::get<int>(it->second);

framework::CINNCompute topk_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input arguments of TopK compute is empty! Please check.\n";
CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 1U) << "At least 1 input tensors for TopK compute\n";
Expr A = pack_args[0];
CHECK(A.as_tensor());
CHECK(!output_shapes.empty());
auto tensor_A = A.as_tensor_ref();
auto stages = CreateStages({tensor_A});
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
auto tensor_name = UniqName("TopK_out");
if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(pack_args.size(), 3U);
CHECK(pack_args[1].is_string());
}
auto indices = ArgSort(tensor_A, target, stages, axis, is_ascend, tensor_name, k);
ir::Tensor out = TopK(indices[0], tensor_A, target, stages, axis, is_ascend, tensor_name);

std::vector<CINNValue> res;
for (auto &r : {out, indices[0], indices[1]}) {
res.emplace_back(CINNValue{r});
stages->InsertLazily(r);
}
// stages->InsertLazily(indices[1]);
res.emplace_back(stages);
CHECK(!out_type.empty()) << "Output type of TopK is empty! Please check.\n";
*ret = CINNValuePack{res};
});

framework::CINNSchedule topk_schedule([=](lang::Args args, lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) {
CHECK(!args.empty()) << "The input argument of topk_schedule is empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
/*
Expr expr_1 = arg_pack[1];
auto tmp_1 = ir_sch.GetBlock(expr_1.as_tensor()->name);
ir_sch.ComputeInline(tmp_1);
*/

/* tmp var
Expr expr_2 = arg_pack[2];
auto tmp_2 = ir_sch.GetBlock(expr_2.as_tensor()->name);
ir_sch.ComputeInline(tmp_2);
*/
long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies<int>());
if (prod_size > 1) {
if (target.arch == Target::Arch::NVGPU) {
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
}
}
LOG(INFO) << "sch expr 0: " << ir_sch.GetModule().GetExprs().at(0);
// LOG(INFO) << "sch expr 1: " << ir_sch.GetModule().GetExprs().at(1);
std::vector<common::CINNValue> res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of topk_schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
}
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(topk_compute, topk_schedule, "strategy.topk", 1);
return strategy;
}

std::vector<std::vector<int>> InferShapeForTopK(const std::vector<std::vector<int>> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_shape.size(), 1UL) << "The input's shape size should be 1! Please check again.";
auto res = inputs_shape;
auto k_it = attrs.find("k");
CHECK(k_it != attrs.end()) << "The attr k of topk does not exist.";
int k = absl::get<int>(k_it->second);
res[0].back() = k;
return {res[0], res[0]};
}

std::vector<Type> InferDtypeForTopK(const std::vector<Type> &inputs_type, const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_type.size(), 1UL) << "The input's type size should be 1! Please check again.";
std::vector<Type> res{inputs_type[0], Int(32)};
return res;
}

} // namespace op
} // namespace hlir
} // namespace cinn
Expand All @@ -350,5 +567,15 @@ CINN_REGISTER_HELPER(sort_ops) {
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForArgSort))
.set_support_level(4);

CINN_REGISTER_OP(top_k)
.describe("Return values and indices of the k largest at the optional axis.")
.set_num_inputs(1)
.set_num_outputs(2)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForTopK)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForTopK))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForTopK))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible)
.set_support_level(4);

return true;
}
1 change: 1 addition & 0 deletions cinn/pybind/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ void BindFrontend(pybind11::module *m) {
.def("argmax", &NetBuilder::Argmax, py::arg("x"), py::arg("axis"), py::arg("keep_dim") = false)
.def("argmin", &NetBuilder::Argmin, py::arg("x"), py::arg("axis"), py::arg("keep_dim") = false)
.def("lookup_table", &NetBuilder::LookupTable, py::arg("table"), py::arg("ids"), py::arg("padding_idx"))
.def("top_k", &NetBuilder::TopK, py::arg("x"), py::arg("k"))
.def("one_hot",
&NetBuilder::OneHot,
py::arg("indices"),
Expand Down
Loading

0 comments on commit 8677775

Please sign in to comment.