forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MPSGenerator to enable custom random number generators on MPS bac…
…kend (#131) This patch will will add support for creating torch.Generator for MPS device, and enables its functions such as manual_seed, get_state, and set_state.
- Loading branch information
Showing
12 changed files
with
247 additions
and
125 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright © 2022 Apple Inc. | ||
|
||
#pragma once | ||
|
||
#include <ATen/core/Generator.h> | ||
#include <ATen/core/PhiloxRNGEngine.h> | ||
#include <c10/core/GeneratorImpl.h> | ||
#include <c10/util/Optional.h> | ||
|
||
namespace at { | ||
namespace mps { | ||
namespace detail { | ||
|
||
static const uint32_t PHILOX_STATE_N = 7; | ||
struct rng_data_pod { | ||
std::array<uint32_t, PHILOX_STATE_N> state{1}; | ||
uint64_t seed = default_rng_seed_val; | ||
}; | ||
|
||
TORCH_API const Generator& getDefaultMPSGenerator(); | ||
TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val); | ||
|
||
} // namespace detail | ||
} // namespace mps | ||
|
||
struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl { | ||
// Constructors | ||
MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val); | ||
~MPSGeneratorImpl() override = default; | ||
|
||
// MPSGeneratorImpl methods | ||
std::shared_ptr<MPSGeneratorImpl> clone() const; | ||
void set_current_seed(uint64_t seed) override; | ||
uint64_t current_seed() const override; | ||
uint64_t seed() override; | ||
void set_state(const c10::TensorImpl& new_state) override; | ||
c10::intrusive_ptr<c10::TensorImpl> get_state() const override; | ||
void update_philox_counters(); | ||
|
||
void set_engine(at::Philox4_32 engine) { engine_ = engine; }; | ||
at::Philox4_32 engine() { return engine_; }; | ||
uint32_t* state_data() { return data_.state.data(); } | ||
static DeviceType device_type() { return DeviceType::MPS; }; | ||
|
||
private: | ||
mps::detail::rng_data_pod data_; | ||
at::Philox4_32 engine_; | ||
|
||
MPSGeneratorImpl* clone_impl() const override; | ||
}; | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// Copyright © 2022 Apple Inc. | ||
|
||
#include <ATen/Utils.h> | ||
#include <ATen/mps/MPSGeneratorImpl.h> | ||
#include <algorithm> | ||
|
||
namespace at { | ||
namespace mps { | ||
namespace detail { | ||
|
||
const Generator& getDefaultMPSGenerator() { | ||
static auto default_gen_mps = createMPSGenerator(c10::detail::getNonDeterministicRandom()); | ||
return default_gen_mps; | ||
} | ||
|
||
Generator createMPSGenerator(uint64_t seed_val) { | ||
auto gen = make_generator<MPSGeneratorImpl>(seed_val); | ||
gen.set_current_seed(seed_val); | ||
return gen; | ||
} | ||
|
||
} // namespace detail | ||
} // namespace mps | ||
|
||
MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in) | ||
: c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)}, | ||
data_({.seed = seed_in}), engine_(seed_in, 0, 0) { } | ||
|
||
void MPSGeneratorImpl::set_current_seed(uint64_t seed) { | ||
data_.seed = seed; | ||
data_.state.fill(1); | ||
// the two last state values are the Philox keys | ||
// TODO: make "key" in PhiloxRNGEngine.h public so we don't duplicate code here | ||
data_.state[5] = static_cast<uint32_t>(seed); | ||
data_.state[6] = static_cast<uint32_t>(seed >> 32); | ||
engine_.reset_state(seed); | ||
} | ||
|
||
uint64_t MPSGeneratorImpl::current_seed() const { | ||
return data_.seed; | ||
} | ||
|
||
uint64_t MPSGeneratorImpl::seed() { | ||
auto random = c10::detail::getNonDeterministicRandom(); | ||
this->set_current_seed(random); | ||
return random; | ||
} | ||
|
||
// See Note [Acquire lock when using random generators] | ||
void MPSGeneratorImpl::update_philox_counters() { | ||
// calling engine_() would call operator() of philox_engine class to | ||
// get each of the four newly generated counter values (see PhiloxRNGEngine.h). | ||
for (int i = 1; i <= 4; i++) { | ||
data_.state[i] = engine_(); | ||
} | ||
} | ||
|
||
c10::intrusive_ptr<c10::TensorImpl> MPSGeneratorImpl::get_state() const { | ||
static const size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t); | ||
static const size_t seed_size = sizeof(uint64_t); | ||
static const size_t total_size = states_size + seed_size; | ||
|
||
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); | ||
auto rng_state = state_tensor.data_ptr<uint8_t>(); | ||
auto current_seed = this->current_seed(); | ||
memcpy(rng_state, this->data_.state.data(), states_size); | ||
memcpy(rng_state + states_size, ¤t_seed, seed_size); | ||
|
||
return state_tensor.getIntrusivePtr(); | ||
} | ||
|
||
void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) { | ||
static const size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t); | ||
static const size_t seed_size = sizeof(uint64_t); | ||
static const size_t total_size = states_size + seed_size; | ||
|
||
detail::check_rng_state(new_state); | ||
|
||
auto new_state_size = new_state.numel(); | ||
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); | ||
|
||
uint64_t input_seed = default_rng_seed_val; | ||
auto new_rng_state = new_state.data<uint8_t>(); | ||
memcpy(&input_seed, new_rng_state + states_size, seed_size); | ||
this->set_current_seed(input_seed); | ||
// state.data must be copied after input_seed to not reset the state in set_current_seed() | ||
memcpy(this->state_data(), new_rng_state, states_size); | ||
} | ||
|
||
std::shared_ptr<MPSGeneratorImpl> MPSGeneratorImpl::clone() const { | ||
return std::shared_ptr<MPSGeneratorImpl>(this->clone_impl()); | ||
} | ||
|
||
MPSGeneratorImpl* MPSGeneratorImpl::clone_impl() const { | ||
auto gen = new MPSGeneratorImpl(this->data_.seed); | ||
gen->set_current_seed(this->data_.seed); | ||
return gen; | ||
} | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.