Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Oct 25, 2024
1 parent 0c10106 commit 0bdc1d3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
6 changes: 6 additions & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,12 @@
kernel :
func : mode_grad

- backward_op : mp_allreduce_sum_grad
forward : mp_allreduce_sum(Tensor x, int ring_id = 0) -> Tensor(out)
args : (Tensor out_grad, int ring_id = 0)
output : Tensor(x_grad)
invoke : c_identity(out_grad, ring_id, false, false)

- backward_op : multi_dot_grad
forward : multi_dot (Tensor[] x) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad)
Expand Down
6 changes: 0 additions & 6 deletions paddle/phi/ops/yaml/inconsistent/static_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,6 @@
func : minimum_grad
composite : minimum_grad(x, y, out_grad, x_grad, y_grad)

- backward_op : mp_allreduce_sum_grad
forward : mp_allreduce_sum(Tensor x, int ring_id = 0) -> Tensor(out)
args : (Tensor out_grad, int ring_id = 0)
output : Tensor(x_grad)
invoke : c_identity(out_grad, ring_id, false, false)

- backward_op : multiply_double_grad
forward : multiply_grad (Tensor x, Tensor y, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
Expand Down
12 changes: 0 additions & 12 deletions paddle/phi/ops/yaml/inconsistent/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -643,18 +643,6 @@
backward : minimum_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : mp_allreduce_sum
args : (Tensor x, int ring_id = 0)
output : Tensor(out)
infer_meta :
func : AllReduceInferMeta
param: [x]
kernel :
func : mp_allreduce_sum
param: [x]
backward: mp_allreduce_sum_grad
inplace: (x -> out)

- op : multiply
args : (Tensor x, Tensor y)
output : Tensor
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3444,6 +3444,18 @@
inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out)
traits : pir::SideEffectTrait, paddle::dialect::ForwardOnlyTrait

- op : mp_allreduce_sum
args : (Tensor x, int ring_id = 0)
output : Tensor(out)
infer_meta :
func : AllReduceInferMeta
param: [x]
kernel :
func : mp_allreduce_sum
param: [x]
backward: mp_allreduce_sum_grad
inplace: (x -> out)

- op : multi_dot
args : (Tensor[] x)
output : Tensor
Expand Down

0 comments on commit 0bdc1d3

Please sign in to comment.