Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix compute and schedule func of sort and argsoft
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Feb 22, 2023
1 parent a9630cc commit b3d1a47
Show file tree
Hide file tree
Showing 17 changed files with 524 additions and 53 deletions.
2 changes: 2 additions & 0 deletions cinn/frontend/decomposer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ gather_srcs(cinnapi_src SRCS
batch_norm.cc
conv2d_grad.cc
norm.cc
top_k.cc
)

cc_library(decomposer_test_helper SRCS test_helper.cc DEPS cinncore)
Expand All @@ -17,6 +18,7 @@ cc_test(test_elementwise_decomposer SRCS elementwise_test.cc DEPS cinncore decom
cc_test(test_broadcast_decomposer SRCS broadcast_test.cc DEPS cinncore decomposer_test_helper)
cc_test(test_batch_norm_decomposer SRCS batch_norm_test.cc DEPS cinncore decomposer_test_helper)
cc_test(test_norm_decomposer SRCS norm_test.cc DEPS cinncore decomposer_test_helper)
cc_test(test_top_k_decomposer SRCS top_k_test.cc DEPS cinncore decomposer_test_helper)
endif()
if(WITH_CUDNN)
cc_test(test_conv2d_grad_decomposer SRCS conv2d_grad_test.cc DEPS cinncore decomposer_test_helper)
Expand Down
51 changes: 51 additions & 0 deletions cinn/frontend/decomposer/top_k.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/frontend/decomposer_registry.h"
#include "cinn/frontend/syntax.h"

namespace cinn {
namespace frontend {
namespace decomposer {

void top_k(const Instruction& instr, const DecomposerContext& context) {
CHECK_EQ(instr->inputs.size(), 1UL) << " 1 input tensor for " << instr->op_type;
CHECK_EQ(instr->outputs.size(), 2UL) << "2 output tensors for " << instr->op_type;
auto x = instr->inputs[0];
auto output = instr->outputs[0];
auto indices = instr->outputs[1];

auto* builder = context.builder();
int k = instr.GetAttrs<int>("k");
CHECK_GT(k, 0) << "The attribute k must be greater than 0.";
int axis = instr.GetAttrs<int>("axis");

auto sort_tmp = builder->Sort(x, axis, false);
auto sort_out = builder->Slice(sort_tmp, {axis}, {0}, {k});
auto argsort_tmp = builder->ArgSort(x, axis, false);
auto argsort_out = builder->Cast(builder->Slice(argsort_tmp, {axis}, {0}, {k}), "int64");

// map the the output of decomposed operator to the original.
context.MapOutToOrigin(sort_out, output);
context.MapOutToOrigin(argsort_out, indices);
}

} // namespace decomposer
} // namespace frontend
} // namespace cinn

CINN_REGISTER_HELPER(top_k_decomposer) {
CINN_DECOMPOSER_REGISTER(top_k, cinn::frontend::decomposer::top_k);
return true;
}
55 changes: 55 additions & 0 deletions cinn/frontend/decomposer/top_k_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include "cinn/frontend/decomposer/test_helper.h"

namespace cinn::frontend {

TEST(Decomposer, top_k_decomposer) {
NetBuilder net_builder("top_k_decomposer");
std::unordered_set<std::string> output_names;
{
auto x = net_builder.CreateInput(Float(32), {10, 5}, "x");
auto y = net_builder.TopK(x, 1, -1, true);
output_names.insert(y[0]->id);
output_names.insert(y[1]->id);
}
auto program = net_builder.Build();

auto target = common::DefaultTarget();
RunDecomposer(&program, target);

auto graph = std::make_shared<hlir::framework::Graph>(program, output_names, target);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");

auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto run_program = gc.Build();

std::vector<float> x(10 * 5);
InitRandomVector<float>(&x, 10 * 5, 0.0f, 1.0f, 1e-3);
std::vector<std::pair<std::string, std::vector<float>>> inputs = {{"x", x}};
for (auto& input : inputs) {
scope->Var<hlir::framework::Tensor>(input.first);
auto tensor = scope->GetTensor(input.first);
auto* data = tensor->mutable_data<float>(target);
CopyFromVector(input.second, tensor, target);
}
run_program->Execute();
}

} // namespace cinn::frontend
1 change: 1 addition & 0 deletions cinn/frontend/decomposer/use_decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ CINN_USE_REGISTER(batch_norm_train_decomposer)
CINN_USE_REGISTER(batch_norm_grad_decomposer)
CINN_USE_REGISTER(conv2d_grad_decomposer)
CINN_USE_REGISTER(norm_decomposer)
CINN_USE_REGISTER(top_k_decomposer)
4 changes: 4 additions & 0 deletions cinn/frontend/net_builder.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -685,5 +685,9 @@ Variable NetBuilder::Norm(const Variable& x, int axis, float epsilon) {
return instr.GetOutput(0);
}

std::vector<Variable> NetBuilder::TopK(const Variable& x, int k, int axis, bool largest) {
return CustomInstr("top_k", {x}, {{"k", k}, {"axis", axis}, {"largest", largest}});
}

} // namespace frontend
} // namespace cinn
16 changes: 16 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,22 @@ class NetBuilder {
*/
Variable Norm(const Variable& x, int axis = -1, float epsilon = 1e-12f);

/**
* @brief Return values and indices of the k largest or smallest at the optional axis.
* If the input is a 1-D Tensor, finds the k largest or smallest values and indices.
* If the input is a Tensor with higher rank, this operator computes the top k values
* and indices along the axis.
* @param x Input tensor.
* @param k The number of top elements to look for along the axis.
* @param axis Axis to compute indices along. The effective range is [-R, R), where R is
* x.ndim. when axis < 0, it works the same way as axis + R. Default is -1.
* @param largest largest is a flag, if set to true, algorithm will sort by descending
* order, otherwise sort by ascending order. Default is True.
* @return The values and indices. The value data type is the same as the input x. The
* indices data type is int64.
*/
std::vector<Variable> TopK(const Variable& x, int k, int axis, bool largest);

private:
CINN_DISALLOW_COPY_AND_ASSIGN(NetBuilder);
};
Expand Down
48 changes: 48 additions & 0 deletions cinn/frontend/op_mappers/paddle/top_k.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/op_mappers/common_utils.h"

