Skip to content

Commit

Permalink
【PIR Dist Op Reg No.1】 reg push_sparse_v2 (#60473)
Browse files Browse the repository at this point in the history
* code reg push_sparse_v2
  • Loading branch information
enkilee authored Jan 11, 2024
1 parent a576356 commit c62a554
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 1 deletion.
3 changes: 3 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_compat_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def insert_new_mutable_attributes(
"atol_tensor": "TolTensor",
"out": "Out",
}
op_arg_name_mappings['push_sparse_v2'].update(
{"out_grad_in": "Out@GRAD", "out_grad_out": "Out@GRAD"}
)

op_name_normailzer_template = env.get_template("op_compat_info.cc.j2")
with open(output_source_file, 'wt') as f:
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@
'uniform_random_batch_size_like',
'c_reduce_min',
'c_reduce_min_',
'push_sparse_v2',
'push_sparse_v2_',
]


Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,17 @@
func : prod
backward : prod_grad

- op : push_sparse_v2
args : (Tensor[] ids, Tensor[] w, Tensor[] out_grad_in, int embeddingdim = 11, int tableid = 0, str accessorclass = "", str ctrlabelname = "", int paddingid = 0, bool scalesparsegrad = true, str[] inputnames = {}, bool is_distributed = true)
output : Tensor[](out_grad_out){out_grad_in.size()}
infer_meta :
func : UnchangedMultiInferMeta
param : [out_grad_in]
kernel :
func : push_sparse_v2
data_type : out_grad_in
inplace: (out_grad_in -> out_grad_out)

- op : randint
args : (int low, int high, IntArray shape, DataType dtype=DataType::INT64, Place place={})
output : Tensor(out)
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ const std::unordered_set<std::string> LegacyOpList = {
paddle::onednn::dialect::LrnOp::name(),
paddle::onednn::dialect::LrnGradOp::name(),
#endif
CReduceMinOp::name()};
CReduceMinOp::name(),
PushSparseV2Op::name()};

enum class AttrType {
UNDEFINED = 0,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2461,6 +2461,14 @@
outputs :
out : Out

- op : push_sparse_v2
inputs :
{ x : Ids, W : w}
outputs :
out : Out
extra :
attrs : [int embeddingdim = 11, int tableid = 0, str accessorclass = "", str ctrlabelname = "", int paddingid = 0, bool scalesparsegrad = true, 'str[] inputnames = {}', bool is_distributed = true]

- op : put_along_axis
backward : put_along_axis_grad
inputs :
Expand Down

0 comments on commit c62a554

Please sign in to comment.