From 0f112d440d93fb38dbfcc9d61f5ad28c4fdf941a Mon Sep 17 00:00:00 2001 From: Ramin Azarmehr Date: Wed, 5 Oct 2022 19:05:25 -0400 Subject: [PATCH] Add MPSGenerator to enable custom random number generators on MPS backend (#131) This patch will will add support for creating torch.Generator for MPS device, and enables its functions such as manual_seed, get_state, and set_state. --- aten/src/ATen/Context.h | 11 +- aten/src/ATen/detail/MPSHooksInterface.cpp | 5 +- aten/src/ATen/detail/MPSHooksInterface.h | 7 +- aten/src/ATen/mps/MPSGeneratorImpl.h | 52 +++++++++ aten/src/ATen/mps/MPSGeneratorImpl.mm | 100 ++++++++++++++++++ aten/src/ATen/mps/MPSHooks.cpp | 16 ++- aten/src/ATen/mps/MPSHooks.h | 3 + aten/src/ATen/native/mps/OperationUtils.h | 18 ---- aten/src/ATen/native/mps/OperationUtils.mm | 56 +--------- .../native/mps/operations/Distributions.mm | 64 ++++++----- test/test_mps.py | 22 ++++ torch/csrc/Generator.cpp | 18 +++- 12 files changed, 247 insertions(+), 125 deletions(-) create mode 100644 aten/src/ATen/mps/MPSGeneratorImpl.h create mode 100644 aten/src/ATen/mps/MPSGeneratorImpl.mm diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 7f23503c36bcb..9f1c571b66968 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -8,8 +8,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -38,6 +38,8 @@ class TORCH_API Context { return at::detail::getDefaultCPUGenerator(); } else if (device_type == at::kCUDA) { return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index()); + } else if (device_type == at::kMPS) { + return at::detail::getMPSHooks().getDefaultMPSGenerator(); } else { AT_ERROR(DeviceTypeName(device_type), " device type not enabled."); } @@ -421,6 +423,13 @@ static inline void manual_seed(uint64_t seed) { } } } + + if (hasMPS()) { + auto mps_gen = globalContext().defaultGenerator(DeviceType::MPS); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(mps_gen.mutex()); + mps_gen.set_current_seed(seed); + } } // When the global flag `allow_tf32` is set to true, cuBLAS handles are diff --git a/aten/src/ATen/detail/MPSHooksInterface.cpp b/aten/src/ATen/detail/MPSHooksInterface.cpp index 87cd26d517985..823b2295b1ace 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.cpp +++ b/aten/src/ATen/detail/MPSHooksInterface.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace at { namespace detail { @@ -7,8 +8,8 @@ namespace detail { const MPSHooksInterface& getMPSHooks() { static std::unique_ptr mps_hooks; #if !defined C10_MOBILE - static std::once_flag once; - std::call_once(once, [] { + static c10::once_flag once; + c10::call_once(once, [] { mps_hooks = MPSHooksRegistry()->Create("MPSHooks", MPSHooksArgs{}); if (!mps_hooks) { mps_hooks = diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 382bcd3255d13..4fff139f27745 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -1,3 +1,5 @@ +// Copyright © 2022 Apple Inc. + #pragma once #include @@ -15,7 +17,7 @@ class Context; namespace at { struct TORCH_API MPSHooksInterface { - virtual ~MPSHooksInterface() {} + virtual ~MPSHooksInterface() = default; // Initialize the MPS library state virtual void initMPS() const { @@ -26,8 +28,7 @@ struct TORCH_API MPSHooksInterface { return false; } - virtual const Generator& getDefaultMPSGenerator(DeviceIndex device_index = -1) const { - (void)device_index; // Suppress unused variable warning + virtual const Generator& getDefaultMPSGenerator() const { AT_ERROR("Cannot get default MPS generator without MPS backend."); } diff --git a/aten/src/ATen/mps/MPSGeneratorImpl.h b/aten/src/ATen/mps/MPSGeneratorImpl.h new file mode 100644 index 0000000000000..9695eb719274c --- /dev/null +++ b/aten/src/ATen/mps/MPSGeneratorImpl.h @@ -0,0 +1,52 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace mps { +namespace detail { + +static const uint32_t PHILOX_STATE_N = 7; +struct rng_data_pod { + std::array state{1}; + uint64_t seed = default_rng_seed_val; +}; + +TORCH_API const Generator& getDefaultMPSGenerator(); +TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val); + +} // namespace detail +} // namespace mps + +struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl { + // Constructors + MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val); + ~MPSGeneratorImpl() override = default; + + // MPSGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + void update_philox_counters(); + + void set_engine(at::Philox4_32 engine) { engine_ = engine; }; + at::Philox4_32 engine() { return engine_; }; + uint32_t* state_data() { return data_.state.data(); } + static DeviceType device_type() { return DeviceType::MPS; }; + +private: + mps::detail::rng_data_pod data_; + at::Philox4_32 engine_; + + MPSGeneratorImpl* clone_impl() const override; +}; + +} // namespace at diff --git a/aten/src/ATen/mps/MPSGeneratorImpl.mm b/aten/src/ATen/mps/MPSGeneratorImpl.mm new file mode 100644 index 0000000000000..8f2d5168b71b8 --- /dev/null +++ b/aten/src/ATen/mps/MPSGeneratorImpl.mm @@ -0,0 +1,100 @@ +// Copyright © 2022 Apple Inc. + +#include +#include +#include + +namespace at { +namespace mps { +namespace detail { + +const Generator& getDefaultMPSGenerator() { + static auto default_gen_mps = createMPSGenerator(c10::detail::getNonDeterministicRandom()); + return default_gen_mps; +} + +Generator createMPSGenerator(uint64_t seed_val) { + auto gen = make_generator(seed_val); + gen.set_current_seed(seed_val); + return gen; +} + +} // namespace detail +} // namespace mps + +MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in) + : c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)}, + data_({.seed = seed_in}), engine_(seed_in, 0, 0) { } + +void MPSGeneratorImpl::set_current_seed(uint64_t seed) { + data_.seed = seed; + data_.state.fill(1); + // the two last state values are the Philox keys + // TODO: make "key" in PhiloxRNGEngine.h public so we don't duplicate code here + data_.state[5] = static_cast(seed); + data_.state[6] = static_cast(seed >> 32); + engine_.reset_state(seed); +} + +uint64_t MPSGeneratorImpl::current_seed() const { + return data_.seed; +} + +uint64_t MPSGeneratorImpl::seed() { + auto random = c10::detail::getNonDeterministicRandom(); + this->set_current_seed(random); + return random; +} + +// See Note [Acquire lock when using random generators] +void MPSGeneratorImpl::update_philox_counters() { + // calling engine_() would call operator() of philox_engine class to + // get each of the four newly generated counter values (see PhiloxRNGEngine.h). + for (int i = 1; i <= 4; i++) { + data_.state[i] = engine_(); + } +} + +c10::intrusive_ptr MPSGeneratorImpl::get_state() const { + static const size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t); + static const size_t seed_size = sizeof(uint64_t); + static const size_t total_size = states_size + seed_size; + + auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + auto rng_state = state_tensor.data_ptr(); + auto current_seed = this->current_seed(); + memcpy(rng_state, this->data_.state.data(), states_size); + memcpy(rng_state + states_size, ¤t_seed, seed_size); + + return state_tensor.getIntrusivePtr(); +} + +void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + static const size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t); + static const size_t seed_size = sizeof(uint64_t); + static const size_t total_size = states_size + seed_size; + + detail::check_rng_state(new_state); + + auto new_state_size = new_state.numel(); + TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); + + uint64_t input_seed = default_rng_seed_val; + auto new_rng_state = new_state.data(); + memcpy(&input_seed, new_rng_state + states_size, seed_size); + this->set_current_seed(input_seed); + // state.data must be copied after input_seed to not reset the state in set_current_seed() + memcpy(this->state_data(), new_rng_state, states_size); +} + +std::shared_ptr MPSGeneratorImpl::clone() const { + return std::shared_ptr(this->clone_impl()); +} + +MPSGeneratorImpl* MPSGeneratorImpl::clone_impl() const { + auto gen = new MPSGeneratorImpl(this->data_.seed); + gen->set_current_seed(this->data_.seed); + return gen; +} + +} // namespace at diff --git a/aten/src/ATen/mps/MPSHooks.cpp b/aten/src/ATen/mps/MPSHooks.cpp index bbf7234462189..5fde8f3843fe6 100644 --- a/aten/src/ATen/mps/MPSHooks.cpp +++ b/aten/src/ATen/mps/MPSHooks.cpp @@ -1,14 +1,8 @@ -#include +// Copyright © 2022 Apple Inc. -#include +#include #include -#include -#include - -#include -#include -#include -#include +#include namespace at { namespace mps { @@ -26,6 +20,10 @@ Allocator* MPSHooks::getMPSDeviceAllocator() const { return at::mps::GetMPSAllocator(); } +const Generator& MPSHooks::getDefaultMPSGenerator() const { + return at::mps::detail::getDefaultMPSGenerator(); +} + using at::MPSHooksRegistry; using at::RegistererMPSHooksRegistry; diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index 13647d83c740b..2bef3eac42648 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -1,3 +1,5 @@ +// Copyright © 2022 Apple Inc. + #pragma once #include @@ -12,6 +14,7 @@ struct MPSHooks : public at::MPSHooksInterface { void initMPS() const override; bool hasMPS() const override; Allocator* getMPSDeviceAllocator() const override; + const Generator& getDefaultMPSGenerator() const override; }; }} // at::mps diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 93b0141243397..cc86c4ede4c3b 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -19,24 +19,6 @@ namespace at { namespace native { namespace mps { -struct TORCH_CUDA_CPP_API MPSGeneratorImpl : public c10::GeneratorImpl { - MPSGeneratorImpl(DeviceIndex device_index = -1); - ~MPSGeneratorImpl() = default; - - void set_current_seed(uint64_t seed) override; - uint64_t current_seed() const override; - uint64_t seed() override; - void set_state(const c10::TensorImpl& new_state) override; - c10::intrusive_ptr get_state() const override; - static DeviceType device_type(); - -private: - MPSGeneratorImpl* clone_impl() const override; - uint64_t seed_ = default_rng_seed_val; -}; - -const Generator& getDefaultMPSGenerator(); - struct MPSScalar { id getMTLBuffer() const { return __builtin_bit_cast(id, buffer.get()); } diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index f41484b27b143..6e3ecc3b8e9bf 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -7,60 +7,6 @@ namespace native { namespace mps { -uint64_t MPSGeneratorImpl::seed() { - auto random = c10::detail::getNonDeterministicRandom(true); - this->set_current_seed(random); - return random; -} - -uint64_t MPSGeneratorImpl::current_seed() const { - return seed_; -} - -void MPSGeneratorImpl::set_current_seed(uint64_t seed) { - seed_ = seed; -} - -MPSGeneratorImpl::MPSGeneratorImpl(DeviceIndex device_index) - : c10::GeneratorImpl{Device(DeviceType::MPS, device_index), - DispatchKeySet(c10::DispatchKey::MPS)} { -} - -const Generator& getDefaultMPSGenerator() { - static auto gen = make_generator(0); - gen.seed(); - return gen; -} -DeviceType MPSGeneratorImpl::device_type() { - return DeviceType::MPS; -} -c10::intrusive_ptr MPSGeneratorImpl::get_state() const { - static const size_t seed_size = sizeof(uint64_t); - static const size_t offset_size = sizeof(int64_t); - static const size_t total_size = seed_size + offset_size; - - auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); - - return state_tensor.getIntrusivePtr(); -} - -void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) { - static const size_t seed_size = sizeof(uint64_t); - - detail::check_rng_state(new_state); - - uint64_t input_seed; - auto new_rng_state = new_state.data(); - memcpy(&input_seed, new_rng_state, seed_size); - this->set_current_seed(input_seed); -} - -MPSGeneratorImpl* MPSGeneratorImpl::clone_impl() const { - auto gen = new MPSGeneratorImpl(0); - gen->set_current_seed(this->seed_); - return gen; -} - void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) { mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE); } @@ -388,4 +334,4 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override { } } // namespace mps } // namespace native -} // namespace at \ No newline at end of file +} // namespace at diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index 99d01c6825b35..b527da3925d69 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -3,7 +3,7 @@ #include #include #include -#include +#include namespace at { namespace native { @@ -26,17 +26,6 @@ MPSGraphTensor *stateTensor = nil; // used for Normal distributions only MPSGraphTensor *meanTensor = nil, *stdTensor = nil; - // we initialize and keep the philox's state in the graph. This would - // guarantee producing new random values each time the same graph is reused. - at::Philox4_32 philoxState; - std::array stateValues = {1}; - - void updatePhiloxCounters() { - // calling philoxState() would call operator() of philox_engine class to - // get each of the four newly generated counter values (see PhiloxRNGEngine.h). - for (int i = 1; i <= 4; i++) - stateValues[i] = philoxState(); - } }; typedef MPSGraphTensor* (^RandomOpBlock)(RandomCachedGraph*, MPSGraphTensor*); @@ -49,11 +38,13 @@ void updatePhiloxCounters() { const c10::optional& mean_opt, const c10::optional& std_opt, MPSGraphRandomDistribution distribution, + c10::optional gen, std::string op_name, RandomOpBlock randomBlock) { if (self.numel() == 0) { return self; } + auto mps_gen = get_generator_or_default(gen, at::mps::detail::getDefaultMPSGenerator()); MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream* stream = getCurrentMPSStream(); @@ -68,7 +59,7 @@ void updatePhiloxCounters() { @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new RandomCachedGraph(mpsGraph); - newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@7]); + newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(at::mps::detail::PHILOX_STATE_N)]); // FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend. const MPSDataType inputDataType = [&] { @@ -95,7 +86,7 @@ void updatePhiloxCounters() { desc.standardDeviation = static_cast(val2); } // we don't use the output state tensor from the MPSGraph API as it requires reading back from GPU to CPU. - // Instead, we keep the Philox state in the cached graph and use the PyTorch's philox_engine to maintain + // Instead, we keep the Philox state in the MPSGenerator and use the PyTorch's philox_engine to maintain // the counters, and feed them to the graph manually NSArray *resultTensors = [mpsGraph randomTensorWithShape: getMPSShape(self) descriptor: desc @@ -109,12 +100,16 @@ void updatePhiloxCounters() { return newCachedGraph; }); } - // update the Philox state values on each run of the same graph - cachedGraph->updatePhiloxCounters(); // feed the updated state values to the graph - MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@7]]; + MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]]; MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease]; - [stateNDArray writeBytes: &cachedGraph->stateValues[0] strideBytes: nil]; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(mps_gen->mutex_); + // update the Philox state values on each run + mps_gen->update_philox_counters(); + [stateNDArray writeBytes: mps_gen->state_data() strideBytes: nil]; + } MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease]; Placeholder meanPlaceholder, stdPlaceholder; @@ -146,6 +141,7 @@ void updatePhiloxCounters() { Tensor& normal_mps_impl(Tensor& self, double mean_s, double std_s, const c10::optional& mean_opt, const c10::optional& std_opt, + c10::optional gen, std::string op_name) { const Tensor& std_t = *(at::borrow_from_optional_tensor(std_opt)); @@ -177,12 +173,12 @@ void updatePhiloxCounters() { return resultTensor; }; return random_mps_impl(self, mean_s, std_s, mean_opt, std_opt, - MPSGraphRandomDistributionNormal, + MPSGraphRandomDistributionNormal, gen, op_name + getTensorsStringKey({mean_t, std_t}), random_op_block); } -Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, std::string op_name) +Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, c10::optional gen, std::string op_name) { TORCH_CHECK(prob_t.is_same_size(self), op_name, ": probability and self tensor should be of the same shape") @@ -195,7 +191,7 @@ void updatePhiloxCounters() { }; // Bernoulli generates binary output so we use bool type return mps::random_mps_impl(self, 0.0, 1.0, c10::nullopt, prob_t, - MPSGraphRandomDistributionUniform, + MPSGraphRandomDistributionUniform, gen, op_name + getTensorsStringKey({prob_t}), random_op_block); } @@ -215,16 +211,16 @@ void updatePhiloxCounters() { }); return mps::random_mps_impl(self, from, to, c10::nullopt, c10::nullopt, - MPSGraphRandomDistributionUniform, __func__, nullptr); + MPSGraphRandomDistributionUniform, gen, __func__, nullptr); } Tensor& normal_mps_(Tensor& self, double mean, double std, c10::optional gen) { - return mps::normal_mps_impl(self, mean, std, c10::nullopt, c10::nullopt, __func__); + return mps::normal_mps_impl(self, mean, std, c10::nullopt, c10::nullopt, gen, __func__); } Tensor normal_mps(const Tensor& mean, double std, c10::optional gen) { Tensor self = empty_mps(mean.sizes(), mean.scalar_type(), c10::nullopt, kMPS); - return mps::normal_mps_impl(self, 0.0, std, mean, c10::nullopt, __func__); + return mps::normal_mps_impl(self, 0.0, std, mean, c10::nullopt, gen, __func__); } Tensor normal_mps(double mean, const Tensor& std, c10::optional gen) { @@ -232,44 +228,44 @@ Tensor normal_mps(double mean, const Tensor& std, c10::optional gen) // when there's no tensor-type mean, we cannot pass scalar mean value due to the order of // multiply/add ops in random computation. So we create a mean tensor instead. Tensor mean_t = at::full_like(self, Scalar(mean)); - return mps::normal_mps_impl(self, 0.0, 1.0, mean_t, std, __func__); + return mps::normal_mps_impl(self, 0.0, 1.0, mean_t, std, gen, __func__); } Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional gen) { auto shape = at::infer_size(mean.sizes(), std.sizes()); Tensor self = empty_mps(shape, mean.scalar_type(), c10::nullopt, kMPS); - return mps::normal_mps_impl(self, 0.0, 1.0, mean, std, __func__); + return mps::normal_mps_impl(self, 0.0, 1.0, mean, std, gen, __func__); } Tensor& normal_mps_out(const Tensor& mean, double std, c10::optional gen, Tensor& self) { - return mps::normal_mps_impl(self, 0.0, std, mean, c10::nullopt, __func__); + return mps::normal_mps_impl(self, 0.0, std, mean, c10::nullopt, gen, __func__); } Tensor& normal_mps_out(double mean, const Tensor& std, c10::optional gen, Tensor& self) { // when there's no tensor-type mean, we cannot pass scalar mean value due to the order of // multiply/add ops in random computation. So we create a mean tensor instead. Tensor mean_t = at::full_like(self, Scalar(mean)); - return mps::normal_mps_impl(self, 0.0, 1.0, mean_t, std, __func__); + return mps::normal_mps_impl(self, 0.0, 1.0, mean_t, std, gen, __func__); } Tensor& normal_mps_out(const Tensor& mean, const Tensor& std, c10::optional gen, Tensor& self) { TORCH_CHECK(mean.numel() == std.numel(), "normal_mps_out: mean and std must have same number of elements") - return mps::normal_mps_impl(self, 0.0, 1.0, mean, std, __func__); + return mps::normal_mps_impl(self, 0.0, 1.0, mean, std, gen, __func__); } Tensor& bernoulli_out_mps(const Tensor& p_, c10::optional gen, Tensor& result) { result.resize_(p_.sizes()); - return mps::bernoulli_mps_impl(result, p_, __func__); + return mps::bernoulli_mps_impl(result, p_, gen, __func__); } Tensor& bernoulli_mps_(Tensor& self, double p, c10::optional gen) { TORCH_CHECK(0.0 <= p && p <= 1.0, "bernoulli_mps_ expects p to be in [0, 1], but got p=", p); Tensor prob_t = at::full_like(self, Scalar(p)); - return mps::bernoulli_mps_impl(self, prob_t, __func__); + return mps::bernoulli_mps_impl(self, prob_t, gen, __func__); } Tensor& bernoulli_mps_(Tensor& self, const Tensor& p_, c10::optional gen) { - return mps::bernoulli_mps_impl(self, p_, __func__); + return mps::bernoulli_mps_impl(self, p_, gen, __func__); } // random_.from @@ -321,7 +317,7 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional(self, from, to - 1, c10::nullopt, c10::nullopt, - MPSGraphRandomDistributionUniform, __func__, nullptr); + MPSGraphRandomDistributionUniform, gen, __func__, nullptr); } Tensor& random_mps_(Tensor& self, int64_t to, c10::optional gen) { @@ -348,7 +344,7 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional(self, 0.0, 1.0, c10::nullopt, c10::nullopt, - MPSGraphRandomDistributionUniform, + MPSGraphRandomDistributionUniform, gen, "exponential_mps_:" + std::to_string(lambda), random_op_block); } diff --git a/test/test_mps.py b/test/test_mps.py index ffe2190403a1b..aa3b444bce448 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -4887,6 +4887,28 @@ def test_bernoulli(self): mps_out = torch.bernoulli(all_ones) self.assertEqual(mps_out, all_ones) + def test_mps_generator(self): + # explicit manual seeding by creating an MPS Generator + g_mps = torch.Generator(device='mps') + g_mps.manual_seed(999) + mps_x = torch.randn(5, device='mps', generator=g_mps) + g_mps.manual_seed(999) + mps_y = torch.randn(5, device='mps', generator=g_mps) + # seed values were the same, so the random tensor contents should match + self.assertEqual(mps_x, mps_y) + # save generator's state to restore it later + g_state = g_mps.get_state() + + # generate random numbers without seeding + mps_x = torch.randn(5, device='mps', generator=g_mps) + # in this case, the random results must differ from the last generated random results + self.assertNotEqual(mps_x, mps_y) + + # restore the previously saved state, and the results should match again + g_mps.set_state(g_state) + mps_x = torch.randn(5, device='mps', generator=g_mps) + self.assertEqual(mps_x, mps_y) + # Test random_.to and random_.from def test_random(self): def helper(shape, low, high, dtype=torch.int32): diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 31dcfefaea8d8..d5939496eff45 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -17,6 +17,10 @@ #include #endif +#ifdef USE_MPS +#include +#endif + using namespace at; using namespace torch; @@ -52,12 +56,20 @@ static PyObject* THPGenerator_pynew( auto device = r.deviceWithDefault(0, at::Device(at::kCPU)); THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0)); -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_MPS) if (device.type() == at::kCPU) { self->cdata = make_generator(); - } else if (device.type() == at::kCUDA) { + } +#ifdef USE_CUDA + else if (device.type() == at::kCUDA) { self->cdata = make_generator(device.index()); - } else { + } +#elif USE_MPS + else if (device.type() == at::kMPS) { + self->cdata = make_generator(); + } +#endif + else { AT_ERROR( "Device type ", c10::DeviceTypeName(device.type()),