From fa2aac312944a5cd04a5657f31b06572437e26d7 Mon Sep 17 00:00:00 2001 From: tczrr1999 <2742392377@qq.com> Date: Sun, 10 Jul 2022 13:22:44 +0800 Subject: [PATCH 1/6] add squeeze base op --- cinn/frontend/base_builder.cc | 7 ++ cinn/frontend/base_builder.h | 2 + cinn/hlir/op/transform.cc | 121 ++++++++++++++++++++++++++++++++++ cinn/hlir/pe/transform.cc | 43 ++++++++++++ cinn/hlir/pe/transform.h | 4 ++ cinn/pybind/frontend.cc | 1 + 6 files changed, 178 insertions(+) diff --git a/cinn/frontend/base_builder.cc b/cinn/frontend/base_builder.cc index 0e744a556a..689fac8438 100644 --- a/cinn/frontend/base_builder.cc +++ b/cinn/frontend/base_builder.cc @@ -224,6 +224,13 @@ Variable BaseBuilder::Reshape(const Variable& operand, const std::vector& s return instr.GetOutput(0); } +Variable BaseBuilder::Squeeze(const Variable& operand) { + Instruction instr("squeeze", {operand}); + InferShape(instr); + AppendInstruction(instr); + return instr.GetOutput(0); +} + Variable BaseBuilder::Transpose(const Variable& operand, const std::vector& axis) { Instruction instr("transpose", {operand}); instr.SetAttr("axis", axis); diff --git a/cinn/frontend/base_builder.h b/cinn/frontend/base_builder.h index 938f7264d8..a6f11eb70d 100644 --- a/cinn/frontend/base_builder.h +++ b/cinn/frontend/base_builder.h @@ -88,6 +88,8 @@ class BaseBuilder { Variable Reshape(const Variable& operand, const std::vector& shape); + Variable Squeeze(const Variable& operand); + Variable Transpose(const Variable& operand, const std::vector& axis); Variable Slice(const Variable& operand, diff --git a/cinn/hlir/op/transform.cc b/cinn/hlir/op/transform.cc index 155286f486..1361fdbab0 100755 --- a/cinn/hlir/op/transform.cc +++ b/cinn/hlir/op/transform.cc @@ -414,6 +414,114 @@ std::vector> InferLayoutForReshape(const std::vector StrategyForSqueeze(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute squeeze_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of Squeeze compute is empty! Please check.\n"; + CINNValuePack a = args[0]; + CHECK_GE(a.size(), 1U) << "at least 1 input tensors for Squeeze compute\n"; + Expr A = a[0]; + CHECK(A.as_tensor()); + CHECK(!output_shapes.empty()); +// auto attr_store = attrs.attr_store; +// CHECK(attr_store.count("shape")) << "find no attr of shape"; +// std::vector new_shape = absl::get>(attr_store.at("shape")); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + ir::Tensor out = pe::Squeeze(tensor_A, stages, UniqName("Squeeze_out")); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) << "Output type of Squeeze is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule squeeze_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of reshape schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + int arg_size = arg_pack.size(); + poly::StageMap stages = arg_pack.back(); + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + if (target.arch == Target::Arch::NVGPU) { + pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes[0], target); + } else if (target.arch == Target::Arch::X86) { + pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes[0], target); + } + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(squeeze_compute, squeeze_schedule, "strategy.squeeze.x86", 1); + return strategy; +} + +std::vector> InferShapeForSqueeze(const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again."; + std::vector output_shape; + int tensor_size = 1; + for (auto s : inputs_shape[0]) { + if (s != 1) { + output_shape.push_back(s); + } + tensor_size *= s; + } + CHECK(!output_shape.empty()) << "infer_shape for squeeze turns out to be empty. Please check\n"; + int flag_index = -1; + for (int i = 0; i < output_shape.size(); i++) { + if (output_shape[i] > 0) { + CHECK_EQ(tensor_size % output_shape[i], 0) + << "Incompatible input shape and output shape in op reshape: " << tensor_size << ", " << output_shape[i]; + tensor_size /= output_shape[i]; + } else if (output_shape[i] == 0) { + CHECK_LT(i, inputs_shape[0].size()) + << "In op reshape, when attribute shape[i] == 0, shape[i] = input_shape[i]. But now the size of input_shape " + "<= i, which is incompatible. Please check!"; + output_shape[i] = inputs_shape[0][i]; + CHECK_EQ(tensor_size % output_shape[i], 0) + << "Incompatible input shape and output shape in op reshape: " << tensor_size << ", " << output_shape[i]; + tensor_size /= output_shape[i]; + } else if (output_shape[i] == -1 && flag_index == -1) { + flag_index = i; + } else if (output_shape[i] == -1) { + LOG(FATAL) << "More than one -1 in output_shape of op reshape."; + } else { + LOG(FATAL) << "Unsupported output_shape " << output_shape[i]; + } + } + if (flag_index >= 0) output_shape[flag_index] = tensor_size; + std::vector> res{output_shape}; + return res; +} + +std::vector InferDtypeForSqueeze(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + std::vector res{inputs_type[0]}; + return res; +} + +std::vector> InferLayoutForSqueeze(const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_shapes.size(), 1U) << "The input's shape size is not 1! Please check again."; + CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; + std::vector new_input_layouts = input_layouts; + if (input_shapes[0].size() > 4) { + // alter input layout back + new_input_layouts[0] = "NCHW"; + VLOG(3) << "alter input layout from " << input_layouts[0] << " to " << new_input_layouts[0]; + } + return {new_input_layouts, new_input_layouts}; +} + std::shared_ptr StrategyForSplit(const framework::NodeAttr &attrs, const std::vector &inputs, const std::vector &out_type, @@ -1892,6 +2000,19 @@ CINN_REGISTER_HELPER(transform_ops) { .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque) .set_support_level(4); + CINN_REGISTER_OP(squeeze) + .describe("This operator is used to squeeze input tensor X.") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSqueeze) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSqueeze)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForSqueeze)) +#ifndef CINN_WITH_CUDA + .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForSqueeze)) +#endif + .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque) + .set_support_level(4); + CINN_REGISTER_OP(split) .describe("This operator is used to split tensors X to 'sections' sub-tensor on specified axis.") .set_num_inputs(1) diff --git a/cinn/hlir/pe/transform.cc b/cinn/hlir/pe/transform.cc index 0d74d0641e..59c9ddf7fa 100755 --- a/cinn/hlir/pe/transform.cc +++ b/cinn/hlir/pe/transform.cc @@ -150,6 +150,49 @@ ir::Tensor Reshape(const ir::Tensor& A, const std::vector& new_shape, const return res; } +ir::Tensor Squeeze(const ir::Tensor& A, + poly::StageMap stages, + const std::string& name) { + std::vector new_expr_shape; + std::vector A_expr_shape = A->shape; + for (auto& i : A_expr_shape) { + CHECK(i.is_constant()) << "Input tensor's shape should be constant value."; + if(i != Expr(1)){ + new_expr_shape.push_back(i); + } + } + auto out = Identity(A->Reshape(new_expr_shape, stages), name).front(); + return out; +} + +ir::Tensor Squeeze(const ir::Tensor& A, + const std::string& name) { + std::vector new_expr_shape; + std::vector A_expr_shape = A->shape; + for (auto& i : A_expr_shape) { + CHECK(i.is_constant()) << "Input tensor's shape should be constant value."; + if(i != Expr(1)){ + new_expr_shape.push_back(i); + } + } + auto res = Compute( + new_expr_shape, + [=](const std::vector& indice) { + Expr offset = Expr(0); + for (int i = 0; i < indice.size(); i++) { + offset = offset * new_expr_shape[i] + indice[i]; + } + std::vector indice_a; + for (int i = A_expr_shape.size() - 1; i >= 0; i--) { + auto temp = offset % A_expr_shape[i]; + indice_a.insert(indice_a.begin(), common::AutoSimplify(temp)); + offset = (offset - temp) / A_expr_shape[i]; + } + return A(indice_a); + }, + name); + return res; +} std::vector Split(const ir::Tensor& A, int axis, const std::vector>& output_shapes, diff --git a/cinn/hlir/pe/transform.h b/cinn/hlir/pe/transform.h index 47c1cb0342..6206bbfa85 100755 --- a/cinn/hlir/pe/transform.h +++ b/cinn/hlir/pe/transform.h @@ -52,6 +52,10 @@ ir::Tensor Reshape(const ir::Tensor& A, poly::StageMap stages, const std::string& name); +ir::Tensor Squeeze(const ir::Tensor& A, + poly::StageMap stages, + const std::string& name); + ir::Tensor Reshape(const ir::Tensor& A, const std::vector& new_shape, const std::string& name = UniqName("T_Transform_Matmul_out")); diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc index 3ec8362370..0455b46850 100755 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -381,6 +381,7 @@ void BindFrontend(pybind11::module *m) { py::arg("out_shape"), py::arg("broadcast_axes")) .def("reshape", &BaseBuilder::Reshape, py::arg("a"), py::arg("shape")) + .def("squeeze", &BaseBuilder::Squeeze, py::arg("a")) .def("transpose", &BaseBuilder::Transpose, py::arg("a"), py::arg("axis")) .def("slice", &BaseBuilder::Slice, From 0e617213b3f1eb3127e68498c9c96625d8ea54e0 Mon Sep 17 00:00:00 2001 From: tczrr1999 <2742392377@qq.com> Date: Thu, 14 Jul 2022 21:06:10 +0800 Subject: [PATCH 2/6] add paddle_to_program --- cinn/frontend/op_mappers/paddle/CMakeLists.txt | 1 + cinn/frontend/paddle_model_to_program.cc | 14 ++++++++++++++ cinn/frontend/paddle_model_to_program.h | 2 ++ cinn/frontend/syntax.cc | 6 ++++++ cinn/frontend/syntax.h | 7 +++++++ 5 files changed, 30 insertions(+) diff --git a/cinn/frontend/op_mappers/paddle/CMakeLists.txt b/cinn/frontend/op_mappers/paddle/CMakeLists.txt index 6a9b6b4a32..2f1bd482d7 100644 --- a/cinn/frontend/op_mappers/paddle/CMakeLists.txt +++ b/cinn/frontend/op_mappers/paddle/CMakeLists.txt @@ -13,4 +13,5 @@ gather_srcs(cinnapi_src SRCS slice.cc dropout.cc transpose.cc + squeeze.cc reshape.cc) diff --git a/cinn/frontend/paddle_model_to_program.cc b/cinn/frontend/paddle_model_to_program.cc index eb7f7df8a3..c875423583 100644 --- a/cinn/frontend/paddle_model_to_program.cc +++ b/cinn/frontend/paddle_model_to_program.cc @@ -170,6 +170,20 @@ void PaddleModelToProgram::AddOpMapper_reshape2() { }; } +void PaddleModelToProgram::AddOpMapper_squeeze2() { + op_mappers_["squeeze2"] = [&](const paddle::cpp::OpDesc& op_desc) { + CHECK_EQ(op_desc.Input("X").size(), 1UL); + auto x_name = op_desc.Input("X").front(); + auto x = GetVar(utils::TransValidVarName(x_name)); + VLOG(4) << "x shape: " << utils::Join(x->shape, ","); + auto out = program_->squeeze(x); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); + auto out_name = op_desc.Output("Out").front(); + AddVar(utils::TransValidVarName(out_name), out); + var_model_to_program_map_[out_name] = out->id; + }; +} + void PaddleModelToProgram::AddOpMapper_concat() { op_mappers_["concat"] = [&](const paddle::cpp::OpDesc& op_desc) { int input_size = op_desc.Input("X").size(); diff --git a/cinn/frontend/paddle_model_to_program.h b/cinn/frontend/paddle_model_to_program.h index c84636699d..7d5e63fedb 100644 --- a/cinn/frontend/paddle_model_to_program.h +++ b/cinn/frontend/paddle_model_to_program.h @@ -63,6 +63,7 @@ class PaddleModelToProgram { AddOpMapper_dropout_infer(); AddOpMapper_matmul(); AddOpMapper_reshape2(); + AddOpMapper_squeeze2(); AddOpMapper_concat(); AddOpMapper_assign(); AddOpMapper_fill_constant(); @@ -96,6 +97,7 @@ class PaddleModelToProgram { void AddOpMapper_dropout_infer(); void AddOpMapper_matmul(); void AddOpMapper_reshape2(); + void AddOpMapper_squeeze2(); void AddOpMapper_concat(); void AddOpMapper_assign(); void AddOpMapper_fill_constant(); diff --git a/cinn/frontend/syntax.cc b/cinn/frontend/syntax.cc index ff1f67698f..9d836e3ca1 100755 --- a/cinn/frontend/syntax.cc +++ b/cinn/frontend/syntax.cc @@ -457,6 +457,12 @@ Variable Program::reshape(const Variable& a, const std::vector& shape) { return instr.GetOutput(0); } +Variable Program::squeeze(const Variable& a) { + Instruction instr("squeeze", {a}); + AppendInstruction(instr); + return instr.GetOutput(0); +} + Variable Program::concat(const std::vector& input_vars, int axis) { Instruction instr("concat", input_vars); instr.SetAttr("axis", axis); diff --git a/cinn/frontend/syntax.h b/cinn/frontend/syntax.h index 05c78fed27..71400c98e6 100755 --- a/cinn/frontend/syntax.h +++ b/cinn/frontend/syntax.h @@ -271,6 +271,13 @@ struct Program { */ Variable reshape(const Variable& a, const std::vector& shape); + /** + * Squeeze a tensor. + * @param a The input tensor. + * @return The squeezed output tensor. + */ + Variable squeeze(const Variable& a); + /** * Concat tensors. * @param input_vars The input tensors. From c4c92821e3103093d67c986512112e38d85a31a7 Mon Sep 17 00:00:00 2001 From: tczrr1999 <2742392377@qq.com> Date: Thu, 14 Jul 2022 21:07:25 +0800 Subject: [PATCH 3/6] add axis --- cinn/frontend/base_builder.cc | 3 ++- cinn/frontend/base_builder.h | 2 +- cinn/hlir/pe/transform.cc | 24 ++++++++++++++++++++---- cinn/hlir/pe/transform.h | 1 + cinn/pybind/frontend.cc | 2 +- 5 files changed, 25 insertions(+), 7 deletions(-) diff --git a/cinn/frontend/base_builder.cc b/cinn/frontend/base_builder.cc index 689fac8438..64500fa401 100644 --- a/cinn/frontend/base_builder.cc +++ b/cinn/frontend/base_builder.cc @@ -224,8 +224,9 @@ Variable BaseBuilder::Reshape(const Variable& operand, const std::vector& s return instr.GetOutput(0); } -Variable BaseBuilder::Squeeze(const Variable& operand) { +Variable BaseBuilder::Squeeze(const Variable& operand, const std::vector& axis) { Instruction instr("squeeze", {operand}); + instr.SetAttr("axis", axis); InferShape(instr); AppendInstruction(instr); return instr.GetOutput(0); diff --git a/cinn/frontend/base_builder.h b/cinn/frontend/base_builder.h index a6f11eb70d..7865282822 100644 --- a/cinn/frontend/base_builder.h +++ b/cinn/frontend/base_builder.h @@ -88,7 +88,7 @@ class BaseBuilder { Variable Reshape(const Variable& operand, const std::vector& shape); - Variable Squeeze(const Variable& operand); + Variable Squeeze(const Variable& operand, const std::vector& axis); Variable Transpose(const Variable& operand, const std::vector& axis); diff --git a/cinn/hlir/pe/transform.cc b/cinn/hlir/pe/transform.cc index 59c9ddf7fa..e04159277d 100755 --- a/cinn/hlir/pe/transform.cc +++ b/cinn/hlir/pe/transform.cc @@ -151,16 +151,32 @@ ir::Tensor Reshape(const ir::Tensor& A, const std::vector& new_shape, const } ir::Tensor Squeeze(const ir::Tensor& A, + const std::vector& axis, poly::StageMap stages, const std::string& name) { std::vector new_expr_shape; std::vector A_expr_shape = A->shape; - for (auto& i : A_expr_shape) { - CHECK(i.is_constant()) << "Input tensor's shape should be constant value."; - if(i != Expr(1)){ - new_expr_shape.push_back(i); + CHECK_EQ(axis.size(), A_expr_shape.size()); + if (axis){ + for (auto& a : axis) { + CHECK_EQ(A_expr_shape[a], Expr(1)); + A_expr_shape[a] = Expr(0); + } + for (auto& i : A_expr_shape) { + CHECK(i.is_constant()) << "Input tensor's shape should be constant value."; + if(i != Expr(0)){ + new_expr_shape.push_back(i); + } + } + }else{ + for (auto& i : A_expr_shape) { + CHECK(i.is_constant()) << "Input tensor's shape should be constant value."; + if(i != Expr(1)){ + new_expr_shape.push_back(i); + } } } + auto out = Identity(A->Reshape(new_expr_shape, stages), name).front(); return out; } diff --git a/cinn/hlir/pe/transform.h b/cinn/hlir/pe/transform.h index 6206bbfa85..198763ea45 100755 --- a/cinn/hlir/pe/transform.h +++ b/cinn/hlir/pe/transform.h @@ -53,6 +53,7 @@ ir::Tensor Reshape(const ir::Tensor& A, const std::string& name); ir::Tensor Squeeze(const ir::Tensor& A, + const std::vector& axis, poly::StageMap stages, const std::string& name); diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc index 0455b46850..89d2ad5d54 100755 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -381,7 +381,7 @@ void BindFrontend(pybind11::module *m) { py::arg("out_shape"), py::arg("broadcast_axes")) .def("reshape", &BaseBuilder::Reshape, py::arg("a"), py::arg("shape")) - .def("squeeze", &BaseBuilder::Squeeze, py::arg("a")) + .def("squeeze", &BaseBuilder::Squeeze, py::arg("a"), py::arg("axis")) .def("transpose", &BaseBuilder::Transpose, py::arg("a"), py::arg("axis")) .def("slice", &BaseBuilder::Slice, From 6c6ee7959567f350887b6f6c0528b315620de5c6 Mon Sep 17 00:00:00 2001 From: tczrr1999 <2742392377@qq.com> Date: Wed, 27 Jul 2022 16:17:15 +0800 Subject: [PATCH 4/6] add axes --- cinn/hlir/op/transform.cc | 39 +++++++++++++++++++++++++++++++-------- cinn/hlir/pe/transform.cc | 9 +++++---- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/cinn/hlir/op/transform.cc b/cinn/hlir/op/transform.cc index 1361fdbab0..40f30b4190 100755 --- a/cinn/hlir/op/transform.cc +++ b/cinn/hlir/op/transform.cc @@ -419,6 +419,9 @@ std::shared_ptr StrategyForSqueeze(const framework::NodeAttr &attrs, const std::vector &out_type, const std::vector> &output_shapes, const Target &target) { + CHECK(attrs.attr_store.count("axes")) << "find no attr of axes"; + std::vector axes = absl::get>(attrs.attr_store.at("axes")); + framework::CINNCompute squeeze_compute([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input arguments of Squeeze compute is empty! Please check.\n"; CINNValuePack a = args[0]; @@ -426,14 +429,11 @@ std::shared_ptr StrategyForSqueeze(const framework::NodeAttr &attrs, Expr A = a[0]; CHECK(A.as_tensor()); CHECK(!output_shapes.empty()); -// auto attr_store = attrs.attr_store; -// CHECK(attr_store.count("shape")) << "find no attr of shape"; -// std::vector new_shape = absl::get>(attr_store.at("shape")); auto tensor_A = A.as_tensor_ref(); auto stages = CreateStages({tensor_A}); VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") << ", output_shapes: " << utils::Join(output_shapes[0], ", "); - ir::Tensor out = pe::Squeeze(tensor_A, stages, UniqName("Squeeze_out")); + ir::Tensor out = pe::Squeeze(tensor_A, axes, stages, UniqName("Squeeze_out")); std::vector res; stages->InsertLazily(out); res.push_back(CINNValue(out)); @@ -465,14 +465,37 @@ std::shared_ptr StrategyForSqueeze(const framework::NodeAttr &attrs, std::vector> InferShapeForSqueeze(const std::vector> &inputs_shape, const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again."; + std::vector axes; + for (auto &iter : attrs) { + if (iter.first == "axes") { + axes = absl::get>(iter.second); + break; + } + } + std::vector output_shape; int tensor_size = 1; - for (auto s : inputs_shape[0]) { - if (s != 1) { - output_shape.push_back(s); + if (axes.size()!=0){ + std::vector temp_shape = inputs_shape[0]; + for (auto& a : axes) { + CHECK(a& new_shape, const } ir::Tensor Squeeze(const ir::Tensor& A, - const std::vector& axis, + const std::vector& axes, poly::StageMap stages, const std::string& name) { std::vector new_expr_shape; std::vector A_expr_shape = A->shape; - CHECK_EQ(axis.size(), A_expr_shape.size()); - if (axis){ - for (auto& a : axis) { + if (axes.size()!=0){ + for (auto& a : axes) { + CHECK(a& axes, const std::string& name) { std::vector new_expr_shape; std::vector A_expr_shape = A->shape; From 5bfa596fe55771f7800b0463c23c94b4a4cc9e48 Mon Sep 17 00:00:00 2001 From: tczrr1999 <2742392377@qq.com> Date: Wed, 27 Jul 2022 16:17:55 +0800 Subject: [PATCH 5/6] add axes --- cinn/frontend/syntax.cc | 3 ++- cinn/frontend/syntax.h | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cinn/frontend/syntax.cc b/cinn/frontend/syntax.cc index 9d836e3ca1..04b8f110ae 100755 --- a/cinn/frontend/syntax.cc +++ b/cinn/frontend/syntax.cc @@ -457,8 +457,9 @@ Variable Program::reshape(const Variable& a, const std::vector& shape) { return instr.GetOutput(0); } -Variable Program::squeeze(const Variable& a) { +Variable Program::squeeze(const Variable& a, const std::vector& axes) { Instruction instr("squeeze", {a}); + instr.SetAttr("axes", axes); AppendInstruction(instr); return instr.GetOutput(0); } diff --git a/cinn/frontend/syntax.h b/cinn/frontend/syntax.h index 71400c98e6..2bd526804e 100755 --- a/cinn/frontend/syntax.h +++ b/cinn/frontend/syntax.h @@ -274,9 +274,10 @@ struct Program { /** * Squeeze a tensor. * @param a The input tensor. + * @param axis The tensor's axis we specified. * @return The squeezed output tensor. */ - Variable squeeze(const Variable& a); + Variable squeeze(const Variable& a, const std::vector& axes); /** * Concat tensors. From a600c1bf3140bc4d47c12ca3ec44a5aecea80c46 Mon Sep 17 00:00:00 2001 From: tczrr1999 <2742392377@qq.com> Date: Wed, 27 Jul 2022 16:19:19 +0800 Subject: [PATCH 6/6] add paddle_model_to_program squeeze --- cinn/frontend/op_mappers/paddle/squeeze.cc | 74 ++++++++++++++++++++++ cinn/frontend/op_mappers/use_op_mappers.h | 1 + cinn/frontend/paddle_model_to_program.cc | 3 +- 3 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 cinn/frontend/op_mappers/paddle/squeeze.cc diff --git a/cinn/frontend/op_mappers/paddle/squeeze.cc b/cinn/frontend/op_mappers/paddle/squeeze.cc new file mode 100644 index 0000000000..4b07b8aa87 --- /dev/null +++ b/cinn/frontend/op_mappers/paddle/squeeze.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/cuda_util.h" +#include "cinn/frontend/op_mapper_registry.h" +#include "cinn/frontend/op_mappers/common_utils.h" + +namespace cinn { +namespace frontend { +namespace paddle_mappers { + +void SqueezeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { + CHECK_EQ(op_desc.Input("X").size(), 1UL); + auto x_name = op_desc.Input("X").front(); + auto x = ctx.GetVar(x_name); + + auto axes = utils::GetAttrOrDefault>(op_desc, "axes"); + + VLOG(4) << "x shape: " << cinn::utils::Join(x->shape, ","); + + auto out = ctx.Builder()->Squeeze(x, axes); + + CHECK_EQ(op_desc.Output("Out").size(), 1UL); + auto out_name = op_desc.Output("Out").front(); + ctx.AddVar(out_name, out); + ctx.AddVarModelToProgram(out_name, out->id); +} + +void SqueezeGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { + auto get_input_var = [&op_desc, &ctx](const std::string& op_name) { + CHECK_EQ(op_desc.Input(op_name).size(), 1UL); + auto var_name = op_desc.Input(op_name).front(); + return ctx.GetVar(var_name); + }; + + auto get_output_name = [&op_desc](const std::string& op_name) { + CHECK_EQ(op_desc.Output(op_name).size(), 1UL); + return op_desc.Output(op_name).front(); + }; + + auto dout = get_input_var(paddle::GradVarName("Out")); + VLOG(4) << "dout shape: " << cinn::utils::Join(dout->shape, ","); + + auto x = get_input_var("X"); + VLOG(4) << "x shape: " << cinn::utils::Join(x->shape, ","); + + auto out = ctx.Builder()->Reshape(dout, x->shape); + + auto out_name = get_output_name(paddle::GradVarName("X")); + ctx.AddVar(out_name, out); + ctx.AddVarModelToProgram(out_name, out->id); +} + +} // namespace paddle_mappers +} // namespace frontend +} // namespace cinn + +CINN_REGISTER_HELPER(paddle_squeeze) { + CINN_REGISTER_OP_MAPPER(squeeze, cinn::frontend::paddle_mappers::SqueezeOpMapper) + + CINN_REGISTER_OP_MAPPER(squeeze_grad, cinn::frontend::paddle_mappers::SqueezeGradOpMapper) + return true; +} diff --git a/cinn/frontend/op_mappers/use_op_mappers.h b/cinn/frontend/op_mappers/use_op_mappers.h index aa89807321..a968146cfc 100644 --- a/cinn/frontend/op_mappers/use_op_mappers.h +++ b/cinn/frontend/op_mappers/use_op_mappers.h @@ -30,6 +30,7 @@ CINN_USE_REGISTER(paddle_pool2d) CINN_USE_REGISTER(paddle_conv2d) CINN_USE_REGISTER(paddle_transpose) CINN_USE_REGISTER(paddle_reshape) +CINN_USE_REGISTER(paddle_squeeze) CINN_USE_REGISTER(science_broadcast) CINN_USE_REGISTER(science_transform) diff --git a/cinn/frontend/paddle_model_to_program.cc b/cinn/frontend/paddle_model_to_program.cc index c875423583..f06ecb2012 100644 --- a/cinn/frontend/paddle_model_to_program.cc +++ b/cinn/frontend/paddle_model_to_program.cc @@ -175,8 +175,9 @@ void PaddleModelToProgram::AddOpMapper_squeeze2() { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); auto x = GetVar(utils::TransValidVarName(x_name)); + std::vector axes = op_desc.GetAttr>("axes"); VLOG(4) << "x shape: " << utils::Join(x->shape, ","); - auto out = program_->squeeze(x); + auto out = program_->squeeze(x, axes); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); AddVar(utils::TransValidVarName(out_name), out);