Skip to content

Commit

Permalink
[hybrid] static model parallel dropout support deterministic RandomSe…
Browse files Browse the repository at this point in the history
…edGenerator (#36228)
  • Loading branch information
wangxicoding committed Oct 25, 2021
1 parent 17b2d71 commit aa33bc3
Show file tree
Hide file tree
Showing 13 changed files with 354 additions and 32 deletions.
37 changes: 37 additions & 0 deletions paddle/fluid/framework/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,43 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() {
return default_cpu_generator;
}

using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>;

static RNGMap& GetRandomSeedGeneratorMap() {
static auto random_seed_generator_map = RNGMap();
return random_seed_generator_map;
}

const std::shared_ptr<Generator>& SetRandomSeedGenerator(
const std::string& name, uint64_t seed) {
auto& rng_map = GetRandomSeedGeneratorMap();
auto iter = rng_map.find(name);
PADDLE_ENFORCE_EQ(iter == rng_map.end(), true,
platform::errors::AlreadyExists(
"%s RandomSeedGenerator is already exist", name));

auto generator = std::make_shared<Generator>(seed);
bool emplace_success = rng_map.emplace(name, generator).second;
PADDLE_ENFORCE_EQ(
emplace_success, true,
platform::errors::PermissionDenied(
"SetRandomSeedGenerator cannot emplace %s RandomSeedGenerator",
name));
return rng_map[name];
}

const std::shared_ptr<Generator>& GetRandomSeedGenerator(
const std::string& name) {
auto& rng_map = GetRandomSeedGeneratorMap();
auto iter = rng_map.find(name);
PADDLE_ENFORCE_EQ(iter != rng_map.end(), true,
platform::errors::NotFound(
"%s RandomSeedGenerator is not found, please "
"use `set_random_seed_generator` to set rng first",
name));
return iter->second;
}

std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine() {
static auto op_default_cpu_engine = std::make_shared<std::mt19937_64>();
return op_default_cpu_engine;
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,11 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
int64_t device_id = -1);

const std::shared_ptr<Generator>& SetRandomSeedGenerator(
const std::string& name, uint64_t seed);

const std::shared_ptr<Generator>& GetRandomSeedGenerator(
const std::string& name);

} // namespace framework
} // namespace paddle
10 changes: 3 additions & 7 deletions paddle/fluid/operators/dropout_impl_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);

if ((seed) && platform::is_gpu_place(seed->place())) {
if (seed) {
framework::Tensor seed_cpu_tensor;
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
Expand All @@ -39,12 +39,8 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
*seed_data = seed_offset.first;
*increment = seed_offset.second;
} else {
if (seed) {
*seed_data = *(seed->data<int>());
} else {
std::random_device rnd;
*seed_data = is_fix_seed ? seed_val : rnd();
}
std::random_device rnd;
*seed_data = is_fix_seed ? seed_val : rnd();
*increment = offset;
}
}
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/seed_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ class SeedOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddOutput("Out", "The output of seed op.");
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddAttr<bool>("deterministic",
"(bool, default false) Whether to use deterministic "
"RandomSeedGenerator which "
"generate by `set_random_seed_generator`")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"rng_name",
"use deterministic RandomSeedGenerator which name is `rng_name`")
.SetDefault("")
.AsExtra();
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
Expand Down
11 changes: 2 additions & 9 deletions paddle/fluid/operators/seed_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,9 @@ class GPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<Tensor>("Out");
int user_seed = context.Attr<int>("seed");
auto force_cpu = context.Attr<bool>("force_cpu");
std::random_device rnd;
int seed;
if (user_seed != 0) {
seed = user_seed;
} else {
seed = rnd();
}
int seed = get_seed(context);

auto force_cpu = context.Attr<bool>("force_cpu");
bool cpu_place = force_cpu || context.GetPlace() == platform::CPUPlace();
if (cpu_place) {
platform::DeviceContextPool &pool =
Expand Down
34 changes: 24 additions & 10 deletions paddle/fluid/operators/seed_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,45 @@
// limitations under the License.
#pragma once

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class CPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
int user_seed = context.Attr<int>("seed");
static int get_seed(const framework::ExecutionContext& context) {
int user_seed = context.Attr<int>("seed");
bool deterministic = context.Attr<bool>("deterministic");

int seed = 0;
if (!deterministic) {
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
std::random_device rnd;
int seed;
if (user_seed != 0) {
seed = user_seed;
} else {
std::random_device rnd;
seed = rnd();
}
out_data[0] = seed;
} else {
std::string name = context.Attr<std::string>("rng_name");
auto rng = framework::GetRandomSeedGenerator(name);
do { // NOTE(wangxi): cpu dropout will use random seed if seed == 0
seed = static_cast<int>(rng->Random64());
} while (seed == 0);
}
return seed;
}

template <typename DeviceContext, typename T>
class CPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
out_data[0] = get_seed(context);
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/generator_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ void BindGenerator(py::module* m_ptr) {
&framework::Generator::SetIsInitPy);
m.def("default_cpu_generator", &framework::DefaultCPUGenerator);
m.def("default_cuda_generator", &framework::GetDefaultCUDAGenerator);
m.def("set_random_seed_generator", &framework::SetRandomSeedGenerator);
m.def("get_random_seed_generator", &framework::GetRandomSeedGenerator);
}
} // namespace pybind
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
import paddle
import contextlib
import numpy as np
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import in_dygraph_mode, default_main_program
from paddle.fluid.layer_helper import LayerHelper

