From 0e0a5516682f65cfd947055597f0d3ed7834a129 Mon Sep 17 00:00:00 2001 From: MrChengmo Date: Fri, 5 Feb 2021 17:33:17 +0800 Subject: [PATCH] add truncated gaussian random --- .../distributed/table/common_dense_table.cc | 2 + .../distributed/table/depends/initializers.h | 37 +++++++++++++++++++ .../table/depends/large_scale_kv.h | 3 ++ 3 files changed, 42 insertions(+) diff --git a/paddle/fluid/distributed/table/common_dense_table.cc b/paddle/fluid/distributed/table/common_dense_table.cc index 45f8eed353dc7..4063e4f501d01 100644 --- a/paddle/fluid/distributed/table/common_dense_table.cc +++ b/paddle/fluid/distributed/table/common_dense_table.cc @@ -29,6 +29,8 @@ void CommonDenseTable::create_initializer(const std::string& attr, initializers_[name] = new FillConstantInitializer(slices); } else if (slices[0] == "uniform_random") { initializers_[name] = new UniformInitializer(slices); + } else if (slices[0] == "truncated_gaussian_random") { + initializers_[name] = new TruncatedGaussianInitializer(slices); } else { PADDLE_THROW( platform::errors::InvalidArgument("%s can not be supported", name)); diff --git a/paddle/fluid/distributed/table/depends/initializers.h b/paddle/fluid/distributed/table/depends/initializers.h index e8857ed51560d..f46e659a88bab 100644 --- a/paddle/fluid/distributed/table/depends/initializers.h +++ b/paddle/fluid/distributed/table/depends/initializers.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -23,6 +24,8 @@ #include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/operators/truncated_gaussian_random_op.h" + namespace paddle { namespace distributed { @@ -108,6 +111,40 @@ class GaussianInitializer : public Initializer { std::normal_distribution dist_; }; +class TruncatedGaussianInitializer : public Initializer { + public: + explicit TruncatedGaussianInitializer(const std::vector &attrs) { + name_ = attrs[0]; + seed_ = static_cast(std::stoi(attrs[1])); + mean_ = std::stof(attrs[2]); + std_ = std::stof(attrs[3]); + + std::uniform_real_distribution dist_( + std::numeric_limits::min(), 1.0); + random_engine_ = framework::GetCPURandomEngine(seed_); + } + + float GetValue() override { + paddle::operators::TruncatedNormal truncated_normal(mean_, std_); + float value = truncated_normal(dist_(*random_engine_)); + return value; + } + + void GetValue(float *value, int numel) { + paddle::operators::TruncatedNormal truncated_normal(mean_, std_); + for (int x = 0; x < numel; ++x) { + value[x] = truncated_normal(dist_(*random_engine_)); + } + } + + private: + float std_; + float mean_; + + std::shared_ptr random_engine_; + std::uniform_real_distribution dist_; +}; + class FillConstantInitializer : public Initializer { public: explicit FillConstantInitializer(const std::vector &attrs) { diff --git a/paddle/fluid/distributed/table/depends/large_scale_kv.h b/paddle/fluid/distributed/table/depends/large_scale_kv.h index 9ab3711fe2ea0..55f8489b08cba 100644 --- a/paddle/fluid/distributed/table/depends/large_scale_kv.h +++ b/paddle/fluid/distributed/table/depends/large_scale_kv.h @@ -134,6 +134,9 @@ class ValueBlock { } else if (slices[0] == "uniform_random") { initializers_.emplace_back( std::make_shared(slices)); + } else if (slices[0] == "truncated_gaussian_random") { + initializers_.emplace_back( + std::make_shared(slices)); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s can not be supported", attr));