Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR OpTest Fix No.28】 fix test_fused_adam_op #62770

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
'dpsgd',
'embedding_grad_sparse',
'ftrl',
'fused_adam',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'fused_adam',
'fused_adam_',

'fused_batch_norm_act_',
'fused_bn_add_activation_',
'fused_elemwise_add_activation',
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -741,16 +741,15 @@
data_type : dtype
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : fused_adam_
- op : fused_adam
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- op : fused_adam
- op : fused_adam_

这里我们要和动态图算子保持一致,另外这个算子只有inplace版本,对PIR,所有下划线结尾的都是inplace算子

args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow)
output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()}
infer_meta :
func : FusedAdamInferMeta
kernel :
func : fused_adam
data_type : params
optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)
optional : skip_update, master_params, master_params_out

- op : fused_batch_norm_act
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str act_type)
Expand Down
5 changes: 2 additions & 3 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -560,16 +560,15 @@
func : full_with_tensor
data_type : dtype

- op : fused_adam_
- op : fused_adam
args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow)
output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()}
infer_meta :
func : FusedAdamInferMeta
kernel :
func : fused_adam
data_type : params
optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)
optional : skip_update, master_params, master_params_out

- op : fused_batch_norm_act
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str act_type)
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,15 @@
data_type : float
support_tensor : true

- op : fused_adam
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- op : fused_adam
- op : fused_adam_(fused_adam)

这样就会把旧IR的fused_adam算子映射到PIR的fused_adam_

inputs :
{params : Params, grads : Grads, learning_rate : LearningRate, moments1 : Moments1,
moments2 : Moments2, beta1_pows : Beta1Pows, beta2_pows : Beta2Pows, master_params : MasterParams,
skip_update : SkipUpdate}
outputs :
{params_out : ParamsOut, moments1_out : Moments1Out, moments2_out : Moments2Out,
beta1_pows_out : Beta1PowsOut, beta2_pows_out : Beta2PowsOut, master_params_out : MasterParamsOut}

- op : fused_attention
backward: fused_attention_grad
inputs:
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ test_fold_op
test_frame_op
test_ftrl_op
test_full_like_op
test_fused_adam_op
test_fused_attention_op
test_fused_attention_op_api
test_fused_bias_dropout_residual_layer_norm_op
Expand Down