Skip to content

Commit

Permalink
Remove auto to_pascal_case for args in op generator (#44350)
Browse files Browse the repository at this point in the history
* remove auto to_pascal_case for args in op generator

* fix yaml config
  • Loading branch information
zyfncg authored Jul 15, 2022
1 parent 270f25e commit 0dafbb0
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer
- id: sort-simple-yaml
files: (api|backward)\.yaml$
files: (api|backward|api_[a-z_]+)\.yaml$
- repo: local
hooks:
- id: clang-format
Expand Down
82 changes: 71 additions & 11 deletions paddle/phi/api/yaml/api_compat.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
- api : atan2
inputs :
x : X1
y : X2
{x : X1, y : X2}
outputs :
out : Out

- api : bernoulli
inputs :
x : X
outputs :
out : Out

- api : cholesky
inputs :
x : X
outputs :
out : Out

- api : cholesky_solve
inputs :
{x : X, y : Y}
outputs :
out : Out

- api : conv2d
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]

- api : cross
inputs : {x : X, y : Y}
inputs :
{x : X, y : Y}
attrs :
axis : dim
outputs :
Expand All @@ -26,17 +53,50 @@
outputs :
out : Out

- api : digamma
inputs :
x : X
outputs :
out : Out

- api : dist
inputs :
{x : X, y : Y}
outputs :
out : Out

- api : dot
inputs :
{x : X, y : Y}
outputs :
out : Out

- api : erf
inputs :
x : X
outputs :
out : Out

- api : mv
inputs :
{x : X, vec : Vec}
outputs :
out : Out

- api : poisson
inputs :
x : X
outputs :
out : Out

- api : trace
inputs :
x : Input
outputs :
out : Out

- api : conv2d
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
- api : trunc
inputs :
x : X
outputs :
out : Out
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/generator/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def to_sr_output_type(s):
# -------------- transform argument names from yaml to opmaker ------------
def to_opmaker_name(s):
if s.endswith("_grad"):
return 'GradVarName("{}")'.format(to_pascal_case(s[:-5]))
return 'GradVarName("{}")'.format(s[:-5])
else:
return '"{}"'.format(to_pascal_case(s))
return '"{}"'.format(s)


def to_opmaker_name_cstr(s):
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/api/yaml/generator/templates/operator_utils.c.j2
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,15 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
input_orig_names, output_orig_names) %}{# inline #}
{% if name in input_names %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name)]%}
Input("{{name_in_forward_orig | to_pascal_case}}")
Input("{{name_in_forward_orig}}")
{%- elif name in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name)]%}
Output("{{name | to_pascal_case}}")
Output("{{name}}")
{%- elif name.endswith("_grad") %}{# output grad#}
{% set name_in_forward = name[:-5] %}
{% if name_in_forward in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name_in_forward)] %}
OutputGrad("{{name_in_forward_orig | to_pascal_case}}")
OutputGrad("{{name_in_forward_orig}}")
{%- endif %}
{%- endif %}
{%- endmacro %}
Expand All @@ -376,11 +376,11 @@ OutputGrad("{{name_in_forward_orig | to_pascal_case}}")
{% if name[:-5] in input_names %}
{% set name_in_forward = name[:-5] %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name[:-5] | to_pascal_case}}")
InputGrad("{{name[:-5]}}")
{%- elif (name | to_input_name) in input_names %}
{% set name_in_forward = name | to_input_name %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name | to_input_name | to_pascal_case}}")
InputGrad("{{name | to_input_name}}")
{%- endif %}
{%- endmacro %}

Expand Down

0 comments on commit 0dafbb0

Please sign in to comment.