Skip to content

Commit

Permalink
support lodtensorarray for send/recv (#35279)
Browse files Browse the repository at this point in the history
* support lodtensorarray
  • Loading branch information
lilong12 authored Sep 3, 2021
1 parent 4e67cd1 commit b6adfd9
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 27 deletions.
33 changes: 19 additions & 14 deletions paddle/fluid/operators/collective/recv_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,26 @@ class RecvOpV2 : public framework::OperatorWithKernel {
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for recv_v2 op must be non-negative.", ring_id));
auto out_shape = ctx->Attrs().Get<std::vector<int>>("out_shape");
PADDLE_ENFORCE_GE(out_shape.size(), 1,
platform::errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(out_shape[i], 1,
platform::errors::InvalidArgument(
"The shape attribute for recv_v2 must be set "
"explicitly, but the %dth element is %d which "
"is less than 1.",
i, out_shape[i]));

if (ctx->GetOutputsVarType("Out").front() ==
framework::proto::VarType::LOD_TENSOR) {
auto out_shape = ctx->Attrs().Get<std::vector<int>>("out_shape");
PADDLE_ENFORCE_GE(
out_shape.size(), 1,
platform::errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(out_shape[i], 1,
platform::errors::InvalidArgument(
"The shape attribute for recv_v2 must be set "
"explicitly, but the %dth element is %d which "
"is less than 1.",
i, out_shape[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}

protected:
Expand Down
37 changes: 28 additions & 9 deletions paddle/fluid/operators/collective/recv_v2_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"The peer (%d) for recv_v2 op must be non-negative.", peer));

auto out = ctx.Output<framework::LoDTensor>("Out");
auto out_dims = out->dims();
auto numel = out->numel();
int data_type = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type);

gpuStream_t stream = nullptr;
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
Expand All @@ -56,14 +49,40 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
} else {
stream = comm->stream();
}

PADDLE_ENFORCE_LT(
peer, comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer, comm->nranks()));
out->mutable_data<T>(out_dims, place);

int data_type = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type);
ncclDataType_t dtype = platform::ToNCCLDataType(type);

auto *out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensorArray>()) {
auto out_array = out_var->GetMutable<framework::LoDTensorArray>();
for (size_t idx = 0; idx < out_array->size(); ++idx) {
VLOG(3) << "LodTensorArray: idx(" << idx << ")";
auto out = &out_array->at(idx);
auto out_dims = out->dims();
out->mutable_data<T>(out_dims, place, 0);
auto numel = out->numel();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv "
<< framework::product(out_dims) << " from " << peer;
}
return;
}

auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
auto out = ctx.Output<framework::LoDTensor>("Out");
auto out_dims = out->dims();
auto numel = out->numel();

out->mutable_data<T>(out_dims, place);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv "
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/operators/collective/send_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ class SendOpV2 : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::Variable* var = ctx.InputVar("X");
if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = var->Get<framework::LoDTensorArray>();
// NOTE(sandyhouse): Support an empty tensor array as Input.
// And set the kernel type is float.
if (t_arr.size() == 0) {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
Expand Down
22 changes: 19 additions & 3 deletions paddle/fluid/operators/collective/send_v2_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
int numel = x->numel();

int rid = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
rid, 0,
Expand All @@ -56,6 +53,25 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer, comm->nranks()));

auto* x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensorArray>()) {
auto& x_array = x_var->Get<framework::LoDTensorArray>();
for (size_t idx = 0; idx < x_array.size(); idx++) {
VLOG(3) << "LodTensorArray: idx(" << idx << ")";
auto& x = x_array.at(idx);
int numel = x.numel();
ncclDataType_t dtype = platform::ToNCCLDataType(x.type());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
x.data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " send "
<< framework::product(x.dims()) << " to " << peer;
}
return;
}
auto x = ctx.Input<framework::LoDTensor>("X");
int numel = x->numel();

ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
x->data<T>(), numel, dtype, peer, comm->comm(), stream));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_base import TestCollectiveRunnerBase, runtime_main

paddle.enable_static()


class TestCollectiveSendRecv(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0

def get_model(self, main_prog, startup_program):
ring_id = self.global_ring_id
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata",
shape=[10, 1000],
dtype='float64',
append_batch_size=False)
if self.rank == 0:
data1 = fluid.layers.assign(
np.array(
[[0, 1, 2]], dtype='float32'))
data2 = fluid.layers.assign(
np.array(
[[3, 4, 5]], dtype='float32'))
elif self.rank == 1:
data1 = fluid.layers.assign(
np.array(
[[3, 4, 5]], dtype='float32'))
data2 = fluid.layers.assign(
np.array(
[[0, 1, 2]], dtype='float32'))
tensor_array = fluid.layers.create_array(dtype='float32')
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
fluid.layers.array_write(data1, i, tensor_array)
fluid.layers.array_write(data2, i + 1, tensor_array)
if self.rank == 0:
main_prog.global_block().append_op(
type="send_v2",
inputs={'X': tensor_array},
attrs={
'ring_id': ring_id,
'peer': 1,
'use_calc_stream': True
})
else:
main_prog.global_block().append_op(
type="recv_v2",
outputs={'Out': tensor_array},
attrs={
'peer': 0,
'ring_id': ring_id,
'dtype': data1.dtype,
'out_shape': [1, 3],
'use_calc_stream': True,
})
return tensor_array


if __name__ == "__main__":
runtime_main(TestCollectiveSendRecv, "sendrecv_array", 0)
11 changes: 10 additions & 1 deletion python/paddle/fluid/tests/unittests/test_collective_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def check_with_place(self,
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"GLOG_v": "0",
"GLOG_v": "3",
"NCCL_P2P_DISABLE": "1"
}
required_envs.update(need_envs)
Expand Down Expand Up @@ -300,5 +300,14 @@ def check_with_place(self,
self.assertTrue(
np.allclose(
tr1_out, need_result2, rtol=1e-05, atol=1e-05))
elif col_type == "sendrecv_array":
need_result1 = np.array([[0, 1, 2]])
need_result2 = np.array([[3, 4, 5]])
self.assertTrue(
np.allclose(
tr1_out[0][0], need_result1, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out[0][1], need_result2, rtol=1e-05, atol=1e-05))
else:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def _setup_config(self):
def test_sendrecv(self):
self.check_with_place("collective_sendrecv_op.py", "sendrecv")

def test_sendrecv_array(self):
self.check_with_place("collective_sendrecv_op_array.py",
"sendrecv_array")


if __name__ == '__main__':
unittest.main()

0 comments on commit b6adfd9

Please sign in to comment.