Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove auto to_pascal_case for args in op generator #44350

Merged
merged 2 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer
- id: sort-simple-yaml
files: (api|backward)\.yaml$
files: (api|backward|api_[a-z_]+)\.yaml$
- repo: local
hooks:
- id: clang-format
Expand Down
82 changes: 71 additions & 11 deletions paddle/phi/api/yaml/api_compat.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
- api : atan2
inputs :
x : X1
y : X2
{x : X1, y : X2}
outputs :
out : Out

- api : bernoulli
inputs :
x : X
outputs :
out : Out

- api : cholesky
inputs :
x : X
outputs :
out : Out

- api : cholesky_solve
inputs :
{x : X, y : Y}
outputs :
out : Out

- api : conv2d
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]

- api : cross
inputs : {x : X, y : Y}
inputs :
{x : X, y : Y}
attrs :
axis : dim
outputs :
Expand All @@ -26,17 +53,50 @@
outputs :
out : Out

- api : digamma
inputs :
x : X
outputs :
out : Out

- api : dist
inputs :
{x : X, y : Y}
outputs :
out : Out

- api : dot
inputs :
{x : X, y : Y}
outputs :
out : Out

- api : erf
inputs :
x : X
outputs :
out : Out

- api : mv
inputs :
{x : X, vec : Vec}
outputs :
out : Out

- api : poisson
inputs :
x : X
outputs :
out : Out

- api : trace
inputs :
x : Input
outputs :
out : Out

- api : conv2d
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
- api : trunc
inputs :
x : X
outputs :
out : Out
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/generator/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def to_sr_output_type(s):
# -------------- transform argument names from yaml to opmaker ------------
def to_opmaker_name(s):
if s.endswith("_grad"):
return 'GradVarName("{}")'.format(to_pascal_case(s[:-5]))
return 'GradVarName("{}")'.format(s[:-5])
else:
return '"{}"'.format(to_pascal_case(s))
return '"{}"'.format(s)


def to_opmaker_name_cstr(s):
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/api/yaml/generator/templates/operator_utils.c.j2
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,15 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
input_orig_names, output_orig_names) %}{# inline #}
{% if name in input_names %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name)]%}
Input("{{name_in_forward_orig | to_pascal_case}}")
Input("{{name_in_forward_orig}}")
{%- elif name in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name)]%}
Output("{{name | to_pascal_case}}")
Output("{{name}}")
{%- elif name.endswith("_grad") %}{# output grad#}
{% set name_in_forward = name[:-5] %}
{% if name_in_forward in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name_in_forward)] %}
OutputGrad("{{name_in_forward_orig | to_pascal_case}}")
OutputGrad("{{name_in_forward_orig}}")
{%- endif %}
{%- endif %}
{%- endmacro %}
Expand All @@ -376,11 +376,11 @@ OutputGrad("{{name_in_forward_orig | to_pascal_case}}")
{% if name[:-5] in input_names %}
{% set name_in_forward = name[:-5] %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name[:-5] | to_pascal_case}}")
InputGrad("{{name[:-5]}}")
{%- elif (name | to_input_name) in input_names %}
{% set name_in_forward = name | to_input_name %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name | to_input_name | to_pascal_case}}")
InputGrad("{{name | to_input_name}}")
{%- endif %}
{%- endmacro %}

Expand Down