Skip to content

Commit

Permalink
[hybrid] static model parallel dropout support deterministic RandomSe…
Browse files Browse the repository at this point in the history
…edGenerator (PaddlePaddle#36228)
  • Loading branch information
wangxicoding committed Oct 25, 2021
1 parent 17b2d71 commit aa33bc3
Show file tree
Hide file tree
Showing 13 changed files with 354 additions and 32 deletions.
37 changes: 37 additions & 0 deletions paddle/fluid/framework/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,43 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() {
return default_cpu_generator;
}

using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>;

static RNGMap& GetRandomSeedGeneratorMap() {
static auto random_seed_generator_map = RNGMap();
return random_seed_generator_map;
}

const std::shared_ptr<Generator>& SetRandomSeedGenerator(
const std::string& name, uint64_t seed) {
auto& rng_map = GetRandomSeedGeneratorMap();
auto iter = rng_map.find(name);
PADDLE_ENFORCE_EQ(iter == rng_map.end(), true,
platform::errors::AlreadyExists(
"%s RandomSeedGenerator is already exist", name));

auto generator = std::make_shared<Generator>(seed);
bool emplace_success = rng_map.emplace(name, generator).second;
PADDLE_ENFORCE_EQ(
emplace_success, true,
platform::errors::PermissionDenied(
"SetRandomSeedGenerator cannot emplace %s RandomSeedGenerator",
name));
return rng_map[name];
}

const std::shared_ptr<Generator>& GetRandomSeedGenerator(
const std::string& name) {
auto& rng_map = GetRandomSeedGeneratorMap();
auto iter = rng_map.find(name);
PADDLE_ENFORCE_EQ(iter != rng_map.end(), true,
platform::errors::NotFound(
"%s RandomSeedGenerator is not found, please "
"use `set_random_seed_generator` to set rng first",
name));
return iter->second;
}

std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine() {
static auto op_default_cpu_engine = std::make_shared<std::mt19937_64>();
return op_default_cpu_engine;
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,11 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
int64_t device_id = -1);

const std::shared_ptr<Generator>& SetRandomSeedGenerator(
const std::string& name, uint64_t seed);

const std::shared_ptr<Generator>& GetRandomSeedGenerator(
const std::string& name);

} // namespace framework
} // namespace paddle
10 changes: 3 additions & 7 deletions paddle/fluid/operators/dropout_impl_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);

if ((seed) && platform::is_gpu_place(seed->place())) {
if (seed) {
framework::Tensor seed_cpu_tensor;
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
Expand All @@ -39,12 +39,8 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
*seed_data = seed_offset.first;
*increment = seed_offset.second;
} else {
if (seed) {
*seed_data = *(seed->data<int>());
} else {
std::random_device rnd;
*seed_data = is_fix_seed ? seed_val : rnd();
}
std::random_device rnd;
*seed_data = is_fix_seed ? seed_val : rnd();
*increment = offset;
}
}
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/seed_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ class SeedOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddOutput("Out", "The output of seed op.");
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddAttr<bool>("deterministic",
"(bool, default false) Whether to use deterministic "
"RandomSeedGenerator which "
"generate by `set_random_seed_generator`")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"rng_name",
"use deterministic RandomSeedGenerator which name is `rng_name`")
.SetDefault("")
.AsExtra();
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
Expand Down
11 changes: 2 additions & 9 deletions paddle/fluid/operators/seed_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,9 @@ class GPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<Tensor>("Out");
int user_seed = context.Attr<int>("seed");
auto force_cpu = context.Attr<bool>("force_cpu");
std::random_device rnd;
int seed;
if (user_seed != 0) {
seed = user_seed;
} else {
seed = rnd();
}
int seed = get_seed(context);

auto force_cpu = context.Attr<bool>("force_cpu");
bool cpu_place = force_cpu || context.GetPlace() == platform::CPUPlace();
if (cpu_place) {
platform::DeviceContextPool &pool =
Expand Down
34 changes: 24 additions & 10 deletions paddle/fluid/operators/seed_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,45 @@
// limitations under the License.
#pragma once

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class CPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
int user_seed = context.Attr<int>("seed");
static int get_seed(const framework::ExecutionContext& context) {
int user_seed = context.Attr<int>("seed");
bool deterministic = context.Attr<bool>("deterministic");

int seed = 0;
if (!deterministic) {
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
std::random_device rnd;
int seed;
if (user_seed != 0) {
seed = user_seed;
} else {
std::random_device rnd;
seed = rnd();
}
out_data[0] = seed;
} else {
std::string name = context.Attr<std::string>("rng_name");
auto rng = framework::GetRandomSeedGenerator(name);
do { // NOTE(wangxi): cpu dropout will use random seed if seed == 0
seed = static_cast<int>(rng->Random64());
} while (seed == 0);
}
return seed;
}

template <typename DeviceContext, typename T>
class CPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
out_data[0] = get_seed(context);
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/generator_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ void BindGenerator(py::module* m_ptr) {
&framework::Generator::SetIsInitPy);
m.def("default_cpu_generator", &framework::DefaultCPUGenerator);
m.def("default_cuda_generator", &framework::GetDefaultCUDAGenerator);
m.def("set_random_seed_generator", &framework::SetRandomSeedGenerator);
m.def("get_random_seed_generator", &framework::GetRandomSeedGenerator);
}
} // namespace pybind
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
import paddle
import contextlib
import numpy as np
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import in_dygraph_mode, default_main_program
from paddle.fluid.layer_helper import LayerHelper

