From 6865ec33965cbc1c2e294bcadaed7217ef5db184 Mon Sep 17 00:00:00 2001 From: cyberslack_lee Date: Mon, 11 Mar 2024 16:53:26 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PIR=20Dist=20Op=20Reg=20No.9=E3=80=91?= =?UTF-8?q?=20reg=20partial=5Frecv=20(#62412)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix * fix * fix * fix * fix * fix * fix --- .../pir/dialect/op_generator/ops_api_gen.py | 1 + paddle/fluid/pir/dialect/operator/ir/ops.yaml | 9 ++++ .../fluid/pir/dialect/operator/utils/utils.cc | 3 +- paddle/phi/api/yaml/op_compat.yaml | 4 ++ paddle/phi/infermeta/nullary.cc | 12 +++++ paddle/phi/infermeta/nullary.h | 9 ++++ test/ir/pir/translator/CMakeLists.txt | 1 + .../test_partial_recv_translator.py | 52 +++++++++++++++++++ 8 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 test/ir/pir/translator/test_partial_recv_translator.py diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index f488e0dfedc6e..37fe8b461095e 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -175,6 +175,7 @@ 'push_sparse_v2', 'push_sparse_v2_', 'partial_send', + 'partial_recv', 'nop', 'nop_', ] diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index bd94df82f17e1..632d9245fe66a 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1142,6 +1142,15 @@ backward : pad_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : partial_recv + args : (int ring_id = 0, int peer = 0, DataType dtype=DataType::FLOAT32, int[] out_shape= {}, bool use_calc_stream = false, int num = 1, int id = 0) + output : Tensor(out) + infer_meta : + func: PartialRecvInferMeta + kernel : + func : partial_recv + data_type : dtype + - op : pool2d args : (Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(out) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 32020dc874cf3..73dda0eb79bf6 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -95,7 +95,8 @@ const std::unordered_set LegacyOpList = { CReduceMinOp::name(), CReduceProdOp::name(), PushSparseV2Op::name(), - PartialSendOp::name()}; + PartialSendOp::name(), + PartialRecvOp::name()}; enum class AttrType { UNDEFINED = 0, diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 68c2241ebe266..218fa0488a5e0 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2465,6 +2465,10 @@ extra : attrs : [bool use_mkldnn = false] +- op : partial_recv + outputs : + out : Out + - op : partial_sum backward : partial_sum_grad extra : diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index d1bd204a682d9..5917a7a46b5ca 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -123,6 +123,18 @@ void GaussianInferMeta(const IntArray& shape, out->set_layout(DataLayout::NCHW); } +void PartialRecvInferMeta(int ring_id, + int peer, + DataType dtype, + const std::vector& out_shape, + bool use_calc_stream, + int num, + int id, + MetaTensor* out) { + out->set_dims(common::make_ddim(out_shape)); + out->set_dtype(dtype); +} + void RandpermInferMeta(int n, DataType dtype, MetaTensor* out) { out->set_dims(common::make_ddim({n})); out->set_dtype(dtype); diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 5eda8fc1a8461..b35b37acc7244 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -80,6 +80,15 @@ void RandpermInferMeta(int n, DataType dtype, MetaTensor* out); void RandintInferMeta( int low, int high, const IntArray& shape, DataType dtype, MetaTensor* out); +void PartialRecvInferMeta(int ring_id, + int peer, + DataType dtype, + const std::vector& out_shape, + bool use_calc_stream, + int num, + int id, + MetaTensor* out); + void PRecvInferMeta(int peer, DataType dtype, MetaTensor* out); void PRecvArrayInferMeta(int peer, diff --git a/test/ir/pir/translator/CMakeLists.txt b/test/ir/pir/translator/CMakeLists.txt index 53eb400c3d1b7..cf84e0de9938b 100644 --- a/test/ir/pir/translator/CMakeLists.txt +++ b/test/ir/pir/translator/CMakeLists.txt @@ -12,6 +12,7 @@ list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_distributed_lookup_table_translate) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_distributed_fused_lamb_init) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_send_translator) +list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_recv_translator) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_max_translator) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_prod_translator) list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_random_routing_translator) diff --git a/test/ir/pir/translator/test_partial_recv_translator.py b/test/ir/pir/translator/test_partial_recv_translator.py new file mode 100644 index 0000000000000..6f06ec4fad073 --- /dev/null +++ b/test/ir/pir/translator/test_partial_recv_translator.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import test_op_translator + +import paddle +from paddle.base.framework import ( + convert_np_dtype_to_dtype_, +) +from paddle.base.layer_helper import LayerHelper + + +class TestPartialRecvOpTranslator(test_op_translator.TestOpTranslator): + def append_op(self): + self.op_type = "partial_recv" + out = paddle.ones(shape=(1, 1), dtype='float32') + attrs = { + 'ring_id': 0, + 'peer': 0, + 'dtype': convert_np_dtype_to_dtype_(np.float32), + 'out_shape': out.shape, + 'use_calc_stream': False, + 'num': 1, + 'id': 0, + } + helper = LayerHelper(self.op_type) + helper.append_op( + type=self.op_type, + outputs={"Out": out}, + attrs=attrs, + ) + + def test_translator(self): + self.check() + + +if __name__ == "__main__": + unittest.main()