diff --git a/paddle/fluid/operators/generator/type_mapping.py b/paddle/fluid/operators/generator/type_mapping.py index 8d3a4933c3bd0a..56e01a997e61b7 100644 --- a/paddle/fluid/operators/generator/type_mapping.py +++ b/paddle/fluid/operators/generator/type_mapping.py @@ -48,7 +48,7 @@ 'int64_t[]': 'const std::vector&', 'float[]': 'const std::vector&', 'double[]': 'const std::vector&', - 'str[]': 'const std::vector<&', + 'str[]': 'const std::vector&', } opmaker_attr_types_map = { @@ -86,8 +86,8 @@ } optional_output_type_map = { - 'Tensor': 'const paddle::optional&', - 'Tensor[]': 'const paddle::optional>&', + 'Tensor': 'const paddle::optional', + 'Tensor[]': 'const paddle::optional>', } # ------------------------------ phi attr ------------------------------ diff --git a/paddle/fluid/pir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py index 9f51351f6ea044..c336dc7b61be18 100644 --- a/paddle/fluid/pir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -197,7 +197,9 @@ def _is_optional_input(self, op_info, input_name): return True return False - def _is_optinonal_output(self, op_info, output_name): + def _is_optional_output(self, op_info, op_name, output_name): + if op_name.endswith(('_grad', '_grad_')): + return False inplace_map = op_info.inplace_map input_optional_list = op_info.input_optional_list input_name_list = op_info.input_name_list @@ -271,7 +273,7 @@ def _gen_api_args( ) return (inputs + ', ' + attrs).strip(', ') - def _gen_ret_type(self, op_info): + def _gen_ret_type(self, op_info, op_name): name_list = op_info.output_name_list type_list = op_info.output_type_list intermediate_list = op_info.output_intermediate_list @@ -285,7 +287,7 @@ def _gen_ret_type(self, op_info): ): if intermediate == 'true': continue - if self._is_optinonal_output(op_info, name): + if self._is_optional_output(op_info, op_name, name): ret.append(OPTIONAL_OUTPUT_TYPE_MAP[type]) else: ret.append(OUTPUT_TYPE_MAP[type]) @@ -293,7 +295,7 @@ def _gen_ret_type(self, op_info): elif output_num == 1: index = intermediate_list.index('false') name = name_list[index] - if self._is_optinonal_output(op_info, name): + if self._is_optional_output(op_info, op_name, name): return OPTIONAL_OUTPUT_TYPE_MAP[type_list[index]] else: return OUTPUT_TYPE_MAP[type_list[index]] @@ -304,7 +306,7 @@ def _gen_one_declare( self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr ): return API_DECLARE_TEMPLATE.format( - ret_type=self._gen_ret_type(op_info), + ret_type=self._gen_ret_type(op_info, op_name), api_name=op_name, args=self._gen_api_args( op_info, True, is_mutable_attr, is_vector_mutable_attr @@ -367,7 +369,7 @@ def _gen_handle_optional_outputs(self, op_info, op_name): ): if intermediate == 'true': continue - if self._is_optinonal_output(op_info, name): + if self._is_optional_output(op_info, op_name, name): if VECTOR_TYPE in type: ret += OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE.format( name=name, @@ -461,7 +463,7 @@ def _gen_compute_op( op_inst_name, ) - def _gen_out_split_and_ret_list(self, op_info, op_inst_name): + def _gen_out_split_and_ret_list(self, op_info, op_name, op_inst_name): name_list = op_info.output_name_list type_list = op_info.output_type_list intermediate_list = op_info.output_intermediate_list @@ -480,7 +482,7 @@ def _gen_out_split_and_ret_list(self, op_info, op_inst_name): ): if intermediate == 'true': continue - if self._is_optinonal_output(op_info, name): + if self._is_optional_output(op_info, op_name, name): ret_list.append(f'optional_{name}') elif VECTOR_TYPE in type: split_op_name = f'{name}_split_op' @@ -503,7 +505,7 @@ def _gen_return_result(self, ret_list): def _gen_one_impl( self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr ): - ret_type = self._gen_ret_type(op_info) + ret_type = self._gen_ret_type(op_info, op_name) in_combine, in_combine_op_list = self._gen_in_combine( op_info, is_mutable_attr, is_vector_mutable_attr ) @@ -514,7 +516,7 @@ def _gen_one_impl( compute_op += f' (void){op_inst_name};' out_split, ret_list = self._gen_out_split_and_ret_list( - op_info, op_inst_name + op_info, op_name, op_inst_name ) ret = API_IMPL_TEMPLATE.format( diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index e24902c712c1a7..ba78e7d7dc722d 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -144,6 +144,7 @@ def GenBuildInputArgsStr( 'int': 'phi::DataType::INT32', 'int64_t': 'phi::DataType::INT64', 'float': 'phi::DataType::FLOAT32', + 'double': 'phi::DataType::FLOAT64', 'std::vector': 'phi::DataType::INT64', 'const std::vector&': 'phi::DataType::INT64', 'bool': 'phi::DataType::BOOL', diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 167b950ee95e7c..d9dd1cc879a23e 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -477,7 +477,7 @@ def parse_mutable_attribute(self): if (self.op_compat_item['op'] == "isclose") or ( self.op_compat_item['op'] == "allclose" ): - data_type = "float" + data_type = "double" mutable_attribute_type_list.append( [ "paddle::dialect::ScalarAttribute", diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index eb5acbf2388ea8..be652e48263301 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -100,5 +100,24 @@ pir::OpResult split_with_num_grad(const std::vector& out_grad, out_grad_combine_op.out(), axis); return split_grad_op.result(0); } + +pir::OpResult ones(const std::vector& shape, + phi::DataType dtype, + const Place& place) { + return paddle::dialect::full(shape, 1, dtype, place); +} + +pir::OpResult ones_like(pir::Value x_, + phi::DataType dtype, + const Place& place) { + return paddle::dialect::full_like(x_, 1, dtype, place); +} + +pir::OpResult zeros(const std::vector& shape, + phi::DataType dtype, + const Place& place) { + return paddle::dialect::full(shape, 0, dtype, place); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index fe579295ad5a09..a9df64a905b24d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -47,5 +47,18 @@ pir::OpResult split_with_num_grad(const std::vector& out_grad, pir::OpResult split_with_num_grad(const std::vector& out_grad, const pir::Value& axis); + +pir::OpResult ones(const std::vector& shape, + phi::DataType dtype = phi::DataType::FLOAT32, + const Place& place = phi::CPUPlace()); + +pir::OpResult ones_like(pir::Value x_, + phi::DataType dtype = phi::DataType::UNDEFINED, + const Place& place = {}); + +pir::OpResult zeros(const std::vector& shape, + phi::DataType dtype = phi::DataType::FLOAT32, + const Place& place = phi::CPUPlace()); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/primitive/backend/manual/manual_backend.h b/paddle/fluid/primitive/backend/manual/manual_backend.h index 3c9340164ac012..4faabab79f6852 100644 --- a/paddle/fluid/primitive/backend/manual/manual_backend.h +++ b/paddle/fluid/primitive/backend/manual/manual_backend.h @@ -24,7 +24,7 @@ namespace primitive { namespace backend { using Tensor = paddle::Tensor; -using Scalar = paddle::experimental::Scalar; +using Scalar = phi::Scalar; using IntArray = paddle::experimental::IntArray; using DataType = phi::DataType; @@ -32,6 +32,13 @@ template std::vector add_n_grad(const std::vector& x, const Tensor& out_grad); +template +Tensor embedding_grad(const Tensor& x, + const Tensor& weight, + const Tensor& out_grad, + int64_t padding_idx = -1, + bool sparse = false); + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/manual/manual_static_backend.cc b/paddle/fluid/primitive/backend/manual/manual_static_backend.cc index 7b33200336d000..b115e6a0210974 100644 --- a/paddle/fluid/primitive/backend/manual/manual_static_backend.cc +++ b/paddle/fluid/primitive/backend/manual/manual_static_backend.cc @@ -45,6 +45,23 @@ std::vector add_n_grad(const std::vector& x, return x_grad; } +template <> +Tensor embedding_grad(const Tensor& x, + const Tensor& weight, + const Tensor& out_grad, + int64_t padding_idx, + bool sparse) { + pir::Value x_res = std::static_pointer_cast(x.impl())->value(); + pir::Value weight_res = + std::static_pointer_cast(weight.impl())->value(); + pir::Value out_grad_res = + std::static_pointer_cast(out_grad.impl())->value(); + auto op_res = paddle::dialect::embedding_grad( + x_res, weight_res, out_grad_res, padding_idx, sparse); + Tensor out(std::make_shared(op_res)); + return out; +} + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 88f0209eb59d67..96630e50bd5cf6 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -37,104 +37,13 @@ # fmt: on -VJPS = [ - 'where_grad', - 'tril_grad', - 'triu_grad', - 'tile_grad', - 'tanh_grad', - 'mean_grad', - 'add_grad', - 'divide_grad', - 'sum_grad', - 'concat_grad', - 'split_grad', - 'split_with_num_grad', - 'gelu_grad', - 'softmax_grad', - 'silu_grad', - 'multiply_grad', - 'subtract_grad', - 'erf_grad', - 'expand_grad', - 'exp_grad', - 'expm1_grad', - 'elementwise_pow_grad', - 'fused_softmax_mask_upper_triangle_grad', - 'matmul_grad', - 'pow_grad', - 'rsqrt_grad', - 'slice_grad', - 'transpose_grad', - 'square_grad', - 'dropout_grad', - 'cast_grad', - 'slice_double_grad', - 'layer_norm_grad', - 'embedding_grad', - 'scale_grad', - 'gather_nd_grad', - 'stack_grad', - 'squeeze_grad', - 'unsqueeze_grad', - 'poisson_grad', - 'gumbel_softmax_grad', - 'conv2d_grad', - 'depthwise_conv2d_grad', - 'sqrt_grad', - 'flatten_grad', - 'relu_grad', - 'abs_grad', - 'log_grad', - 'clip_grad', - 'ceil_grad', - 'frobenius_norm_grad', - 'p_norm_grad', - 'maximum_grad', - 'argsort_grad', - 'min_grad', - 'batch_norm_grad', - 'max_pool2d_with_index_grad', - 'pool2d_grad', - 'minimum_grad', - 'prod_grad', - 'round_grad', - 'sin_grad', - 'cos_grad', - 'dot_grad', - 'floor_grad', - 'topk_grad', - 'square_grad', - 'gather_grad', - 'label_smooth_grad', - 'cross_entropy_with_softmax_grad', - 'mean_all_grad', - 'cumsum_grad', - 'linear_interp_grad', - 'bilinear_interp_grad', - 'trilinear_interp_grad', - 'nearest_interp_grad', - 'bicubic_interp_grad', - 'assign_grad', - 'assign_out__grad', - 'real_grad', - 'flip_grad', - 'softmax_grad', - 'expand_grad', - 'conv2d_transpose_grad', - 'depthwise_conv2d_transpose_grad', - 'sigmoid_grad', - 'pad_grad', - 'pad3d_grad', - 'einsum_grad', - 'leaky_relu_grad', - 'log10_grad', - 'conv3d_grad', - 'solve_grad', - 'diag_grad', - 'trace_grad', +VJPS_BLACK_LIST = [ + 'reshape_grad', + 'add_n_grad', ] +BACKENDS_BLACK_LIST = ['copy_to', 'add_n_grad', "allclose", "isclose"] + PRIM_VJP = [ 'divide_grad', @@ -156,147 +65,6 @@ ] # custom vjp list of composite op VJP_COMPS = PRIM_VJP + CUSTOM_VJP -BACKENDS = [ - 'where_grad', - 'tril_grad', - 'triu_grad', - 'tile_grad', - 'add_n', - 'mean', - 'sum', - 'divide', - 'full', - 'tanh', - 'tanh_grad', - 'mean_grad', - 'concat', - 'add', - 'multiply', - 'elementwise_pow', - 'scale', - 'reshape', - 'expand', - 'tile', - 'add_grad', - 'divide_grad', - 'sum_grad', - 'concat_grad', - 'split_grad', - 'split_with_num_grad', - 'gelu_grad', - 'softmax_grad', - 'silu_grad', - 'multiply_grad', - 'subtract_grad', - 'erf_grad', - 'expand_grad', - 'exp_grad', - 'expm1_grad', - 'multiply', - 'exp', - 'erf', - 'cast', - 'elementwise_pow_grad', - 'fused_softmax_mask_upper_triangle_grad', - 'matmul_grad', - 'pow_grad', - 'reshape_grad', - 'rsqrt_grad', - 'slice_grad', - 'transpose_grad', - 'subtract', - 'assign', - 'equal', - 'greater_equal', - 'greater_than', - 'less_equal', - 'less_than', - 'matmul', - 'max', - 'maximum', - 'minimum', - 'not_equal', - 'abs', - 'bitwise_and', - 'bitwise_not', - 'bitwise_or', - 'bitwise_xor', - 'floor', - 'gather_nd', - 'log', - 'roll', - 'scatter', - 'scatter_nd_add', - 'square_grad', - 'dropout_grad', - 'slice', - 'layer_norm_grad', - 'embedding_grad', - 'sqrt', - 'uniform', - 'poisson_grad', - 'gumbel_softmax_grad', - 'split', - 'transpose', - 'gather_nd_grad', - 'stack_grad', - 'squeeze_grad', - 'unsqueeze_grad', - 'conv2d_grad', - 'depthwise_conv2d_grad', - 'sqrt_grad', - 'flatten_grad', - 'relu_grad', - 'abs_grad', - 'log_grad', - 'clip_grad', - 'ceil_grad', - 'frobenius_norm_grad', - 'p_norm_grad', - 'maximum_grad', - 'argsort_grad', - 'min_grad', - 'batch_norm_grad', - 'max_pool2d_with_index_grad', - 'pool2d_grad', - 'minimum_grad', - 'prod_grad', - 'round_grad', - 'sin_grad', - 'cos_grad', - 'dot_grad', - 'floor_grad', - 'topk_grad', - 'square_grad', - 'gather_grad', - 'label_smooth_grad', - 'cross_entropy_with_softmax_grad', - 'mean_all_grad', - 'cumsum_grad', - 'linear_interp_grad', - 'bilinear_interp_grad', - 'trilinear_interp_grad', - 'nearest_interp_grad', - 'bicubic_interp_grad', - 'assign_out__grad', - 'real_grad', - 'softmax_grad', - 'conv2d_transpose_grad', - 'depthwise_conv2d_transpose_grad', - 'sigmoid_grad', - 'pad_grad', - 'pad3d_grad', - 'einsum_grad', - 'leaky_relu_grad', - 'log10_grad', - 'conv3d_grad', - 'solve_grad', - 'diag_grad', - 'trace_grad', - 'flip', - 'sign', -] - def load(path: pathlib.Path): """Load config from yaml file. @@ -347,6 +115,7 @@ def render(src_dir: pathlib.Path, dst_dir: pathlib.Path, *args, **kwargs): 'datatype': op_gen_tests.is_datatype, 'exist_mutable_attribute': op_gen_tests.exist_mutable_attribute, 'mutable_attribute': op_gen_tests.is_mutable_attribute, + 'only_composite_op': op_gen_tests.is_only_composite_op, } ) for tpl in env.list_templates( @@ -497,6 +266,22 @@ def process_backward_invoke_info(apis): api['invoke']['args'] = ', '.join(args) +def process_optional_output_info(apis): + for api in apis: + if not api['is_fwd']: + continue + inputs_dict = to_named_dict(api['inputs']) + for output in api['outputs']: + if ( + api.get("inplace", None) + and output['name'] in api['inplace'] + and inputs_dict[api['inplace'][output['name']]]['optional'] + ): + output['optional'] = True + else: + output['optional'] = False + + def gen( prim_path: pathlib.Path, fwd_path: pathlib.Path, @@ -545,12 +330,13 @@ def gen( apis = extend_compat_info(apis, compats) apis = apis + get_inplace_api(apis) process_backward_invoke_info(apis) + process_optional_output_info(apis) render( templates_dir, destination_dir, apis=apis, - backend_white_list=BACKENDS, - vjp_white_list=VJPS, + backend_black_list=BACKENDS_BLACK_LIST, + vjp_black_list=VJPS_BLACK_LIST, vjp_comp_white_list=VJP_COMPS, ) diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 index 25443f52fe8af7..863bbb7de633fb 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 @@ -15,24 +15,22 @@ namespace primitive { namespace backend { using Tensor = paddle::Tensor; -using Scalar = paddle::experimental::Scalar; +using Scalar = phi::Scalar; using IntArray = paddle::experimental::IntArray; using DataType = phi::DataType; {% for api in apis %} - {%- if api.name in backend_white_list -%} - {% set inplace_map = {} %} - {% if 'inplace' in api and api.inplace != None %} - {% for source, target in api.inplace.items() %} - {% do inplace_map.update({source: target}) %} - {% endfor %} - {% endif %} - {% if api.attrs is exist_mutable_attribute %} -{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, inplace_map, True, True)}}; + {%- if api is only_composite_op -%}{#- render nothing -#} + {%- elif api.name not in backend_black_list -%} + {%- if 'invoke' not in api or 'invoke' in api and api.is_fwd -%} + {% if api.attrs is exist_mutable_attribute %} +{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, True, True)}}; - {% endif %} -{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, inplace_map, False, True)}}; + {% endif %} +{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, False, True)}}; + {% endif %} + {% else %}{#- render nothing -#} {% endif %} {% endfor %} } // namespace backend diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 index 34e427f0c2e03b..3b9a94993eaa4e 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 @@ -16,9 +16,9 @@ namespace backend { {{common.sequence('', '', ', ', attrs)}} {%- endmacro -%} -{%- macro sig(name, inputs, attrs, outputs, inplace_map) -%} +{%- macro sig(name, inputs, attrs, outputs) -%} template <> -{{common.ret(outputs, inplace_map)}} {{name}}({{common.params(inputs, attrs, False)}}) +{{common.ret(outputs)}} {{name}}({{common.params(inputs, attrs, False)}}) {%- endmacro -%} {% macro body(name, inputs, attrs, outputs) %} @@ -27,21 +27,15 @@ template <> {%- set attr_names = [] -%} {%- for i in attrs -%} {%- do attr_names.append(i.name) -%} {%-endfor-%} {% filter indent(2, True) %} -VLOG(4) << "Eager Prim API {name}_ad_func call"; +VLOG(4) << "Eager Prim API {{name}}_ad_func call"; return ::{{name}}_ad_func({{common.args(input_names, attr_names)}}); {% endfilter %} {% endmacro %} {% for api in apis %} - {%- if api.is_prim and api.name in backend_white_list -%} - {% set inplace_map = {} %} - {% if 'inplace' in api and api.inplace != None %} - {% for source, target in api.inplace.items() %} - {% do inplace_map.update({source: target}) %} - {% endfor %} - {% endif %} -{{sig(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate, inplace_map)}} { + {%- if api.is_prim and api.name not in backend_black_list and api.name[-1] != '_' -%} +{{sig(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate)}} { {{body(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate)}} } diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 index 152cd241ad8333..97b150b0d2dfcc 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 @@ -12,9 +12,9 @@ namespace backend { using LazyTensor = paddle::primitive::LazyTensor; -{%- macro sig(name, inputs, outputs, attrs, inplace_map, mutable_attribute_as_inputs=False) -%} +{%- macro sig(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False) -%} template <> -{{common.ret(outputs, inplace_map)}} {{name}}({{common.params(inputs, attrs, mutable_attribute_as_inputs, False)}}) +{{common.ret(outputs)}} {{name}}({{common.params(inputs, attrs, mutable_attribute_as_inputs, False)}}) {%- endmacro -%} {%- macro prepare_ir_api_inputs(inputs)-%} @@ -48,13 +48,13 @@ if({{input.name}}) { {%- macro get_static_backend_outputs(outputs)-%} {%- if outputs|length == 1 -%} - {%- if outputs[0].typename == 'Tensor' and not outputs[0].optional-%} + {%- if outputs[0].typename == 'Tensor' and not outputs[0].optional -%} Tensor {{outputs[0].name}}(std::make_shared(op_res)); return {{outputs[0].name}}; {%- elif outputs[0].typename == 'Tensor' and outputs[0].optional -%} paddle::optional {{outputs[0].name}}; if(op_res){ - {{outputs[0].name}} = paddle::make_optional(Tensor(std::make_shared(op_res.get())); + {{outputs[0].name}} = paddle::make_optional(Tensor(std::make_shared(op_res.get()))); } return {{outputs[0].name}}; {%- elif outputs[0].typename == 'Tensor[]' and not outputs[0].optional -%} @@ -80,7 +80,7 @@ return {{outputs[0].name}}; auto op_res_{{i}} = std::get<{{i}}>(op_res); {% if outputs[i].typename == 'Tensor' and not outputs[i].optional %} Tensor {{outputs[i].name}}(std::make_shared(op_res_{{i}})); - {% elif outputs[i].typename == 'Tensor' and outputs[i].optional %} + {% elif outputs[i].typename == 'Tensor' and outputs[i].optional %} paddle::optional {{outputs[i].name}}; if(op_res_{{i}}){ {{outputs[i].name}} = paddle::make_optional(Tensor(std::make_shared(op_res_{{i}}.get()))); @@ -139,28 +139,26 @@ auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}} {% for api in apis %} -{% if api.name in backend_white_list %} +{%- if api is only_composite_op -%}{#- render nothing -#} +{% elif api.name not in backend_black_list %} + {%- if 'invoke' not in api or 'invoke' in api and api.is_fwd-%} {% set api_outputs = api.outputs | trip_intermediate %} - {% set inplace_map = {} %} - {% if 'inplace' in api and api.inplace != None %} - {% for source, target in api.inplace.items() %} - {% do inplace_map.update({source: target}) %} - {% endfor %} - {% endif %} -{{sig(api.name, api.inputs, api_outputs, api.attrs, inplace_map)}} { +{{sig(api.name, api.inputs, api_outputs, api.attrs)}} { {% filter indent(2, True) %} {{body(api.name, api.inputs, api_outputs, api.attrs)}} {% endfilter %} } - {% if api.attrs is exist_mutable_attribute %} -{{sig(api.name, api.inputs, api_outputs, api.attrs, inplace_map, True)}} { + {% if api.attrs is exist_mutable_attribute %} +{{sig(api.name, api.inputs, api_outputs, api.attrs, True)}} { {% filter indent(2, True) %} {{body(api.name, api.inputs, api_outputs, api.attrs, True)}} {% endfilter %} } + {% endif %} {% endif %} +{% else %}{#- render nothing -#} {% endif %} {% endfor %} diff --git a/paddle/fluid/primitive/codegen/templates/common.j2 b/paddle/fluid/primitive/codegen/templates/common.j2 index 6ac639e8ceeaef..5f7148017ab23b 100644 --- a/paddle/fluid/primitive/codegen/templates/common.j2 +++ b/paddle/fluid/primitive/codegen/templates/common.j2 @@ -1,6 +1,6 @@ -{%- macro sig(name, inputs, outputs, attrs, inplace_map, mutable_attribute_as_inputs=False, default=False) -%} +{%- macro sig(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False, default=False) -%} template -{{ret(outputs, inplace_map)}} {{name}}({{params(inputs, attrs, mutable_attribute_as_inputs, default)}}) +{{ret(outputs)}} {{name}}({{params(inputs, attrs, mutable_attribute_as_inputs, default)}}) {%- endmacro %} @@ -40,9 +40,9 @@ template {%- endmacro -%} -{%- macro ret(outputs, inplace_map) -%} +{%- macro ret(outputs) -%} {%- set names = [] -%} - {%- for i in outputs -%} {%- do names.append(i.typename|to_paddle_output_type(i.name in inplace_map and i.optional)) -%} {%- endfor -%} + {%- for i in outputs -%} {%- do names.append(i.typename|to_paddle_output_type(i.optional)) -%} {%- endfor -%} {%- if names|length > 1 -%} std::tuple<{{sequence('', '', ', ', names)}}> {%- else -%} @@ -73,5 +73,9 @@ std::tuple<{{sequence('', '', ', ', names)}}> {%- macro scalar2ir(name, data_type) -%} + {%- if data_type == 'std::vector' -%} +{{name}} + {%- else -%} {{name}}.to<{{data_type}}>() + {%- endif -%} {%- endmacro -%} diff --git a/paddle/fluid/primitive/codegen/templates/primitive/primitive.h.j2 b/paddle/fluid/primitive/codegen/templates/primitive/primitive.h.j2 index 5cf6807470f2bf..90c8d4ce5d89fa 100644 --- a/paddle/fluid/primitive/codegen/templates/primitive/primitive.h.j2 +++ b/paddle/fluid/primitive/codegen/templates/primitive/primitive.h.j2 @@ -13,18 +13,12 @@ using Tensor = paddle::Tensor; using IntArray = paddle::experimental::IntArray; {% for api in apis %} -{%- if api.is_prim and api.name in backend_white_list and api.name[-1] != '_' -%} +{%- if api.is_prim and api.name not in backend_black_list and api.name[-1] != '_' -%} {%- set input_names = [] -%} {%- for i in api.inputs -%} {%- do input_names.append(i.name) -%} {%- endfor -%} {%- set attr_names = [] -%} {%- for i in api.attrs -%} {%- do attr_names.append(i.name) -%} {% endfor %} - {% set inplace_map = {} %} - {% if 'inplace' in api and api.inplace != None %} - {% for source, target in api.inplace.items() %} - {% do inplace_map.update({source: target}) %} - {% endfor %} - {% endif %} -{{common.sig(api.name, api.inputs, api.outputs | trip_intermediate, api.attrs, inplace_map, False, True)}} { +{{common.sig(api.name, api.inputs, api.outputs | trip_intermediate, api.attrs, False, True)}} { return backend::{{api.name}}({{common.args(input_names, attr_names)}}); } diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index 50a0c5d86fc318..02e6c58f97af63 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -120,8 +120,10 @@ details::{{api.composite.func_name}}({{api.composite.func_args}}); {%- set api_map = {} -%} {%- for api in apis -%} {%- do api_map.update({api.name: api}) -%} {%- endfor -%} {%- for api in apis %} - {%- if api.backward and api.backward in api_map and api.backward in vjp_white_list -%} + {%- if api.backward and api.backward in api_map and api.backward not in vjp_black_list -%} {%- set backward_api = api_map[api.backward] %} + {%- if backward_api is only_composite_op -%}{#- render nothing -#} + {%- else -%} {{sig(api.name, backward_api.name, backward_api.inputs, backward_api.attrs, backward_api.outputs)}} { {% filter indent(2, True) %} {{body(backward_api)}} @@ -129,6 +131,7 @@ details::{{api.composite.func_name}}({{api.composite.func_args}}); } {% endif %} + {% endif %} {% endfor %} diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 index 7f403661fea05e..a4209fb5e81748 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 @@ -20,11 +20,14 @@ std::vector> {{fwd_name}}_vjp({{common.params(inputs {%- set api_map = {} -%} {%- for api in apis -%} {%- do api_map.update({api.name: api}) -%} {%- endfor -%} {% for api in apis %} - {%- if api.backward and api.backward in api_map and api.backward in vjp_white_list -%} + {%- if api.backward and api.backward in api_map and api.backward not in vjp_black_list -%} {%- set backward_api = api_map[api.backward] -%} + {%- if backward_api is only_composite_op -%}{#- render nothing -#} + {%- else -%} {{sig(api.name, backward_api.name, backward_api.inputs, backward_api.attrs, backward_api.outputs)}} {% endif %} + {% endif %} {% endfor %} } // namespace primitive } // namespace paddle