From 645f4d1546cdabf2913b261de2613cbd5c452a3a Mon Sep 17 00:00:00 2001 From: xiayanming <41795079@qq.com> Date: Tue, 28 Sep 2021 13:54:46 +0800 Subject: [PATCH] [hybrid] seed and dropout op support force-cpu (#35820) * [HIP] fix op not support AMD GPU bug, the flag PADDLE_WITH_ROCM is invalid * [HIP] fix op not support AMD GPU bug, the flag PADDLE_WITH_ROCM is invalid * [HIP] fix op not support AMD GPU bug * [hybrid] seed and dropout op support force-cpu * [hybrid] seed and dropout op support force-cpu * [hybrid] seed and dropout op support force-cpu * [hybrid] seed and dropout op support force-cpu * [hybrid] seed and dropout op support force-cpu * [hybrid] fix seed ci failed issue * add AsExtra for force_cpu of seed op --- paddle/fluid/operators/dropout_impl.cu.h | 3 + paddle/fluid/operators/dropout_op.cc | 13 ++++ paddle/fluid/operators/seed_op.cc | 18 +++++ paddle/fluid/operators/seed_op.cu | 30 +++++--- paddle/fluid/operators/seed_op.h | 1 + python/paddle/fluid/backward.py | 9 ++- .../fluid/tests/unittests/test_dropout_op.py | 69 +++++++++++++++++++ .../fluid/tests/unittests/test_seed_op.py | 4 +- 8 files changed, 135 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 4261a5f2534c8..7a93d2db0dd1c 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -205,6 +205,9 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); seed_data = static_cast(seed_cpu_tensor.data()[0]); increment = offset; + } else if (seed && platform::is_cpu_place(seed->place())) { + seed_data = *(seed->data()); + increment = offset; } else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) { auto seed_offset = gen_cuda->IncrementOffset(offset); seed_data = seed_offset.first; diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 9700b9a2f7a1c..cbfb795d6a23e 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -42,6 +42,19 @@ class DropoutOp : public framework::OperatorWithKernel { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "Seed") { + VLOG(10) << "var_name:" << var_name + << " does not need to transform in dropout op"; + return expected_kernel_type; + } + + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/seed_op.cc b/paddle/fluid/operators/seed_op.cc index 2f3e4c9ba88c3..32daa8c3934ae 100644 --- a/paddle/fluid/operators/seed_op.cc +++ b/paddle/fluid/operators/seed_op.cc @@ -39,6 +39,12 @@ class SeedOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddOutput("Out", "The output of seed op."); AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddAttr("force_cpu", + "(bool, default false) Force fill output variable to cpu " + "memory. Otherwise, fill output variable to the running " + "device") + .SetDefault(false) + .AsExtra(); AddComment(R"DOC( Seed Operator. )DOC"); @@ -55,3 +61,15 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( seed, ops::CPUSeedKernel); + +/* ========================== register checkpoint ===========================*/ +REGISTER_OP_VERSION(seed) + .AddCheckpoint( + R"ROC( + Upgrade seed add a new attribute [force_cpu])ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "force_cpu", + "If true, Force fill output variable to cpu." + "memory. Otherwise, fill output variable to the running " + "device", + false)); diff --git a/paddle/fluid/operators/seed_op.cu b/paddle/fluid/operators/seed_op.cu index c84407ba52dfd..4593b88019621 100644 --- a/paddle/fluid/operators/seed_op.cu +++ b/paddle/fluid/operators/seed_op.cu @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/seed_op.h" namespace paddle { @@ -20,10 +21,10 @@ namespace operators { template class GPUSeedKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* out = context.Output("Out"); - auto* out_data = out->mutable_data(context.GetPlace()); + void Compute(const framework::ExecutionContext &context) const override { + auto *out = context.Output("Out"); int user_seed = context.Attr("seed"); + auto force_cpu = context.Attr("force_cpu"); std::random_device rnd; int seed; if (user_seed != 0) { @@ -31,11 +32,24 @@ class GPUSeedKernel : public framework::OpKernel { } else { seed = rnd(); } - auto target_gpu_place = - BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()); - auto stream = context.cuda_device_context().stream(); - memory::Copy(target_gpu_place, out_data, platform::CPUPlace(), &seed, - sizeof(int), stream); + + bool cpu_place = force_cpu || context.GetPlace() == platform::CPUPlace(); + if (cpu_place) { + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(context.GetPlace()); + out->mutable_data(platform::CPUPlace()); + math::SetConstant functor; + functor(reinterpret_cast(dev_ctx), + out, static_cast(seed)); + } else { + auto *out_data = out->mutable_data(context.GetPlace()); + auto target_gpu_place = + BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()); + auto stream = context.cuda_device_context().stream(); + memory::Copy(target_gpu_place, out_data, platform::CPUPlace(), &seed, + sizeof(int), stream); + } } }; diff --git a/paddle/fluid/operators/seed_op.h b/paddle/fluid/operators/seed_op.h index f8b513fca4824..671f397d4eaff 100644 --- a/paddle/fluid/operators/seed_op.h +++ b/paddle/fluid/operators/seed_op.h @@ -14,6 +14,7 @@ #pragma once #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 8bf27f6d2fd98..7aa3c888f2ad1 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -197,13 +197,18 @@ def modify_forward_desc_for_recompute(self): if op.desc.has_attr(op_device_attr_name): op_device = op.desc.attr(op_device_attr_name) + # Setting the force_cpu of seed to true will make the output of seed in cpu memory, + # reduce the synchronous copy from GPU to CPU in dropout, and reduce the communication hang added_op = self.block._insert_op( index=op.idx, type='seed', inputs={}, outputs={'Out': [added_var]}, - attrs={'seed': seed, - 'op_device': op_device}) + attrs={ + 'seed': seed, + 'op_device': op_device, + 'force_cpu': True + }) self.ops.insert(op_idx, added_op) # modify dropout op desc so that it accept a seed var as input op.desc.set_input("Seed", [var_unique_name]) diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 89755d0365f2c..396d55b3d0a8b 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -232,6 +232,75 @@ def init_test_case(self): self.fix_seed = False +class TestDropoutOpWithSeedOnCPUPlace(unittest.TestCase): + def test_seed_cpu_place(self): + paddle.enable_static() + main_program = Program() + with program_guard(main_program): + seed_input_name = "tensor@SeedInput" + x_var_name = "tensor@X" + x_out_var = "tensor@XOut" + + mask_var_name = "tensor@Mask" + seed_input_var = main_program.global_block().create_var( + name=seed_input_name, + shape=[1], + dtype='int32', + persistable=False, + stop_gradient=True) + x_out_var = main_program.global_block().create_var( + name=x_out_var, + shape=[40, 40], + dtype='float32', + persistable=False, + stop_gradient=True) + x_var = main_program.global_block().create_var( + name=x_var_name, + shape=[40, 40], + dtype='float32', + persistable=False, + stop_gradient=True) + mask_var = main_program.global_block().create_var( + name=mask_var_name, + shape=[1], + dtype='int', + persistable=False, + stop_gradient=True) + + main_program.global_block().append_op( + type="fill_constant", + outputs={"Out": x_var_name}, + attrs={ + "shape": [40, 40], + "dtype": x_var.dtype, + "value": 1.0, + "place_type": 0 + }) + main_program.global_block().append_op( + type='seed', + inputs={}, + outputs={'Out': seed_input_var}, + attrs={'seed': 1, + 'force_cpu': True}) + main_program.global_block().append_op( + type='dropout', + inputs={'X': x_var, + 'Seed': seed_input_var}, + attrs={'dropout_prob': 0.}, + outputs={'Out': x_out_var, + 'Mask': mask_var}) + place = fluid.CPUPlace() + if core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + x_out, mask_out = exe.run( + main_program, + feed={}, + fetch_list=[x_out_var.name, mask_var.name]) + x_in_np = np.ones([40, 40]).astype("float32") + self.assertTrue(np.allclose(x_out, x_in_np)) + + class TestDropoutOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/test_seed_op.py b/python/paddle/fluid/tests/unittests/test_seed_op.py index 7d6705f72569b..08478d7140d43 100644 --- a/python/paddle/fluid/tests/unittests/test_seed_op.py +++ b/python/paddle/fluid/tests/unittests/test_seed_op.py @@ -25,7 +25,7 @@ def setUp(self): self.op_type = "seed" self.inputs = {} self.attrs = {"seed": 123} - self.outputs = {"Out": np.asarray((123)).astype('int32')} + self.outputs = {"Out": np.asarray((123)).astype('int')} def test_check_output(self): self.check_output() @@ -36,7 +36,7 @@ def setUp(self): self.op_type = "seed" self.inputs = {} self.attrs = {"seed": 0} - self.outputs = {"Out": np.asarray((123)).astype('int32')} + self.outputs = {"Out": np.asarray((123)).astype('int')} def test_check_output(self): self.check_output(no_check_set=["Out"])