Skip to content

Commit

Permalink
Add MPSGenerator to enable custom random number generators on MPS bac…
Browse files Browse the repository at this point in the history
…kend (#131)

This patch will will add support for creating torch.Generator for MPS device,
and enables its functions such as manual_seed, get_state, and set_state.
  • Loading branch information
razarmehr authored and kulinseth committed Dec 9, 2022
1 parent bc3ff0a commit 0f112d4
Show file tree
Hide file tree
Showing 12 changed files with 247 additions and 125 deletions.
11 changes: 10 additions & 1 deletion aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/detail/HIPHooksInterface.h>
#include <ATen/detail/ORTHooksInterface.h>
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/detail/ORTHooksInterface.h>
#include <c10/core/QEngine.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/CallOnce.h>
Expand Down Expand Up @@ -38,6 +38,8 @@ class TORCH_API Context {
return at::detail::getDefaultCPUGenerator();
} else if (device_type == at::kCUDA) {
return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
} else if (device_type == at::kMPS) {
return at::detail::getMPSHooks().getDefaultMPSGenerator();
} else {
AT_ERROR(DeviceTypeName(device_type), " device type not enabled.");
}
Expand Down Expand Up @@ -421,6 +423,13 @@ static inline void manual_seed(uint64_t seed) {
}
}
}

if (hasMPS()) {
auto mps_gen = globalContext().defaultGenerator(DeviceType::MPS);
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(mps_gen.mutex());
mps_gen.set_current_seed(seed);
}
}

// When the global flag `allow_tf32` is set to true, cuBLAS handles are
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/detail/MPSHooksInterface.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#include <ATen/detail/MPSHooksInterface.h>
#include <c10/util/Exception.h>
#include <c10/util/CallOnce.h>

