diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 638f13fd729a8..1cde33191adc1 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -173,6 +173,7 @@ 'push_sparse_v2', 'push_sparse_v2_', 'partial_send', + 'partial_recv', 'nop', 'nop_', ] diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 6a655d9851ec5..19bb1c8579694 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1119,6 +1119,15 @@ backward : pad_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : partial_recv + args : (int ring_id = 0, int peer = 0, DataType dtype=DataType::FLOAT32, int[] out_shape= {}, bool use_calc_stream = false, int num = 1, int id = 0) + output : Tensor(out) + infer_meta : + func: PartialRecvInferMeta + kernel : + func : partial_recv + data_type : dtype + - op : pool2d args : (Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(out) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index cca683ed0bbef..fb5f5cf963992 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -94,7 +94,8 @@ const std::unordered_set LegacyOpList = { CReduceMinOp::name(), CReduceProdOp::name(), PushSparseV2Op::name(), - PartialSendOp::name()}; + PartialSendOp::name(), + PartialRecvOp::name()}; enum class AttrType { UNDEFINED = 0, diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 2c6129c30fb81..413d0e9e92c2a 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2459,6 +2459,10 @@ extra : attrs : [bool use_mkldnn = false] +- op : partial_recv + outputs : + out : Out + - op : partial_sum backward : partial_sum_grad extra : diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index d1bd204a682d9..5917a7a46b5ca 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -123,6 +123,18 @@ void GaussianInferMeta(const IntArray& shape, out->set_layout(DataLayout::NCHW); } +void PartialRecvInferMeta(int ring_id, + int peer, + DataType dtype, + const std::vector& out_shape, + bool use_calc_stream, + int num, + int id, + MetaTensor* out) { + out->set_dims(common::make_ddim(out_shape)); + out->set_dtype(dtype); +} + void RandpermInferMeta(int n, DataType dtype, MetaTensor* out) { out->set_dims(common::make_ddim({n})); out->set_dtype(dtype); diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 5eda8fc1a8461..b35b37acc7244 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -80,6 +80,15 @@ void RandpermInferMeta(int n, DataType dtype, MetaTensor* out); void RandintInferMeta( int low, int high, const IntArray& shape, DataType dtype, MetaTensor* out); +void PartialRecvInferMeta(int ring_id, + int peer, + DataType dtype, + const std::vector& out_shape, + bool use_calc_stream, + int num, + int id, + MetaTensor* out); + void PRecvInferMeta(int peer, DataType dtype, MetaTensor* out); void PRecvArrayInferMeta(int peer, diff --git a/test/ir/pir/translator/CMakeLists.txt b/test/ir/pir/translator/CMakeLists.txt index b7fd892ea35a5..43181806b14b5 100644 --- a/test/ir/pir/translator/CMakeLists.txt +++ b/test/ir/pir/translator/CMakeLists.txt @@ -11,6 +11,7 @@ list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_distributed_lookup_table_translate) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_distributed_fused_lamb_init) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_send_translator) +list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_recv_translator) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_max_translator) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_prod_translator) diff --git a/test/ir/pir/translator/test_partial_recv_translator.py b/test/ir/pir/translator/test_partial_recv_translator.py new file mode 100644 index 0000000000000..6f06ec4fad073 --- /dev/null +++ b/test/ir/pir/translator/test_partial_recv_translator.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import test_op_translator + +import paddle +from paddle.base.framework import ( + convert_np_dtype_to_dtype_, +) +from paddle.base.layer_helper import LayerHelper + + +class TestPartialRecvOpTranslator(test_op_translator.TestOpTranslator): + def append_op(self): + self.op_type = "partial_recv" + out = paddle.ones(shape=(1, 1), dtype='float32') + attrs = { + 'ring_id': 0, + 'peer': 0, + 'dtype': convert_np_dtype_to_dtype_(np.float32), + 'out_shape': out.shape, + 'use_calc_stream': False, + 'num': 1, + 'id': 0, + } + helper = LayerHelper(self.op_type) + helper.append_op( + type=self.op_type, + outputs={"Out": out}, + attrs=attrs, + ) + + def test_translator(self): + self.check() + + +if __name__ == "__main__": + unittest.main()