__all__ = []

Expand Down Expand Up @@ -93,3 +98,135 @@ def model_parallel_random_seed(seed=None):
RNG_STATE_TRACKER.reset()
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
paddle.seed(global_seed)


def determinate_seed(rng_name):
assert rng_name is not None and rng_name != ""
helper = LayerHelper('seed', **locals())
out = helper.create_variable_for_type_inference(dtype=paddle.int32)
# set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
helper.append_op(
type='seed',
outputs={'Out': out},
attrs={'deterministic': True,
'rng_name': rng_name,
'force_cpu': True})
return out


def dropout(x,
p=0.5,
axis=None,
rng_name=None,
training=True,
mode="upscale_in_train",
name=None):
"""
Dropout is a regularization technique for reducing overfitting by preventing
neuron co-adaption during training. The dropout operator randomly sets the
outputs of some units to zero, while upscale others according to the given
dropout probability.
Args:
x (Tensor): The input tensor. The data type is float32 or float64.
p (float|int): Probability of setting units to zero. Default 0.5.
axis (int|list|tuple): The axis along which the dropout is performed. Default None.
rng_name (str): The random seed generator name, which used to obtain deterministic results.
training (bool): A flag indicating whether it is in train phrase or not. Default True.
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - dropout_prob )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - dropout_prob)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor representing the dropout, has same shape and data type as `x` .
Examples:
We use ``p=0.5`` in the following description for simplicity.
1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly.
.. code-block:: text
Let's see a simple case when x is a 2d tensor with shape 2*3:
[[1 2 3]
[4 5 6]]
we generate mask with the same shape as x, which is 2*3. The value of mask is
sampled from a Bernoulli distribution randomly. For example, we may get such mask:
[[0 1 0]
[1 0 1]]
So the output is obtained from elementwise multiply of x and mask:
[[0 2 0]
[4 0 6]]
Using default setting, i.e. ``mode='upscale_in_train'`` ,
if in training phase, the final upscale output is:
[[0 4 0 ]
[8 0 12]]
if in test phase, the output is the same as input:
[[1 2 3]
[4 5 6]]
we can also set ``mode='downscale_in_infer'`` , then
if in training phase, the final output is:
[[0 2 0]
[4 0 6]]
if in test phase, the scale output is:
[[0.5 1. 1.5]
[2. 2.5 3. ]]
"""
if rng_name is None:
return paddle.nn.functional.dropout(x, p, axis, training, mode, name)

# fast return for p == 0
if p == 0: return x

assert isinstance(p, (float, int)), \
TypeError("p argument should be a number")
assert 0 <= p <= 1, ValueError("p argument should between 0 and 1")
assert mode in ('downscale_in_infer', 'upscale_in_train'), \
ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")

assert axis is None, \
TypeError("unsupport axis when using random seed generator")

mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer

# dygraph using tracker, doesn't need determinate seed
if in_dygraph_mode():
out, mask = _C_ops.dropout(x, 'dropout_prob', p, 'is_test',
not training, 'fix_seed', False, 'seed', 0,
'dropout_implementation', mode)
return out

seed = determinate_seed(rng_name)

helper = LayerHelper('dropout', **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'dropout')

out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)

helper.append_op(
type='dropout',
inputs={'X': [x],
'Seed': seed},
outputs={'Out': [out],
'Mask': [mask]},
attrs={
'dropout_prob': p,
'is_test': not training,
'dropout_implementation': mode,
})
return out
6 changes: 5 additions & 1 deletion python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,15 @@ def modify_forward_desc_for_recompute(self):
return

op_idx = 0
while (op_idx < len(self.ops)):
while op_idx < len(self.ops):
op = self.ops[op_idx]
if op.desc.type() != "dropout":
op_idx += 1
continue
# already insert seed op before dropout
if op.input('Seed') is not None and len(op.input('Seed')) == 1:
op_idx += 1
continue
# add a seed op so that the two dropout op can generate same output
op_unique_name = unique_name.generate("seed")
var_unique_name = unique_name.generate_with_ignorable_key(".".join(
Expand Down
Loading

0 comments on commit aa33bc3

Please sign in to comment.