Skip to content

Commit

Permalink
Merge pull request BVLC#336 from jeffdonahue/fix-rng-segfault
Browse files Browse the repository at this point in the history
Fix RNG segfault related to BVLC#297
  • Loading branch information
shelhamer committed Apr 19, 2014
2 parents b3dd5b9 + 6f3a18f commit eff7fdb
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 35 deletions.
8 changes: 2 additions & 6 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,14 @@ class Caffe {
explicit RNG(unsigned int seed);
explicit RNG(const RNG&);
RNG& operator=(const RNG&);
const void* generator() const;
void set_generator(const void* other_rng);
void* generator();
private:
class Generator;
shared_ptr<Generator> generator_;
};

// Getters for boost rng, curand, and cublas handles
inline static const RNG& rng_stream() {
inline static RNG& rng_stream() {
if (!Get().random_generator_) {
Get().random_generator_.reset(new RNG());
}
Expand All @@ -120,9 +119,6 @@ class Caffe {
inline static void set_phase(Phase phase) { Get().phase_ = phase; }
// Sets the random seed of both boost and curand
static void set_random_seed(const unsigned int seed);
// Sets the boost RNG engine from another RNG engine to maintain state across
// variate_generator calls.
static void set_generator(const void* other_rng);
// Sets the device. Since we have cublas and curand stuff, set device also
// requires us to reset those values.
static void SetDevice(const int device_id);
Expand Down
8 changes: 2 additions & 6 deletions include/caffe/util/rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@ namespace caffe {

typedef boost::mt19937 rng_t;

inline const rng_t& caffe_rng() {
return *static_cast<const caffe::rng_t*>(Caffe::rng_stream().generator());
}

inline void caffe_set_rng(const caffe::rng_t& other) {
Caffe::set_generator(static_cast<const void*>(&other));
inline rng_t* caffe_rng() {
return static_cast<caffe::rng_t*>(Caffe::rng_stream().generator());
}

} // namespace caffe
Expand Down
20 changes: 3 additions & 17 deletions src/caffe/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ void Caffe::set_random_seed(const unsigned int seed) {
Get().random_generator_.reset(new RNG(seed));
}

void Caffe::set_generator(const void* other_rng) {
Get().random_generator_->set_generator(other_rng);
}

void Caffe::SetDevice(const int device_id) {
int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));
Expand Down Expand Up @@ -123,9 +119,7 @@ class Caffe::RNG::Generator {
public:
Generator() : rng_(new caffe::rng_t(cluster_seedgen())) {}
explicit Generator(unsigned int seed) : rng_(new caffe::rng_t(seed)) {}
explicit Generator(const caffe::rng_t& other) :
rng_(new caffe::rng_t(other)) {}
const caffe::rng_t& rng() const { return *rng_; }
caffe::rng_t* rng() { return rng_.get(); }
private:
shared_ptr<caffe::rng_t> rng_;
};
Expand All @@ -134,21 +128,13 @@ Caffe::RNG::RNG() : generator_(new Generator) { }

Caffe::RNG::RNG(unsigned int seed) : generator_(new Generator(seed)) { }

Caffe::RNG::RNG(const RNG& other) : generator_(new Generator(*other.generator_))
{ }

Caffe::RNG& Caffe::RNG::operator=(const RNG& other) {
generator_.reset(other.generator_.get());
return *this;
}

const void* Caffe::RNG::generator() const {
return static_cast<const void*>(&generator_->rng());
}

void Caffe::RNG::set_generator(const void* other_rng) {
const caffe::rng_t& rng = *static_cast<const caffe::rng_t*>(other_rng);
return generator_.reset(new Generator(rng));
void* Caffe::RNG::generator() {
return static_cast<void*>(generator_->rng());
}

const char* cublasGetErrorString(cublasStatus_t error) {
Expand Down
9 changes: 3 additions & 6 deletions src/caffe/util/math_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,11 @@ void caffe_rng_uniform(const int n, const Dtype a, const Dtype b, Dtype* r) {
CHECK(r);
CHECK_LE(a, b);
boost::uniform_real<Dtype> random_distribution(a, caffe_nextafter<Dtype>(b));
boost::variate_generator<caffe::rng_t, boost::uniform_real<Dtype> >
boost::variate_generator<caffe::rng_t*, boost::uniform_real<Dtype> >
variate_generator(caffe_rng(), random_distribution);
for (int i = 0; i < n; ++i) {
r[i] = variate_generator();
}
caffe_set_rng(variate_generator.engine());
}

template
Expand All @@ -343,12 +342,11 @@ void caffe_rng_gaussian(const int n, const Dtype a,
CHECK(r);
CHECK_GT(sigma, 0);
boost::normal_distribution<Dtype> random_distribution(a, sigma);
boost::variate_generator<caffe::rng_t, boost::normal_distribution<Dtype> >
boost::variate_generator<caffe::rng_t*, boost::normal_distribution<Dtype> >
variate_generator(caffe_rng(), random_distribution);
for (int i = 0; i < n; ++i) {
r[i] = variate_generator();
}
caffe_set_rng(variate_generator.engine());
}

template
Expand All @@ -366,12 +364,11 @@ void caffe_rng_bernoulli(const int n, const Dtype p, int* r) {
CHECK_GE(p, 0);
CHECK_LE(p, 1);
boost::bernoulli_distribution<Dtype> random_distribution(p);
boost::variate_generator<caffe::rng_t, boost::bernoulli_distribution<Dtype> >
boost::variate_generator<caffe::rng_t*, boost::bernoulli_distribution<Dtype> >
variate_generator(caffe_rng(), random_distribution);
for (int i = 0; i < n; ++i) {
r[i] = variate_generator();
}
caffe_set_rng(variate_generator.engine());
}

template
Expand Down

0 comments on commit eff7fdb

Please sign in to comment.