Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

【Hackathon 4th No.83】Add resize op #1306

Merged
merged 17 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ Variable NetBuilder::Repeat(const Variable& x, int repeats, int axis) {
return CustomInstr("repeat", {x}, {{"repeats", repeats}, {"axis", axis}}).front();
}

Variable NetBuilder::Resize(const Variable& x, const std::vector<int>& out_shape, const std::string& mode) {
return CustomInstr("resize", {x}, {{"out_shape", out_shape}, {"mode", mode}}).front();
}

std::vector<Variable> NetBuilder::BatchNorm(const Variable& a,
const Variable& scale,
const Variable& bias,
Expand Down
9 changes: 9 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,15 @@ class NetBuilder {
*/
Variable Repeat(const Variable& x, int repeats, int axis);

/**
* @brief Resize operator does 2D scaling to the given size.
* @param x An input variable, the data layout of input is NCHW
* @param out_shape The out size to which the image will be resized.
* @param mode Scale method to used [nearest, bilinear, bicubic].
* @return The resized result.
*/
Variable Resize(const Variable& x, const std::vector<int>& out_shape, const std::string& mode);

// *******************************************
// Broadcast operator
/**
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ gather_srcs(cinnapi_src SRCS
uniform_random.cc
cholesky.cc
triangular_solve.cc
resize.cc
)

cc_test(test_gather_nd SRCS gather_nd_test.cc DEPS cinncore)
Expand Down
248 changes: 248 additions & 0 deletions cinn/hlir/op/contrib/resize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/resize.h"

#include <gflags/gflags.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cinn/common/cas.h"
#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/macros.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/pe/elementwise.h"
#include "cinn/hlir/pe/ir_schedule_pe.h"
#include "cinn/hlir/pe/transform.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"
#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace hlir {
namespace op {

using common::CINNValuePack;

#define __get_pixel(input, h, w, n, c, y, x) \
input({n, \
c, \
common::AutoSimplify(ir::Max::Make(ir::Min::Make(y, h - Expr(1)), Expr(0))), \
common::AutoSimplify(ir::Max::Make(ir::Min::Make(x, w - Expr(1)), Expr(0)))})

ir::Tensor Resize(const ir::Tensor &input,
const common::Target &target,
const std::vector<int> &out_shape,
const std::string &mode,
const std::string &output_name) {
int ndim = static_cast<int>(input->shape.size());
CHECK_EQ(ndim, 4U) << "The dimension of x must be 4.";
thisjiang marked this conversation as resolved.
Show resolved Hide resolved
CHECK_EQ(out_shape.size(), 2U) << "The length of out_shape must be 2.";
CHECK(out_shape[0] > 0 && out_shape[1] > 0) << "The element of out_shape must be great that 0.";
CHECK(mode == "nearest" || mode == "bilinear" || mode == "bicubic")
<< "Resize only supports `nearest`, `bilinear` and `bicubic` mode.";

std::string func_name;

if (target.arch == common::Target::Arch::NVGPU) {
func_name.assign("cinn_cuda_resize_");
} else if (target.arch == common::Target::Arch::X86) {
func_name.assign("cinn_host_resize_");
} else {
LOG(FATAL) << "Resize only supports X86 and NVGPU ! Please Check.\n";
}

if (mode == "bilinear") {
func_name.append("bilinear");
} else if (mode == "bicubic") {
func_name.append("bicubic");
}

Expr in_h = input->shape[2];
Expr in_w = input->shape[3];
Expr out_h = Expr(out_shape[0]);
Expr out_w = Expr(out_shape[1]);

std::vector<Expr> new_shape = {input->shape[0], input->shape[1], out_h, out_w};
ir::Tensor res = lang::Compute(
{new_shape},
[=](const std::vector<Expr> &indices) {
Expr out_y = indices[2];
Expr out_x = indices[3];
Expr value;

if (mode == "nearest") {
Expr in_y = ir::Cast::Make(common::F32(), in_h) / ir::Cast::Make(common::F32(), out_h) *
ir::Cast::Make(common::F32(), out_y);
Expr in_x = ir::Cast::Make(common::F32(), in_w) / ir::Cast::Make(common::F32(), out_w) *
ir::Cast::Make(common::F32(), out_x);
Expr in_y_int = ir::Cast::Make(common::Int(32), lang::Floor(in_y));
Expr in_x_int = ir::Cast::Make(common::Int(32), lang::Floor(in_x));
std::vector<Expr> in_indices = {indices[0], indices[1], in_y_int, in_x_int};
value = input(in_indices);

} else if (mode == "bilinear") {
value = lang::CallExtern(
func_name, {input, input->shape[1], in_h, in_w, out_h, out_w, indices[0], indices[1], out_y, out_x});

} else if (mode == "bicubic") {
value = lang::CallExtern(
func_name, {input, input->shape[1], in_h, in_w, out_h, out_w, indices[0], indices[1], out_y, out_x});
}

return value;
},
common::UniqName(output_name));

return res;
}

std::vector<std::vector<int>> InferShapeForResize(const std::vector<std::vector<int>> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_shape[0].size(), 4U) << "The input's shape size should be 4! Please check again.";
framework::shape_t x_shape = inputs_shape[0];
std::vector<int> new_shape, out_shape;
new_shape.push_back(x_shape[0]);
new_shape.push_back(x_shape[1]);

thisjiang marked this conversation as resolved.
Show resolved Hide resolved
if (attrs.find("out_shape") != attrs.end()) {
out_shape = absl::get<std::vector<int>>(attrs.at("out_shape"));
}

CHECK_EQ(out_shape.size(), 2U) << "The length of out_shape must be 2.";
new_shape.push_back(out_shape[0]);
new_shape.push_back(out_shape[1]);

return {new_shape};
}

std::vector<Type> InferDtypeForResize(const std::vector<Type> &inputs_type, const framework::AttrMapType &attrs) {
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
thisjiang marked this conversation as resolved.
Show resolved Hide resolved
std::vector<Type> res{inputs_type[0]};
return res;
}

std::shared_ptr<framework::OpStrategy> StrategyForResize(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
std::vector<int> out_shape;
std::string mode = "bilinear";

for (auto &iter : attrs.attr_store) {
if (iter.first == "out_shape") {
out_shape = absl::get<std::vector<int>>(iter.second);
} else if (iter.first == "mode") {
mode = absl::get<std::string>(iter.second);
}
}

CHECK(mode == "nearest" || mode == "bilinear" || mode == "bicubic")
<< "Resize only supports `nearest`, `bilinear` and `bicubic` mode.";

framework::CINNCompute resize_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input arguments of Resize compute is empty! Please check.\n";
CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 1U) << "at least 1 input tensors for Resize compute\n";
Expr A = pack_args[0];
CHECK(A.as_tensor());
CHECK(!output_shapes.empty());
auto tensor_A = A.as_tensor_ref();
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
std::string tensor_name = common::UniqName("T_Resize_out");

if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(pack_args.size(), 2U);
tensor_name = pack_args[1].operator std::string();
}

ir::Tensor out = Resize(tensor_A, target, out_shape, mode, tensor_name);

std::vector<common::CINNValue> res;
auto stages = CreateStages({tensor_A});
stages->InsertLazily(out);
res.push_back(common::CINNValue(out));
res.push_back(common::CINNValue(stages));
*ret = common::CINNValuePack{res};
});

framework::CINNSchedule resize_schedule([=](lang::Args args, lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) {
CHECK(!args.empty()) << "The input argument of resize schedule is empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies<int>());
if (prod_size > 1) {
if (target.arch == Target::Arch::NVGPU) {
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
}
}
std::vector<common::CINNValue> res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of resize schedule is empty! Please check.\n";
thisjiang marked this conversation as resolved.
Show resolved Hide resolved
CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
}
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(resize_compute, resize_schedule, "strategy.resize.x86", 1);

return strategy;
}

} // namespace op
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(resize_ops) {
CINN_REGISTER_OP(resize)
.describe(" ")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForResize)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForResize))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForResize))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible)
.set_support_level(4);

return true;
}
thisjiang marked this conversation as resolved.
Show resolved Hide resolved
36 changes: 36 additions & 0 deletions cinn/hlir/op/contrib/resize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"

namespace cinn {
namespace hlir {
namespace op {

ir::Tensor Resize(const ir::Tensor &x,
const common::Target &target,
const std::vector<int> &out_shape,
const std::string &mode,
const std::string &output_name);

} // namespace op
} // namespace hlir
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/hlir/op/use_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ CINN_USE_REGISTER(uniform_random_ops)
CINN_USE_REGISTER(cholesky_ops)
CINN_USE_REGISTER(triangular_solve_ops)
CINN_USE_REGISTER(op_external_api)
CINN_USE_REGISTER(resize_ops)
1 change: 1 addition & 0 deletions cinn/pybind/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ void BindFrontend(pybind11::module *m) {
py::arg("strides") = std::vector<int>{},
py::arg("decrease_axis") = std::vector<int>{})
.def("reverse", &NetBuilder::Reverse, py::arg("x"), py::arg("axis"))
.def("resize", &NetBuilder::Resize, py::arg("x"), py::arg("out_shape"), py::arg("mode") = "bilinear")
.def("select", &NetBuilder::Select, py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
.def("split", &NetBuilder::Split, py::arg("x"), py::arg("num_or_sections"), py::arg("axis") = 0)
.def("gather", &NetBuilder::Gather, py::arg("x"), py::arg("index"), py::arg("axis") = 0)
Expand Down
Loading