namespace cinn {
namespace frontend {
namespace paddle_mappers {

void TopKOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
CHECK_EQ(op_desc.Output("Indices").size(), 1UL);
auto indices_name = op_desc.Output("Indices").front();
auto x = ctx.GetVar(x_name);

CHECK(op_desc.HasAttr("k"));
auto k = utils::GetAttrOrDefault<int>(op_desc, "k");
auto outs = ctx.Builder()->TopK(x, k, -1, true);

ctx.AddVar(out_name, outs[0]);
ctx.AddVarModelToProgram(out_name, outs[0]->id);
ctx.AddVar(indices_name, outs[1]);
ctx.AddVarModelToProgram(indices_name, outs[1]->id);
}

} // namespace paddle_mappers
} // namespace frontend
} // namespace cinn

CINN_REGISTER_HELPER(paddle_top_k) {
CINN_REGISTER_OP_MAPPER(top_k, cinn::frontend::paddle_mappers::TopKOpMapper)
return true;
}
1 change: 1 addition & 0 deletions cinn/frontend/op_mappers/use_op_mappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ CINN_USE_REGISTER(paddle_reduce)
CINN_USE_REGISTER(paddle_atan)
CINN_USE_REGISTER(paddle_gaussian_random)
CINN_USE_REGISTER(paddle_uniform_random)
CINN_USE_REGISTER(paddle_top_k)
CINN_USE_REGISTER(paddle_one_hot)
CINN_USE_REGISTER(paddle_cumsum)
CINN_USE_REGISTER(paddle_norm)
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/op/contrib/argmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ Tensor Argmax(const Tensor &in_tensor,
output_shape.push_back(Expr(1));
}

auto sort_index = ArgSort(in_tensor, target, stages, pos_axis, false, name + "_index");
auto sort_index = ArgSort(in_tensor, target, stages, pos_axis, false, name + "_index").at(0);
auto res = Compute(
output_shape,
[=](const std::vector<Expr> &indices) {
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/op/contrib/argmin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Tensor Argmin(const Tensor &in_tensor,
if (output_shape.empty()) {
output_shape.push_back(Expr(1));
}
auto sort_index = ArgSort(in_tensor, target, stages, pos_axis, true, name + "_index");
auto sort_index = ArgSort(in_tensor, target, stages, pos_axis, true, name + "_index").at(0);
auto res = Compute(
output_shape,
[=](const std::vector<Expr> &indices) {
Expand Down
Loading

0 comments on commit b3d1a47

Please sign in to comment.