Skip to content

Commit

Permalink
【PIR Dist Op Reg No.9】 reg partial_recv (#62412)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
enkilee authored Mar 11, 2024
1 parent 937decf commit 6865ec3
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 1 deletion.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
'push_sparse_v2',
'push_sparse_v2_',
'partial_send',
'partial_recv',
'nop',
'nop_',
]
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,15 @@
backward : pad_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : partial_recv
args : (int ring_id = 0, int peer = 0, DataType dtype=DataType::FLOAT32, int[] out_shape= {}, bool use_calc_stream = false, int num = 1, int id = 0)
output : Tensor(out)
infer_meta :
func: PartialRecvInferMeta
kernel :
func : partial_recv
data_type : dtype

- op : pool2d
args : (Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
output : Tensor(out)
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ const std::unordered_set<std::string> LegacyOpList = {
CReduceMinOp::name(),
CReduceProdOp::name(),
PushSparseV2Op::name(),
PartialSendOp::name()};
PartialSendOp::name(),
PartialRecvOp::name()};

enum class AttrType {
UNDEFINED = 0,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,10 @@
extra :
attrs : [bool use_mkldnn = false]

- op : partial_recv
outputs :
out : Out

- op : partial_sum
backward : partial_sum_grad
extra :
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ void GaussianInferMeta(const IntArray& shape,
out->set_layout(DataLayout::NCHW);
}

void PartialRecvInferMeta(int ring_id,
int peer,
DataType dtype,
const std::vector<int>& out_shape,
bool use_calc_stream,
int num,
int id,
MetaTensor* out) {
out->set_dims(common::make_ddim(out_shape));
out->set_dtype(dtype);
}

void RandpermInferMeta(int n, DataType dtype, MetaTensor* out) {
out->set_dims(common::make_ddim({n}));
out->set_dtype(dtype);
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/nullary.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ void RandpermInferMeta(int n, DataType dtype, MetaTensor* out);
void RandintInferMeta(
int low, int high, const IntArray& shape, DataType dtype, MetaTensor* out);

void PartialRecvInferMeta(int ring_id,
int peer,
DataType dtype,
const std::vector<int>& out_shape,
bool use_calc_stream,
int num,
int id,
MetaTensor* out);

void PRecvInferMeta(int peer, DataType dtype, MetaTensor* out);

void PRecvArrayInferMeta(int peer,
Expand Down
1 change: 1 addition & 0 deletions test/ir/pir/translator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST
test_distributed_lookup_table_translate)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_distributed_fused_lamb_init)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_send_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_recv_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_max_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_prod_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_random_routing_translator)
Expand Down
52 changes: 52 additions & 0 deletions test/ir/pir/translator/test_partial_recv_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import test_op_translator

import paddle
from paddle.base.framework import (
convert_np_dtype_to_dtype_,
)
from paddle.base.layer_helper import LayerHelper


class TestPartialRecvOpTranslator(test_op_translator.TestOpTranslator):
def append_op(self):
self.op_type = "partial_recv"
out = paddle.ones(shape=(1, 1), dtype='float32')
attrs = {
'ring_id': 0,
'peer': 0,
'dtype': convert_np_dtype_to_dtype_(np.float32),
'out_shape': out.shape,
'use_calc_stream': False,
'num': 1,
'id': 0,
}
helper = LayerHelper(self.op_type)
helper.append_op(
type=self.op_type,
outputs={"Out": out},
attrs=attrs,
)

def test_translator(self):
self.check()


if __name__ == "__main__":
unittest.main()

0 comments on commit 6865ec3

Please sign in to comment.