From 37f708c55df73d132bd69464ab58dd432f5ad1eb Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 2 Dec 2022 06:26:11 +0000 Subject: [PATCH 1/4] add suppport_tensor for code_gen to static graph --- .../fluid/operators/generator/CMakeLists.txt | 17 +++-- .../fluid/operators/generator/generate_op.py | 35 +++++++++ .../generator/templates/operator_utils.c.j2 | 4 + paddle/fluid/operators/multinomial_op.cc | 73 ------------------- paddle/phi/api/yaml/legacy_ops.yaml | 10 +-- paddle/phi/api/yaml/op_compat.yaml | 7 ++ paddle/phi/api/yaml/ops.yaml | 9 +++ 7 files changed, 67 insertions(+), 88 deletions(-) delete mode 100644 paddle/fluid/operators/multinomial_op.cc diff --git a/paddle/fluid/operators/generator/CMakeLists.txt b/paddle/fluid/operators/generator/CMakeLists.txt index 7ebc1300345ad..39d90aab5af19 100644 --- a/paddle/fluid/operators/generator/CMakeLists.txt +++ b/paddle/fluid/operators/generator/CMakeLists.txt @@ -108,18 +108,23 @@ execute_process( --op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml --output_op_path "${generated_op_path}.tmp" --output_arg_map_path "${generated_argument_mapping_path}.tmp" + RESULT_VARIABLE _result) +if(${_result}) + message(FATAL_ERROR "operator codegen failed, exiting.") +endif() + +execute_process( + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator COMMAND ${PYTHON_EXECUTABLE} generate_sparse_op.py --ops_yaml_path ./parsed_ops/sparse_ops.parsed.yaml --backward_ops_yaml_path ./parsed_ops/sparse_backward.parsed.yaml --output_op_path "${generated_sparse_ops_path}.tmp" --output_arg_map_path "${generated_sparse_argument_mapping_path}.tmp" - RESULT_VARIABLE _results) -foreach(_result in ${_results}) - if(${_result}) - message(FATAL_ERROR "operator codegen failed, exiting.") - endif() -endforeach() + RESULT_VARIABLE _result) +if(${_result}) + message(FATAL_ERROR "sparse operator codegen failed, exiting.") +endif() if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}") execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index 8746b83a325bd..fb279e7769ce2 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -66,6 +66,25 @@ def restruct_io(op): return op +def process_support_tensor(op_item, support_tensor_list): + scalar_map = { + 'Scalar': 'float', + 'Scalar(float)': 'float', + 'Scalar(int)': 'int', + 'Scalar(int64_t)': 'int64_t', + } + for attr_item in op_item['attrs']: + if support_tensor_list and attr_item['name'] in support_tensor_list: + attr_type = attr_item['typename'] + assert ( + attr_type in scalar_map + ), f"{op_item['name']}'s support_tensor in op_compat.yaml is error, the data_type of {attr_item['name']} is expected to be one of Scalar, Scalar(float), Scalar(int) or Scalar(int64_t), but now is {attr_type}." + attr_item['typename'] = scalar_map[attr_type] + attr_item['is_support_tensor'] = True + else: + attr_item['is_support_tensor'] = False + + # replace name of op and params for OpMaker def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict): def get_op_and_op_name(op_item): @@ -91,12 +110,22 @@ def update_op_attr_name(attrs, attrs_alias_map): if new_op_name != op_name: forward_op_item['op_name'] = op_name + support_tensor_list = None + + if 'support_tensor' in op_args: + support_tensor_list = [ + item.strip() for item in op_args['support_tensor'].split(",") + ] + process_support_tensor(forward_op_item, support_tensor_list) + if 'backward' in op_args and has_backward: backward_op_list = op_args['backward'].split(',') _, bw_op_name = get_op_and_op_name(backward_op_list[0]) forward_op_item['backward'] = bw_op_name backward_op_item['op_name'] = bw_op_name + process_support_tensor(backward_op_item, support_tensor_list) + # for double grad if len(backward_op_list) > 1: ( @@ -114,6 +143,8 @@ def update_op_attr_name(attrs, attrs_alias_map): double_grad_item['forward']['attrs'], op_args['attrs'] ) + process_support_tensor(double_grad_item, support_tensor_list) + # for triple grad if len(backward_op_list) > 2: ( @@ -132,6 +163,10 @@ def update_op_attr_name(attrs, attrs_alias_map): op_args['attrs'], ) + process_support_tensor( + triple_grad_item, support_tensor_list + ) + key_set = ['inputs', 'attrs', 'outputs'] args_map = {} for key in key_set: diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index 1fef16fc23462..70b593ff16ea2 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -74,6 +74,10 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty {% if "default_value" in attr %} .SetDefault({{process_default_value(attr)}}) {%- endif %} + {% if "is_support_tensor" in attr and attr["is_support_tensor"] %} + + .SupportTensor() + {%- endif %} {%- endmacro %} {# process default value for attributes, some attribute has different types and different default values in op & opmaker #} diff --git a/paddle/fluid/operators/multinomial_op.cc b/paddle/fluid/operators/multinomial_op.cc deleted file mode 100644 index 663e86137309d..0000000000000 --- a/paddle/fluid/operators/multinomial_op.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "A tensor contains probabilities of categories"); - AddOutput("Out", "The output tensor of multinomial op"); - AddAttr("num_samples", "number of the generated samples") - .SetDefault(1) - .SupportTensor(); - AddAttr("replacement", "can a category be sampled more than once") - .SetDefault(false); - AddComment(R"DOC( -This OP returns a Tensor filled with the sampled categoris according to Multinomial probabilities. - - Out ~ Multinomial(X) - -)DOC"); - } -}; - -class MultinomialOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = - framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -DECLARE_INFER_SHAPE_FUNCTOR(multinomial, - MultinomialInferShapeFunctor, - PD_INFER_META(phi::MultinomialInferMeta)); -REGISTER_OPERATOR( - multinomial, - ops::MultinomialOp, - ops::MultinomialOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - MultinomialInferShapeFunctor); diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 1bc0fc7f0aa43..93d088d5d9edc 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -287,7 +287,7 @@ backward : bilinear_tensor_product_grad - op : bincount - args: (Tensor x, Tensor weights, Scalar minlength) + args: (Tensor x, Tensor weights, Scalar(int) minlength = 0) output: Tensor(out) infer_meta: func: BincountInferMeta @@ -1377,14 +1377,6 @@ func : multiclass_nms3 optional : rois_num -- op : multinomial - args : (Tensor x, Scalar num_samples, bool replacement) - output : Tensor(out) - infer_meta : - func : MultinomialInferMeta - kernel : - func : multinomial - - op : multiplex args : (Tensor[] inputs, Tensor index) output : Tensor diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 5640ca7eb8b0f..d2c28bc32eabe 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -809,6 +809,13 @@ extra : attrs : [bool use_mkldnn = false] +- op : multinomial + inputs : + {x : X} + outputs : + out : Out + support_tensor : num_samples + - op : multiply (elementwise_mul) backward : multiply_grad (elementwise_mul_grad) extra : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 10b6645c61667..e8653fc79e6e3 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -665,6 +665,15 @@ func : maxout backward : maxout_grad +- op : multinomial + args : (Tensor x, Scalar(int) num_samples = 1, bool replacement = false) + output : Tensor(out) + infer_meta : + func : MultinomialInferMeta + kernel : + func : multinomial + data_type : x + - op : mv args : (Tensor x, Tensor vec) output : Tensor From a00fe7b383dd7c751c9027afe75b4f6c9e594cb5 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 6 Dec 2022 08:19:15 +0000 Subject: [PATCH 2/4] support code-gen for int_array --- paddle/fluid/operators/crop_tensor_op.cc | 320 ------------------ paddle/fluid/operators/generator/filters.py | 35 +- .../fluid/operators/generator/generate_op.py | 105 ++++-- .../operators/generator/generate_sparse_op.py | 6 + .../generator/templates/operator_utils.c.j2 | 61 ++-- paddle/fluid/operators/top_k_v2_op.cc | 140 -------- paddle/phi/api/yaml/backward.yaml | 21 ++ paddle/phi/api/yaml/legacy_backward.yaml | 20 -- paddle/phi/api/yaml/legacy_ops.yaml | 19 -- paddle/phi/api/yaml/op_compat.yaml | 36 ++ paddle/phi/api/yaml/ops.yaml | 20 ++ paddle/phi/ops/compat/crop_tensor_sig.cc | 73 ---- paddle/phi/ops/compat/top_k_sig.cc | 42 --- 13 files changed, 238 insertions(+), 660 deletions(-) delete mode 100644 paddle/fluid/operators/crop_tensor_op.cc delete mode 100644 paddle/fluid/operators/top_k_v2_op.cc delete mode 100644 paddle/phi/ops/compat/crop_tensor_sig.cc delete mode 100644 paddle/phi/ops/compat/top_k_sig.cc diff --git a/paddle/fluid/operators/crop_tensor_op.cc b/paddle/fluid/operators/crop_tensor_op.cc deleted file mode 100644 index b74aaf8cb22a2..0000000000000 --- a/paddle/fluid/operators/crop_tensor_op.cc +++ /dev/null @@ -1,320 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" - -// TODO(freeliuzc): Delete old infershape -// New infershape has already in unary.h and backward.h - -namespace paddle { -namespace operators { - -class CropTensorOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CropTensor"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CropTensor"); - auto x_dim = ctx->GetInputDim("X"); - auto shape = ctx->Attrs().Get>("shape"); - auto offsets = ctx->Attrs().Get>("offsets"); - if (ctx->HasInputs("ShapeTensor")) { - // top prority shape - auto inputs_name = ctx->Inputs("ShapeTensor"); - PADDLE_ENFORCE_GT( - inputs_name.size(), - 0, - platform::errors::InvalidArgument( - "The number of elements of the input 'ShapeTensor' for " - "CropTensor must be greater than zero, " - "but the value received is %d.", - inputs_name.size())); - auto out_dims = std::vector(inputs_name.size(), -1); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] > 0) { - out_dims[i] = static_cast(shape[i]); - } else { - if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) { - out_dims[i] = x_dim[i] - static_cast(offsets[i]); - } - } - } - ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); - - return; - } - - if (ctx->HasInput("Shape")) { - auto shape_dim = ctx->GetInputDim("Shape"); - PADDLE_ENFORCE_EQ(shape_dim.size(), - 1, - platform::errors::InvalidArgument( - "The number of dimensions of the input " - "'Shape' for CropTensor must be 1, " - "but the value received is %d.", - shape_dim.size())); - PADDLE_ENFORCE_EQ(shape_dim[0], - x_dim.size(), - platform::errors::InvalidArgument( - "The number of elements (%d) of the input 'Shape' " - "for CropTensor must be equal to the number of" - " dimensions (%d) of the input.", - shape_dim[0], - x_dim.size())); - if (ctx->IsRuntime()) { - // If true, set the shape of Output(Out) according to Input(Shape) in - // CropKernel with ExecutionContext. Also check LoD in - // CropKernel. - ctx->ShareLoD("X", /*->*/ "Out"); - } else { - auto out_dims = std::vector(shape_dim[0], -1); - ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); - } - return; - } - PADDLE_ENFORCE_EQ( - int64_t(shape.size()), - x_dim.size(), - platform::errors::InvalidArgument( - "The number of elements (%d) of attribute 'shape' for " - "CropTensor must be equal to the number of " - "dimensions (%d) of the input.", - shape.size(), - x_dim.size())); - std::vector out_shape(shape.size(), -1); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] > 0) { - out_shape[i] = static_cast(shape[i]); - } else { - if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) { - out_shape[i] = x_dim[i] - static_cast(offsets[i]); - } - } - } - ctx->SetOutputDim("Out", phi::make_ddim(out_shape)); - } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); - } - - framework::OpKernelType GetKernelTypeForVar( - const std::string &var_name, - const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (var_name == "ShapeTensor" || var_name == "OffsetsTensor" || - var_name == "Shape" || var_name == "Offsets") { - return expected_kernel_type; - } - - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); - } -}; - -class CropTensorOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "The input of pad op. " - "The input should be a k-D tensor(k > 0 and k < 7)."); - AddInput("Shape", - "The input used to describe shape of output, which is a " - "1-D vector whose size equals to the rank of input 'X'. The " - "elements data type must be int. It has a higher priority than " - "the shape attribute") - .AsDispensable(); - AddInput("Offsets", - "The input used to describe offsets in runtime, which is a " - "1-D vector whose size equals to the rank of input 'X'. The " - "elements data type must be int. It has a higher priority than " - "the offsets attribute") - .AsDispensable(); - AddInput("ShapeTensor", - "(vector>, optional). If provided, crop_tensor will " - "use this. The shape of the tensor in vector MUST BE [1]. " - "It has the highest priority compare with Input(Shape) and " - "attr(shape).") - .AsDuplicable() - .AsDispensable(); - AddInput("OffsetsTensor", - "(vector>, optional). If provided, crop_tensor will " - "use this. The shape of the tensor in vector MUST BE [1]. " - "It has the highest priority compare with Input(Offsets) and " - "attr(offsets).") - .AsDuplicable() - .AsDispensable(); - AddOutput("Out", - "The output of crop_tensor op, " - "which is of the same dimensions as X."); - AddAttr>("offsets", - "A list describing offsets to be cropped. " - "The size of offsets list should be the same as " - "the dimension size of input X.") - .SetDefault(std::vector()); - AddAttr>("shape", - "A list describing the shape of output. " - "The size of shape list should be the same as " - "the dimension size of input X.") - .SetDefault(std::vector()); - AddComment(R"DOC( -CropTensor Operator. - -Crop input into output, as specified by offsets and shape. - -There are three ways to set the offsets: -1. Input 'OffsetsTensor: It is a tensor list. It should be set as a list that - contains tensor variable in python configure script. - This way is suitable for dynamic offsets. -2. Input 'Offsets': It is a variable and can be output of other operators. - This way is suitable for dynamic offsets. -3. Attribute 'offsets': It will be set in python configure script. This way - is suitable for fixed offsets. - -You CANNOT use these three ways at the same time. An exception will be raised -if input 'OffsetsTensor' or 'Offset' is configured and meanwhile the attribute 'offsets' is -not empty. - -There are three ways to set shape: -1. Input 'ShapeTensor': It is a tensor list. It should be set as a list that contains - tensor variable in python configure script. This way is suitable - for dynamic shape. -2. Input 'Shape': It is a Variable and can be output of other operators. This way is suitable - for dynamic shape. -2. Attribute 'shape': crop input X into the shape described by a list. The size of shape - list should be the same as the dimension size of input X. This way is - suitable for fixed shape. - -The input should be a k-D tensor(k > 0 and k < 7). As an example: - -Case 1: -Given - - X = [[0, 1, 2, 0, 0] - [0, 3, 4, 0, 0] - [0, 0, 0, 0, 0]], - -and - - offsets = [0, 1], - -and - - shape = [2, 2], - -we get: - - Out = [[1, 2], - [3, 4]]. - - -Case 2: -Given - - X = [[0, 1, 2, 5, 0] - [0, 3, 4, 6, 0] - [0, 0, 0, 0, 0]], - -and offsets is a list that contains tensor variable, -in runtime offses_var' s value is 1. - - offsets = [0, offsets_var], - -and shape is a list that contains tensor variable, -in runtime dim's value is 2. - - shape = [dim, 3] - -we get: - - Out = [[1, 2, 5], - [3, 4, 6]]. -)DOC"); - } -}; - -class CropTensorOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CropTensorGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "CropTensorGrad"); - auto x_dims = ctx->GetInputDim("X"); - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); - } - - framework::OpKernelType GetKernelTypeForVar( - const std::string &var_name, - const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (var_name == "ShapeTensor" || var_name == "OffsetsTensor" || - var_name == "Shape" || var_name == "Offsets") { - return expected_kernel_type; - } - - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); - } -}; - -template -class CropTensorGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("crop_tensor_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - if (this->HasInput("OffsetsTensor")) { - op->SetInput("OffsetsTensor", this->Input("OffsetsTensor")); - } - if (this->HasInput("Offsets")) { - op->SetInput("Offsets", this->Input("Offsets")); - } - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(crop_tensor, - ops::CropTensorOp, - ops::CropTensorOpMaker, - ops::CropTensorGradOpMaker, - ops::CropTensorGradOpMaker); -REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad); diff --git a/paddle/fluid/operators/generator/filters.py b/paddle/fluid/operators/generator/filters.py index e39a3122d538e..6a40b936f060f 100644 --- a/paddle/fluid/operators/generator/filters.py +++ b/paddle/fluid/operators/generator/filters.py @@ -114,17 +114,44 @@ def to_input_name(s): return match.group(2) +def to_scalar_tensor_name(attr): + if 'tensor_name' in attr: + return attr['tensor_name'] + return to_pascal_case(attr['name']) + 'Tensor' + + +def to_int_array_tensor_name(attr): + if 'tensor_name' in attr: + return attr['tensor_name'] + return to_pascal_case(attr['name']) + 'Tensor' + + +def to_int_array_tensors_name(attr): + if 'tensors_name' in attr: + return attr['tensors_name'] + return to_pascal_case(attr['name']) + 'TensorList' + + def cartesian_prod_attrs(attrs): items = [] for attr in attrs: type_name = attr["typename"] name = attr["name"] if type_name == "Scalar": - items.append((name, "{}Tensor".format(name))) + items.append((name, to_scalar_tensor_name(attr))) elif type_name == "IntArray": - items.append( - (name, "{}Tensor".format(name), "{}TensorList".format(name)) - ) + if 'tensor_name' not in attr and 'manual_flag' in attr: + items.append((name, to_int_array_tensors_name(attr))) + elif 'tensors_name' not in attr and 'manual_flag' in attr: + items.append((name, to_int_array_tensor_name(attr))) + else: + items.append( + ( + name, + to_int_array_tensor_name(attr), + to_int_array_tensors_name(attr), + ) + ) else: items.append((name,)) diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index fb279e7769ce2..dad1da1b32379 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -20,10 +20,13 @@ from filters import ( cartesian_prod_mapping, to_input_name, + to_int_array_tensor_name, + to_int_array_tensors_name, to_op_attr_type, to_opmaker_name, to_opmaker_name_cstr, to_pascal_case, + to_scalar_tensor_name, ) from jinja2 import Environment, FileSystemLoader, StrictUndefined from parse_utils import to_named_dict @@ -48,6 +51,9 @@ env.filters["to_op_attr_type"] = to_op_attr_type env.filters["to_opmaker_name"] = to_opmaker_name env.filters["to_pascal_case"] = to_pascal_case +env.filters["to_scalar_tensor_name"] = to_scalar_tensor_name +env.filters["to_int_array_tensor_name"] = to_int_array_tensor_name +env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name env.filters["to_input_name"] = to_input_name env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping @@ -66,23 +72,74 @@ def restruct_io(op): return op -def process_support_tensor(op_item, support_tensor_list): +def process_scalar(op_item, scalar_configs): scalar_map = { 'Scalar': 'float', 'Scalar(float)': 'float', 'Scalar(int)': 'int', 'Scalar(int64_t)': 'int64_t', } - for attr_item in op_item['attrs']: - if support_tensor_list and attr_item['name'] in support_tensor_list: - attr_type = attr_item['typename'] - assert ( - attr_type in scalar_map - ), f"{op_item['name']}'s support_tensor in op_compat.yaml is error, the data_type of {attr_item['name']} is expected to be one of Scalar, Scalar(float), Scalar(int) or Scalar(int64_t), but now is {attr_type}." - attr_item['typename'] = scalar_map[attr_type] - attr_item['is_support_tensor'] = True - else: - attr_item['is_support_tensor'] = False + if scalar_configs is not None: + for attr_item in op_item['attrs']: + if attr_item['name'] in scalar_configs: + attr_type = attr_item['typename'] + assert ( + attr_type in scalar_map + ), f"{op_item['name']}'s scalar in op_compat.yaml is error, the data_type of {attr_item['name']} is expected to be one of Scalar, Scalar(float), Scalar(int) or Scalar(int64_t), but now is {attr_type}." + + scalar_config = scalar_configs[attr_item['name']] + attr_item['is_support_tensor'] = ( + True + if 'support_tensor' in scalar_config + and scalar_config['support_tensor'] + else False + ) + if attr_item['is_support_tensor']: + attr_item['typename'] = ( + scalar_config['data_type'] + if 'data_type' in scalar_config + else scalar_map[attr_type] + ) + else: + attr_item['tensor_name'] = scalar_config['tensor_name'] + + +def process_int_array(op_item, int_array_configs): + data_type_map = { + 'int': 'std::vector', + 'int64_t': 'std::vector', + } + if int_array_configs is not None: + for attr_item in op_item['attrs']: + if attr_item['name'] in int_array_configs: + attr_type = attr_item['typename'] + assert ( + attr_item['typename'] == "IntArray" + ), f"{op_item['name']}'s int_array in op_compat.yaml is error, the data_type of {attr_item['name']} is expected to be one of IntArray, but now is {attr_type}." + + int_array_config = int_array_configs[attr_item['name']] + attr_item['is_support_tensor'] = ( + True + if 'support_tensor' in int_array_config + and int_array_config['support_tensor'] + else False + ) + if attr_item['is_support_tensor']: + attr_item['typename'] = ( + int_array_config['data_type'] + if 'data_type' in int_array_config + else 'std::vector' + ) + else: + attr_item['manual_flag'] = True + if 'tensor_name' in int_array_config: + attr_item['tensor_name'] = int_array_config[ + 'tensor_name' + ] + if 'tensors_name' in int_array_config: + attr_item['tensors_name'] = int_array_config[ + 'tensors_name' + ] # replace name of op and params for OpMaker @@ -110,13 +167,16 @@ def update_op_attr_name(attrs, attrs_alias_map): if new_op_name != op_name: forward_op_item['op_name'] = op_name - support_tensor_list = None + scalar_configs = None + int_array_configs = None - if 'support_tensor' in op_args: - support_tensor_list = [ - item.strip() for item in op_args['support_tensor'].split(",") - ] - process_support_tensor(forward_op_item, support_tensor_list) + if 'scalar' in op_args: + scalar_configs = op_args['scalar'] + if 'int_array' in op_args: + int_array_configs = op_args['int_array'] + + process_scalar(forward_op_item, scalar_configs) + process_int_array(forward_op_item, int_array_configs) if 'backward' in op_args and has_backward: backward_op_list = op_args['backward'].split(',') @@ -124,7 +184,8 @@ def update_op_attr_name(attrs, attrs_alias_map): forward_op_item['backward'] = bw_op_name backward_op_item['op_name'] = bw_op_name - process_support_tensor(backward_op_item, support_tensor_list) + process_scalar(backward_op_item, scalar_configs) + process_int_array(backward_op_item, int_array_configs) # for double grad if len(backward_op_list) > 1: @@ -143,7 +204,8 @@ def update_op_attr_name(attrs, attrs_alias_map): double_grad_item['forward']['attrs'], op_args['attrs'] ) - process_support_tensor(double_grad_item, support_tensor_list) + process_scalar(double_grad_item, scalar_configs) + process_int_array(double_grad_item, int_array_configs) # for triple grad if len(backward_op_list) > 2: @@ -163,9 +225,8 @@ def update_op_attr_name(attrs, attrs_alias_map): op_args['attrs'], ) - process_support_tensor( - triple_grad_item, support_tensor_list - ) + process_scalar(triple_grad_item, scalar_configs) + process_int_array(triple_grad_item, int_array_configs) key_set = ['inputs', 'attrs', 'outputs'] args_map = {} diff --git a/paddle/fluid/operators/generator/generate_sparse_op.py b/paddle/fluid/operators/generator/generate_sparse_op.py index 0f04e6130840c..1af4ee493b65d 100644 --- a/paddle/fluid/operators/generator/generate_sparse_op.py +++ b/paddle/fluid/operators/generator/generate_sparse_op.py @@ -20,10 +20,13 @@ from filters import ( cartesian_prod_mapping, to_input_name, + to_int_array_tensor_name, + to_int_array_tensors_name, to_op_attr_type, to_opmaker_name, to_opmaker_name_cstr, to_pascal_case, + to_scalar_tensor_name, ) from generate_op import process_invoke_op from jinja2 import Environment, FileSystemLoader, StrictUndefined @@ -49,6 +52,9 @@ env.filters["to_op_attr_type"] = to_op_attr_type env.filters["to_opmaker_name"] = to_opmaker_name env.filters["to_pascal_case"] = to_pascal_case +env.filters["to_scalar_tensor_name"] = to_scalar_tensor_name +env.filters["to_int_array_tensor_name"] = to_int_array_tensor_name +env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name env.filters["to_input_name"] = to_input_name env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index 70b593ff16ea2..9fa313ffb3391 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -61,14 +61,18 @@ AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name {% set name = attr["name"] %} {% set typename = attr["typename"] %} {% if typename is scalar %} -AddInput("{{name | to_pascal_case}}Tensor", "attribute {{i}} for {{op_name}} op from 0D Tensor.") +AddInput("{{attr | to_scalar_tensor_name}}", "attribute {{i}} for {{op_name}} op from 0D Tensor.") .AsDispensable(); {% elif typename == "IntArray" %}{# the type has been renamed #} -AddInput("{{name | to_pascal_case}}Tensor", "attribute {{i}} for {{op_name}} op from 1D integer Tensor.") + {% if 'tensor_name' in attr or 'manual_flag' not in attr %} +AddInput("{{attr | to_int_array_tensor_name}}", "attribute {{i}} for {{op_name}} op from 1D integer Tensor.") .AsDispensable(); -AddInput("{{name | to_pascal_case}}TensorList", "attribute {{i}} for {{op_name}} op from list fo 0D integer Tensors.") + {% endif %} + {% if 'tensors_name' in attr or 'manual_flag' not in attr %} +AddInput("{{attr | to_int_array_tensors_name}}", "attribute {{i}} for {{op_name}} op from list fo 0D integer Tensors.") .AsDuplicable() .AsDispensable(); + {% endif %} {% endif %} AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_type}}), attribute {{i}} for {{op_name}} op.") {% if "default_value" in attr %} @@ -108,7 +112,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum paddle::small_vector attrs; {% for attr in op["attrs"]%} {% filter indent(2)%} - {{get_an_attr(attr)}}; + {{get_an_attr(attr)}} {% endfilter %} {% endfor %} {{get_output_list(op["outputs"], kernel_args)}}; @@ -163,7 +167,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum paddle::small_vector attrs; {% for attr in op["attrs"]%} {% filter indent(2)%} - {{get_an_attr(attr)}}; + {{get_an_attr(attr)}} {% endfilter %} {% endfor %} {{get_output_list(op["outputs"], kernel_args)}}; @@ -206,21 +210,28 @@ paddle::small_vector inputs { {% set typename = attr["typename"] %} {% set name = attr["name"] %} {% if typename is scalar %}{# scalar correspond to a dispensable input and an attr in opmaker #} -attrs.emplace_back( - ctx.HasInput("{{name | to_pascal_case}}") - ? "{{name | to_pascal_case}}Tensor" - : "{{name}}" -) +attrs.emplace_back(ctx.HasInput("{{attr | to_scalar_tensor_name}}") ? "{{attr | to_scalar_tensor_name}}" : "{{name}}"); {%- elif typename == "IntArray" %} + {% if 'tensor_name' in attr and 'tensors_name' not in attr %} attrs.emplace_back( - ctx.HasInput("{{name | to_pascal_case}}Tensor") - ? "{{name | to_pascal_case}}Tensor" - : ctx.InputSize("{{name | to_pascal_case}}TensorList") > 0 - ? "{{name | to_pascal_case}}TensorList" - : "{{name}}" -) + ctx.HasInput("{{attr | to_int_array_tensor_name}}") + ? "{{attr | to_int_array_tensor_name}}" + : "{{name}}"); + {% elif 'tensor_name' not in attr and 'tensors_name' in attr %} +attrs.emplace_back( + ctx.InputSize("{{attr | to_int_array_tensors_name}}") > 0 + ? "{{attr | to_int_array_tensors_name}}" + : "{{name}}"); + {% else %} +attrs.emplace_back( + ctx.HasInput("{{attr | to_int_array_tensor_name}}") + ? "{{attr | to_int_array_tensor_name}}" + : ctx.InputSize("{{attr | to_int_array_tensors_name}}") > 0 + ? "{{attr | to_int_array_tensors_name}}" + : "{{name}}"); + {%- endif %} {%- else %} -attrs.emplace_back("{{name}}") +attrs.emplace_back("{{name}}"); {%- endif %} {%- endmacro %} @@ -398,10 +409,20 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker {% set attr_name = attr["name"] %} {% if attr_name in forward_attr_names %} {% if attr["typename"] == "IntArray" %} - grad_op->SetInput("{{attr_name | to_pascal_case}}Tensor", this->Input("{{attr_name | to_pascal_case}}Tensor")); - grad_op->SetInput("{{attr_name | to_pascal_case}}TensorList", this->Input("{{attr_name | to_pascal_case}}TensorList")); + {% if 'tensor_name' in attr or 'manual_flag' not in attr %} + if (this->HasInput("{{attr | to_int_array_tensor_name}}")) { + grad_op->SetInput("{{attr | to_int_array_tensor_name}}", this->Input("{{attr | to_int_array_tensor_name}}")); + } + {% endif %} + {% if 'tensors_name' in attr or 'manual_flag' not in attr %} + if (this->HasInput("{{attr | to_int_array_tensors_name}}")) { + grad_op->SetInput("{{attr | to_int_array_tensors_name}}", this->Input("{{attr | to_int_array_tensors_name}}")); + } + {% endif %} {% elif attr["typename"] == "Scalar" %} - grad_op->SetInput("{{attr_name | to_pascal_case}}Tensor", this->Input("{{attr_name | to_pascal_case}}Tensor")); + if (this->HasInput("{{attr | to_scalar_tensor_name}}")) { + grad_op->SetInput("{{attr | to_scalar_tensor_name}}", this->Input("{{attr | to_scalar_tensor_name}}")); + } {% endif %} {% else %}{# maybe something wrong: backward op has more attrs than the forward one#} grad_op->SetAttr("{{attr_name}}", {{process_default_value(attr)}}); diff --git a/paddle/fluid/operators/top_k_v2_op.cc b/paddle/fluid/operators/top_k_v2_op.cc deleted file mode 100644 index fdfda2e4029be..0000000000000 --- a/paddle/fluid/operators/top_k_v2_op.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class TopkV2Op : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - phi::DataLayout layout_ = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context(), - layout_, - library_); - } -}; - -class TopkV2OpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor) The input of Topk op"); - AddInput("K", - "(Tensor) Number of top elements to look for along " - "the last dimension (along each row for matrices).") - .AsDispensable(); - AddOutput("Out", "(Tensor) The output tensor of Topk op"); - AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); - AddComment(R"DOC( -Top K operator - -If the input is a vector (1d tensor), this operator finds the k largest -entries in the vector and outputs their values and indices as vectors. -Thus values[j] is the j-th largest entry in input, and its index is indices[j]. - -For matrices, this operator computes the top k entries in each row. )DOC"); - AddAttr("k", - "(int, default 1) Number of top elements to look for along " - "the tensor).") - .SetDefault(1); - AddAttr("axis", - "the axis to sort and get the k indices, value." - "if not set, will get k value in last axis.") - .SetDefault(-1); - AddAttr("largest", - "control flag whether to return largest or smallest") - .SetDefault(true); - AddAttr("sorted", - "control flag whether to return elements in sorted order") - .SetDefault(true); - } -}; - -class TopkV2OpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), - true, - platform::errors::InvalidArgument("Input(X) should be not null")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Indices"), - true, - platform::errors::InvalidArgument("Input(Indices) should be not null")); - PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), - true, - platform::errors::InvalidArgument( - "Grad Input(Out) should be not null")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput(framework::GradVarName("X")), - true, - platform::errors::InvalidArgument("Grad Output(X) should be not null")); - - auto x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto data_type = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); - } -}; - -template -class TopkV2GradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("top_k_v2_grad"); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetInput("X", this->Input("X")); - op->SetInput("Indices", this->Output("Indices")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(top_k_v2, - TopKInferShapeFunctor, - PD_INFER_META(phi::TopKInferMeta)); -REGISTER_OPERATOR(top_k_v2, - ops::TopkV2Op, - ops::TopkV2OpMaker, - ops::TopkV2GradOpMaker, - ops::TopkV2GradOpMaker, - TopKInferShapeFunctor); - -REGISTER_OPERATOR(top_k_v2_grad, ops::TopkV2OpGrad); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 2d333805b5aa0..be6a12194a4be 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -226,6 +226,16 @@ func : cosh_grad inplace : (out_grad -> x_grad) +- backward_op : crop_grad + forward : crop (Tensor x, IntArray shape, IntArray offsets) -> Tensor(out) + args : (Tensor x, Tensor out_grad, IntArray offsets) + output : Tensor(x_grad) + infer_meta : + func : CropGradInferMeta + kernel : + func : crop_grad + data_type : x + - backward_op : cross_grad forward : cross (Tensor x, Tensor y, int axis = 9) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis) @@ -1071,6 +1081,17 @@ func : thresholded_relu_grad inplace : (out_grad -> x_grad) +- backward_op : topk_grad + forward : topk (Tensor x, Scalar k, int axis = -1, bool largest = true, bool sorted = true) -> Tensor(out), Tensor(indices) + args : (Tensor x, Tensor indices, Tensor out_grad, Scalar k, int axis, bool largest, bool sorted) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : topk_grad + data_type : out_grad + - backward_op : trace_grad forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index b0ce57461685e..51ad68c993811 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -325,16 +325,6 @@ kernel : func : conv3d_transpose_grad -- backward_op : crop_grad - forward : crop (Tensor x, IntArray shape, IntArray offsets) -> Tensor(out) - args : (Tensor x, Tensor out_grad, IntArray offsets) - output : Tensor(x_grad) - infer_meta : - func : CropGradInferMeta - kernel : - func : crop_grad - data_type : x - - backward_op : cross_entropy_with_softmax_grad forward : cross_entropy_with_softmax (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) -> Tensor(softmax), Tensor(loss) args : (Tensor label, Tensor softmax, Tensor loss_grad, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) @@ -1705,16 +1695,6 @@ no_need_buffer : x backward : tile_double_grad -- backward_op : topk_grad - forward : topk (Tensor x, Scalar k, int axis = -1, bool largest = true, bool sorted = true) -> Tensor(out), Tensor(indices) - args : (Tensor x, Tensor indices, Tensor out_grad, Scalar k = -1, int axis = -1, bool largest = true, bool sorted = true) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : topk_grad - - backward_op : transpose_double_grad forward : transpose_grad (Tensor grad_out, int[] perm) -> Tensor(grad_x) args : (Tensor grad_x_grad, int[] perm) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 0811dd4b6bff1..1ac85210497d0 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -464,16 +464,6 @@ output : Tensor(out) invoke : copy_to_impl(x, place, blocking) -- op : crop - args : (Tensor x, IntArray shape, IntArray offsets) - output : Tensor(out) - infer_meta : - func : CropInferMeta - kernel : - func : crop - data_type : x - backward : crop_grad - # Part of python API paddle.nn.functional.cross_entropy - op : cross_entropy_with_softmax args : (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) @@ -2086,15 +2076,6 @@ func : tile backward : tile_grad -- op : topk - args : (Tensor x, Scalar k, int axis = -1, bool largest = true, bool sorted = true) - output : Tensor(out), Tensor(indices) - infer_meta : - func : TopKInferMeta - kernel : - func : topk - backward : topk_grad - - op : transpose args : (Tensor x, int[] perm) output : Tensor diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index d2c28bc32eabe..2064aee605d24 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -252,6 +252,22 @@ extra : attrs : [bool use_mkldnn = false, bool use_cudnn = false] +- op : crop (crop_tensor) + backward : crop_grad (crop_tensor_grad) + inputs : + x : X + outputs : + out : Out + int_array: + shape : + data_type : int + tensor_name : Shape + tensors_name : ShapeTensor + offsets : + data_type : int + tensor_name : Offsets + tensors_name : OffsetsTensor + - op : cross inputs : {x : X, y : Y} @@ -815,6 +831,10 @@ outputs : out : Out support_tensor : num_samples + scalar : + num_samples : + data_type : int + support_tensor : true - op : multiply (elementwise_mul) backward : multiply_grad (elementwise_mul_grad) @@ -1161,6 +1181,22 @@ outputs : out : Out +- op : topk (top_k_v2) + backward : topk_grad (top_k_v2_grad) + inputs : + x : X + outputs : + {out : Out, indices : Indices} + scalar : + k : + data_type : int + tensor_name : K + # int_array : + # shape : + # data_type : int64_t + # tensor_name : Shape + # tensors_name : ShapeList + - op : trace inputs : x : Input diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index e8653fc79e6e3..29785c9c161b8 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -179,6 +179,16 @@ func : cosh backward : cosh_grad +- op : crop + args : (Tensor x, IntArray shape = {}, IntArray offsets = {}) + output : Tensor(out) + infer_meta : + func : CropInferMeta + kernel : + func : crop + data_type : x + backward : crop_grad + - op : cross args : (Tensor x, Tensor y, int axis = 9) output : Tensor @@ -886,6 +896,16 @@ func : thresholded_relu backward : thresholded_relu_grad +- op : topk + args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true) + output : Tensor(out), Tensor(indices) + infer_meta : + func : TopKInferMeta + kernel : + func : topk + data_type : x + backward : topk_grad + - op : trace args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1) output : Tensor diff --git a/paddle/phi/ops/compat/crop_tensor_sig.cc b/paddle/phi/ops/compat/crop_tensor_sig.cc deleted file mode 100644 index 8cf4ddab336bb..0000000000000 --- a/paddle/phi/ops/compat/crop_tensor_sig.cc +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature CropTensorOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.InputSize("ShapeTensor") > 0) { - if (ctx.InputSize("OffsetsTensor") > 0) { - return KernelSignature( - "crop", {"X"}, {"ShapeTensor", "OffsetsTensor"}, {"Out"}); - } else if (ctx.HasInput("Offsets")) { - return KernelSignature( - "crop", {"X"}, {"ShapeTensor", "Offsets"}, {"Out"}); - } else { - return KernelSignature( - "crop", {"X"}, {"ShapeTensor", "offsets"}, {"Out"}); - } - } else if (ctx.HasInput("Shape")) { - if (ctx.InputSize("OffsetsTensor") > 0) { - return KernelSignature( - "crop", {"X"}, {"Shape", "OffsetsTensor"}, {"Out"}); - } else if (ctx.HasInput("Offsets")) { - return KernelSignature("crop", {"X"}, {"Shape", "Offsets"}, {"Out"}); - } else { - return KernelSignature("crop", {"X"}, {"Shape", "offsets"}, {"Out"}); - } - } else { - if (ctx.InputSize("OffsetsTensor") > 0) { - return KernelSignature( - "crop", {"X"}, {"shape", "OffsetsTensor"}, {"Out"}); - } else if (ctx.HasInput("Offsets")) { - return KernelSignature("crop", {"X"}, {"shape", "Offsets"}, {"Out"}); - } else { - return KernelSignature("crop", {"X"}, {"shape", "offsets"}, {"Out"}); - } - } -} - -KernelSignature CropTensorGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.InputSize("OffsetsTensor") > 0) { - return KernelSignature( - "crop_grad", {"X", "Out@GRAD"}, {"OffsetsTensor"}, {"X@GRAD"}); - } else if (ctx.HasInput("Offsets")) { - return KernelSignature( - "crop_grad", {"X", "Out@GRAD"}, {"Offsets"}, {"X@GRAD"}); - } else { - return KernelSignature( - "crop_grad", {"X", "Out@GRAD"}, {"offsets"}, {"X@GRAD"}); - } -} - -} // namespace phi - -PD_REGISTER_BASE_KERNEL_NAME(crop_tensor, crop); -PD_REGISTER_BASE_KERNEL_NAME(crop_tensor_grad, crop_grad); - -PD_REGISTER_ARG_MAPPING_FN(crop_tensor, phi::CropTensorOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(crop_tensor_grad, - phi::CropTensorGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/top_k_sig.cc b/paddle/phi/ops/compat/top_k_sig.cc deleted file mode 100644 index 0f3a5c1c0b5f9..0000000000000 --- a/paddle/phi/ops/compat/top_k_sig.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature TopkOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.HasInput("K")) { - return KernelSignature( - "topk", {"X"}, {"K", "axis", "largest", "sorted"}, {"Out", "Indices"}); - - } else { - return KernelSignature( - "topk", {"X"}, {"k", "axis", "largest", "sorted"}, {"Out", "Indices"}); - } -} - -KernelSignature TopkGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("topk_grad", - {"X", "Indices", "Out@GRAD"}, - {"k", "axis", "largest", "sorted"}, - {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_BASE_KERNEL_NAME(top_k_v2, topk); -PD_REGISTER_BASE_KERNEL_NAME(top_k_v2_grad, topk_grad); -PD_REGISTER_ARG_MAPPING_FN(top_k_v2, phi::TopkOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(top_k_v2_grad, phi::TopkGradOpArgumentMapping); From 887986e9606af87215da816b830d626734edccbd Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 6 Dec 2022 12:03:54 +0000 Subject: [PATCH 3/4] polish code --- paddle/phi/api/yaml/op_compat.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index eb3b08ba53e81..f12da5c97f7d0 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -844,7 +844,6 @@ {x : X} outputs : out : Out - support_tensor : num_samples scalar : num_samples : data_type : int @@ -1223,11 +1222,6 @@ k : data_type : int tensor_name : K - # int_array : - # shape : - # data_type : int64_t - # tensor_name : Shape - # tensors_name : ShapeList - op : trace inputs : From b64472ef1b154e1e0cb80f8cda7098372243178f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 7 Dec 2022 04:13:33 +0000 Subject: [PATCH 4/4] fix bug of data_type --- paddle/fluid/operators/generator/generate_op.py | 12 +++++++++++- paddle/fluid/operators/generator/parse_utils.py | 3 +++ .../generator/templates/operator_utils.c.j2 | 5 ++++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index dad1da1b32379..d7121a2aeb567 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -101,6 +101,11 @@ def process_scalar(op_item, scalar_configs): else scalar_map[attr_type] ) else: + attr_item['data_type'] = ( + scalar_config['data_type'] + if 'data_type' in scalar_config + else scalar_map[attr_type] + ) attr_item['tensor_name'] = scalar_config['tensor_name'] @@ -126,11 +131,16 @@ def process_int_array(op_item, int_array_configs): ) if attr_item['is_support_tensor']: attr_item['typename'] = ( - int_array_config['data_type'] + data_type_map[int_array_config['data_type']] if 'data_type' in int_array_config else 'std::vector' ) else: + attr_item['data_type'] = ( + data_type_map[int_array_config['data_type']] + if 'data_type' in int_array_config + else 'std::vector' + ) attr_item['manual_flag'] = True if 'tensor_name' in int_array_config: attr_item['tensor_name'] = int_array_config[ diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index fb7940ddfe608..419f6bc3c921b 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Tuple from tests import is_attr, is_input, is_output, is_vec +from type_mapping import opmaker_attr_types_map def to_named_dict(items: List[Dict]) -> Dict[str, Dict]: @@ -97,6 +98,8 @@ def parse_input_and_attr( ), f"{op_name}: Arguments with default value should not precede those without default value" elif "default_value" in item: met_attr_with_default_value = True + if typename.startswith('Scalar') or typename == 'IntArray': + item['data_type'] = opmaker_attr_types_map[typename] attrs.append(item) else: raise KeyError(f"{op_name}: Invalid argument type {typename}.") diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index 9fa313ffb3391..0b49721afcc9e 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -63,6 +63,7 @@ AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name {% if typename is scalar %} AddInput("{{attr | to_scalar_tensor_name}}", "attribute {{i}} for {{op_name}} op from 0D Tensor.") .AsDispensable(); +AddAttr<{{attr["data_type"]}}>("{{name}}", "({{attr["data_type"]}}), attribute {{i}} for {{op_name}} op.") {% elif typename == "IntArray" %}{# the type has been renamed #} {% if 'tensor_name' in attr or 'manual_flag' not in attr %} AddInput("{{attr | to_int_array_tensor_name}}", "attribute {{i}} for {{op_name}} op from 1D integer Tensor.") @@ -73,8 +74,10 @@ AddInput("{{attr | to_int_array_tensors_name}}", "attribute {{i}} for {{op_name} .AsDuplicable() .AsDispensable(); {% endif %} - {% endif %} +AddAttr<{{attr["data_type"]}}>("{{name}}", "({{attr["data_type"]}}), attribute {{i}} for {{op_name}} op.") + {% else %} AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_type}}), attribute {{i}} for {{op_name}} op.") + {% endif %} {% if "default_value" in attr %} .SetDefault({{process_default_value(attr)}}) {%- endif %}