namespace at {
namespace detail {

const MPSHooksInterface& getMPSHooks() {
static std::unique_ptr<MPSHooksInterface> mps_hooks;
#if !defined C10_MOBILE
static std::once_flag once;
std::call_once(once, [] {
static c10::once_flag once;
c10::call_once(once, [] {
mps_hooks = MPSHooksRegistry()->Create("MPSHooks", MPSHooksArgs{});
if (!mps_hooks) {
mps_hooks =
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/detail/MPSHooksInterface.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright © 2022 Apple Inc.

#pragma once

#include <c10/core/Allocator.h>
Expand All @@ -15,7 +17,7 @@ class Context;
namespace at {

struct TORCH_API MPSHooksInterface {
virtual ~MPSHooksInterface() {}
virtual ~MPSHooksInterface() = default;

// Initialize the MPS library state
virtual void initMPS() const {
Expand All @@ -26,8 +28,7 @@ struct TORCH_API MPSHooksInterface {
return false;
}

virtual const Generator& getDefaultMPSGenerator(DeviceIndex device_index = -1) const {
(void)device_index; // Suppress unused variable warning
virtual const Generator& getDefaultMPSGenerator() const {
AT_ERROR("Cannot get default MPS generator without MPS backend.");
}

Expand Down
52 changes: 52 additions & 0 deletions aten/src/ATen/mps/MPSGeneratorImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright © 2022 Apple Inc.

#pragma once

#include <ATen/core/Generator.h>
#include <ATen/core/PhiloxRNGEngine.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/util/Optional.h>

namespace at {
namespace mps {
namespace detail {

static const uint32_t PHILOX_STATE_N = 7;
struct rng_data_pod {
std::array<uint32_t, PHILOX_STATE_N> state{1};
uint64_t seed = default_rng_seed_val;
};

TORCH_API const Generator& getDefaultMPSGenerator();
TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);

} // namespace detail
} // namespace mps

struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
// Constructors
MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
~MPSGeneratorImpl() override = default;

// MPSGeneratorImpl methods
std::shared_ptr<MPSGeneratorImpl> clone() const;
void set_current_seed(uint64_t seed) override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
void update_philox_counters();

void set_engine(at::Philox4_32 engine) { engine_ = engine; };
at::Philox4_32 engine() { return engine_; };
uint32_t* state_data() { return data_.state.data(); }
static DeviceType device_type() { return DeviceType::MPS; };

private:
mps::detail::rng_data_pod data_;
at::Philox4_32 engine_;

MPSGeneratorImpl* clone_impl() const override;
};

} // namespace at
100 changes: 100 additions & 0 deletions aten/src/ATen/mps/MPSGeneratorImpl.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright © 2022 Apple Inc.

#include <ATen/Utils.h>
#include <ATen/mps/MPSGeneratorImpl.h>
#include <algorithm>

namespace at {
namespace mps {
namespace detail {

const Generator& getDefaultMPSGenerator() {
static auto default_gen_mps = createMPSGenerator(c10::detail::getNonDeterministicRandom());
return default_gen_mps;
}

Generator createMPSGenerator(uint64_t seed_val) {
auto gen = make_generator<MPSGeneratorImpl>(seed_val);
gen.set_current_seed(seed_val);
return gen;
}

} // namespace detail
} // namespace mps

MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in)
: c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)},
data_({.seed = seed_in}), engine_(seed_in, 0, 0) { }

void MPSGeneratorImpl::set_current_seed(uint64_t seed) {
data_.seed = seed;
data_.state.fill(1);
// the two last state values are the Philox keys
// TODO: make "key" in PhiloxRNGEngine.h public so we don't duplicate code here
data_.state[5] = static_cast<uint32_t>(seed);
data_.state[6] = static_cast<uint32_t>(seed >> 32);
engine_.reset_state(seed);
}

uint64_t MPSGeneratorImpl::current_seed() const {
return data_.seed;
}

uint64_t MPSGeneratorImpl::seed() {
auto random = c10::detail::getNonDeterministicRandom();
this->set_current_seed(random);
return random;
}

// See Note [Acquire lock when using random generators]
void MPSGeneratorImpl::update_philox_counters() {
// calling engine_() would call operator() of philox_engine class to
// get each of the four newly generated counter values (see PhiloxRNGEngine.h).
for (int i = 1; i <= 4; i++) {
data_.state[i] = engine_();
}
}

c10::intrusive_ptr<c10::TensorImpl> MPSGeneratorImpl::get_state() const {
static const size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t);
static const size_t seed_size = sizeof(uint64_t);
static const size_t total_size = states_size + seed_size;

auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
auto current_seed = this->current_seed();
memcpy(rng_state, this->data_.state.data(), states_size);
memcpy(rng_state + states_size, &current_seed, seed_size);

return state_tensor.getIntrusivePtr();
}

void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
static const size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t);
static const size_t seed_size = sizeof(uint64_t);
static const size_t total_size = states_size + seed_size;

detail::check_rng_state(new_state);

auto new_state_size = new_state.numel();
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");

uint64_t input_seed = default_rng_seed_val;
auto new_rng_state = new_state.data<uint8_t>();
memcpy(&input_seed, new_rng_state + states_size, seed_size);
this->set_current_seed(input_seed);
// state.data must be copied after input_seed to not reset the state in set_current_seed()
memcpy(this->state_data(), new_rng_state, states_size);
}

std::shared_ptr<MPSGeneratorImpl> MPSGeneratorImpl::clone() const {
return std::shared_ptr<MPSGeneratorImpl>(this->clone_impl());
}

MPSGeneratorImpl* MPSGeneratorImpl::clone_impl() const {
auto gen = new MPSGeneratorImpl(this->data_.seed);
gen->set_current_seed(this->data_.seed);
return gen;
}

} // namespace at
16 changes: 7 additions & 9 deletions aten/src/ATen/mps/MPSHooks.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
#include <ATen/mps/MPSHooks.h>
// Copyright © 2022 Apple Inc.

#include <ATen/Context.h>
#include <ATen/mps/MPSHooks.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/detail/MPSHooksInterface.h>
#include <c10/util/irange.h>

#include <sstream>
#include <cstddef>
#include <functional>
#include <memory>
#include <ATen/mps/MPSGeneratorImpl.h>

namespace at {
namespace mps {
Expand All @@ -26,6 +20,10 @@ Allocator* MPSHooks::getMPSDeviceAllocator() const {
return at::mps::GetMPSAllocator();
}

const Generator& MPSHooks::getDefaultMPSGenerator() const {
return at::mps::detail::getDefaultMPSGenerator();
}

using at::MPSHooksRegistry;
using at::RegistererMPSHooksRegistry;

Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/mps/MPSHooks.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright © 2022 Apple Inc.

#pragma once

#include <ATen/detail/MPSHooksInterface.h>
Expand All @@ -12,6 +14,7 @@ struct MPSHooks : public at::MPSHooksInterface {
void initMPS() const override;
bool hasMPS() const override;
Allocator* getMPSDeviceAllocator() const override;
const Generator& getDefaultMPSGenerator() const override;
};

}} // at::mps
18 changes: 0 additions & 18 deletions aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,6 @@ namespace at {
namespace native {
namespace mps {

struct TORCH_CUDA_CPP_API MPSGeneratorImpl : public c10::GeneratorImpl {
MPSGeneratorImpl(DeviceIndex device_index = -1);
~MPSGeneratorImpl() = default;

void set_current_seed(uint64_t seed) override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
static DeviceType device_type();

private:
MPSGeneratorImpl* clone_impl() const override;
uint64_t seed_ = default_rng_seed_val;
};

const Generator& getDefaultMPSGenerator();

struct MPSScalar {
id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }

Expand Down
56 changes: 1 addition & 55 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -7,60 +7,6 @@
namespace native {
namespace mps {

uint64_t MPSGeneratorImpl::seed() {
auto random = c10::detail::getNonDeterministicRandom(true);
this->set_current_seed(random);
return random;
}

uint64_t MPSGeneratorImpl::current_seed() const {
return seed_;
}

void MPSGeneratorImpl::set_current_seed(uint64_t seed) {
seed_ = seed;
}

MPSGeneratorImpl::MPSGeneratorImpl(DeviceIndex device_index)
: c10::GeneratorImpl{Device(DeviceType::MPS, device_index),
DispatchKeySet(c10::DispatchKey::MPS)} {
}

const Generator& getDefaultMPSGenerator() {
static auto gen = make_generator<MPSGeneratorImpl>(0);
gen.seed();
return gen;
}
DeviceType MPSGeneratorImpl::device_type() {
return DeviceType::MPS;
}
c10::intrusive_ptr<c10::TensorImpl> MPSGeneratorImpl::get_state() const {
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;

auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);

return state_tensor.getIntrusivePtr();
}

void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
static const size_t seed_size = sizeof(uint64_t);

detail::check_rng_state(new_state);

uint64_t input_seed;
auto new_rng_state = new_state.data<uint8_t>();
memcpy(&input_seed, new_rng_state, seed_size);
this->set_current_seed(input_seed);
}

MPSGeneratorImpl* MPSGeneratorImpl::clone_impl() const {
auto gen = new MPSGeneratorImpl(0);
gen->set_current_seed(this->seed_);
return gen;
}

void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) {
mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE);
}
Expand Down Expand Up @@ -388,4 +334,4 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override { }

} // namespace mps
} // namespace native
} // namespace at
} // namespace at
Loading

0 comments on commit 0f112d4

Please sign in to comment.