__all__ = []

Expand Down Expand Up @@ -93,3 +98,135 @@ def model_parallel_random_seed(seed=None):
RNG_STATE_TRACKER.reset()
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
paddle.seed(global_seed)


def determinate_seed(rng_name):
assert rng_name is not None and rng_name != ""
helper = LayerHelper('seed', **locals())
out = helper.create_variable_for_type_inference(dtype=paddle.int32)
# set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
helper.append_op(
type='seed',
outputs={'Out': out},
attrs={'deterministic': True,
'rng_name': rng_name,
'force_cpu': True})
return out


def dropout(x,
p=0.5,
axis=None,
rng_name=None,
training=True,
mode="upscale_in_train",
name=None):
"""
Dropout is a regularization technique for reducing overfitting by preventing
neuron co-adaption during training. The dropout operator randomly sets the
outputs of some units to zero, while upscale others according to the given
dropout probability.
Args:
x (Tensor): The input tensor. The data type is float32 or float64.
p (float|int): Probability of setting units to zero. Default 0.5.
axis (int|list|tuple): The axis along which the dropout is performed. Default None.
rng_name (str): The random seed generator name, which used to obtain deterministic results.
training (bool): A flag indicating whether it is in train phrase or not. Default True.
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - dropout_prob )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - dropout_prob)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor representing the dropout, has same shape and data type as `x` .
Examples:
We use ``p=0.5`` in the following description for simplicity.
1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly.
.. code-block:: text
Let's see a simple case when x is a 2d tensor with shape 2*3:
[[1 2 3]
[4 5 6]]
we generate mask with the same shape as x, which is 2*3. The value of mask is
sampled from a Bernoulli distribution randomly. For example, we may get such mask:
[[0 1 0]
[1 0 1]]
So the output is obtained from elementwise multiply of x and mask:
[[0 2 0]
[4 0 6]]
Using default setting, i.e. ``mode='upscale_in_train'`` ,
if in training phase, the final upscale output is:
[[0 4 0 ]
[8 0 12]]
if in test phase, the output is the same as input:
[[1 2 3]
[4 5 6]]
we can also set ``mode='downscale_in_infer'`` , then
if in training phase, the final output is:
[[0 2 0]
[4 0 6]]
if in test phase, the scale output is:
[[0.5 1. 1.5]
[2. 2.5 3. ]]
"""
if rng_name is None:
return paddle.nn.functional.dropout(x, p, axis, training, mode, name)

# fast return for p == 0
if p == 0: return x

assert isinstance(p, (float, int)), \
TypeError("p argument should be a number")
assert 0 <= p <= 1, ValueError("p argument should between 0 and 1")
assert mode in ('downscale_in_infer', 'upscale_in_train'), \
ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")

assert axis is None, \
TypeError("unsupport axis when using random seed generator")

mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer

# dygraph using tracker, doesn't need determinate seed
if in_dygraph_mode():
out, mask = _C_ops.dropout(x, 'dropout_prob', p, 'is_test',
not training, 'fix_seed', False, 'seed', 0,
'dropout_implementation', mode)
return out

seed = determinate_seed(rng_name)

helper = LayerHelper('dropout', **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'dropout')

out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)

helper.append_op(
type='dropout',
inputs={'X': [x],
'Seed': seed},
outputs={'Out': [out],
'Mask': [mask]},
attrs={
'dropout_prob': p,
'is_test': not training,
'dropout_implementation': mode,
})
return out
6 changes: 5 additions & 1 deletion python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,15 @@ def modify_forward_desc_for_recompute(self):
return

op_idx = 0
while (op_idx < len(self.ops)):
while op_idx < len(self.ops):
op = self.ops[op_idx]
if op.desc.type() != "dropout":
op_idx += 1
continue
# already insert seed op before dropout
if op.input('Seed') is not None and len(op.input('Seed')) == 1:
op_idx += 1
continue
# add a seed op so that the two dropout op can generate same output
op_unique_name = unique_name.generate("seed")
var_unique_name = unique_name.generate_with_ignorable_key(".".join(
Expand Down
Loading

1 comment on commit aa33bc3

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.