Skip to content

Commit

Permalink
Update ops.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
xingmingyyj authored Jan 12, 2024
1 parent e59bf47 commit 99b23cc
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1479,15 +1479,6 @@
optional: dropout1_seed, dropout2_seed, linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias, ln2_mean, ln2_variance, ln1_mean, ln1_variance, ln1_out
backward: fused_feedforward_grad

- op: match_matrix_tensor
args: (Tensor x, Tensor y, Tensor w, int dim_t=1)
output: Tensor(out), Tensor(tmp)
infer_meta:
func: MatchMatrixTensorInferMeta
kernel:
func: match_matrix_tensor
backward: match_matrix_tensor_grad

- op: lars_momentum
args: (Tensor param, Tensor velocity, Tensor grad, Tensor learning_rate, Tensor master_param, float mu, float lars_coeff=0.001f, float[] lars_weight_decay={0.0005}, float epsilon=0, bool multi_precision=false, float rescale_grad=1.0f)
output: Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out)
Expand All @@ -1500,6 +1491,15 @@
data_type: param
optional: master_param, master_param_out

- op: match_matrix_tensor
args: (Tensor x, Tensor y, Tensor w, int dim_t=1)
output: Tensor(out), Tensor(tmp)
infer_meta:
func: MatchMatrixTensorInferMeta
kernel:
func: match_matrix_tensor
backward: match_matrix_tensor_grad

- op: nce
args: (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false)
output: Tensor(cost), Tensor(sample_logits), Tensor(sample_labels)
Expand Down

0 comments on commit 99b23cc

Please sign in to comment.