Skip to content

Commit

Permalink
rollout prototype C++ threadpool for benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
aftersomemath committed Dec 3, 2024
1 parent a793a34 commit 36f3321
Showing 1 changed file with 154 additions and 7 deletions.
161 changes: 154 additions & 7 deletions python/mujoco/rollout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,116 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <condition_variable>
#include <cstdint>
#include <functional>
#include <iostream>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
#include <absl/base/attributes.h>

namespace mujoco::python {

namespace {

namespace py = ::pybind11;

// Copied from https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/threadpool.h
// ThreadPool class
class ThreadPool {
public:
// constructor
explicit ThreadPool(int num_threads);
// destructor
~ThreadPool();
int NumThreads() const { return threads_.size(); }
// returns an ID between 0 and NumThreads() - 1. must be called within
// worker thread (returns -1 if not).
static int WorkerId() { return worker_id_; }
// ----- methods ----- //
// set task for threadpool
void Schedule(std::function<void()> task);
// return number of tasks completed
std::uint64_t GetCount() { return ctr_; }
// reset count to zero
void ResetCount() { ctr_ = 0; }
// wait for count, then return
void WaitCount(int value) {
std::unique_lock<std::mutex> lock(m_);
cv_ext_.wait(lock, [&]() { return this->GetCount() >= value; });
}
private:
// ----- methods ----- //
// execute task with available thread
void WorkerThread(int i);
ABSL_CONST_INIT static thread_local int worker_id_;
// ----- members ----- //
std::vector<std::thread> threads_;
std::mutex m_;
std::condition_variable cv_in_;
std::condition_variable cv_ext_;
std::queue<std::function<void()>> queue_;
std::uint64_t ctr_;
};

// Copied from https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/threadpool.cc
ABSL_CONST_INIT thread_local int ThreadPool::worker_id_ = -1;
// ThreadPool constructor
ThreadPool::ThreadPool(int num_threads) : ctr_(0) {
for (int i = 0; i < num_threads; i++) {
threads_.push_back(std::thread(&ThreadPool::WorkerThread, this, i));
}
}
// ThreadPool destructor
ThreadPool::~ThreadPool() {
{
std::unique_lock<std::mutex> lock(m_);
for (int i = 0; i < threads_.size(); i++) {
queue_.push(nullptr);
}
cv_in_.notify_all();
}
for (auto& thread : threads_) {
thread.join();
}
}
// ThreadPool scheduler
void ThreadPool::Schedule(std::function<void()> task) {
std::unique_lock<std::mutex> lock(m_);
queue_.push(std::move(task));
cv_in_.notify_one();
}
// ThreadPool worker
void ThreadPool::WorkerThread(int i) {
worker_id_ = i;
while (true) {
auto task = [&]() {
std::unique_lock<std::mutex> lock(m_);
cv_in_.wait(lock, [&]() { return !queue_.empty(); });
std::function<void()> task = std::move(queue_.front());
queue_.pop();
cv_in_.notify_one();
return task;
}();
if (task == nullptr) {
{
std::unique_lock<std::mutex> lock(m_);
++ctr_;
cv_ext_.notify_one();
}
break;
}
task();
{
std::unique_lock<std::mutex> lock(m_);
++ctr_;
cv_ext_.notify_one();
}
}
}

// NOLINTBEGIN(whitespace/line_length)

const auto rollout_doc = R"(
Expand All @@ -54,7 +158,7 @@ Roll out open-loop trajectories from initial states, get resulting states and se
// C-style rollout function, assumes all arguments are valid
// all input fields of d are initialised, contents at call time do not matter
// after returning, d will contain the last step of the last rollout
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int nstep, unsigned int control_spec,
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll, int end_roll, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
mjtNum* state, mjtNum* sensordata) {
// sizes
Expand All @@ -75,7 +179,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
}

// loop over rollouts
for (int r = 0; r < nroll; r++) {
for (int r = start_roll; r < end_roll; r++) {
// clear user inputs if unspecified
if (!(control_spec & mjSTATE_MOCAP_POS)) {
for (int i = 0; i < nbody; i++) {
Expand Down Expand Up @@ -158,6 +262,35 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
}
}

// C-style threaded version of _unsafe_rollout
void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d,
int nroll, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
mjtNum* state, mjtNum* sensordata,
int nthread, int chunk_size) {
int njobs = nroll / chunk_size;
int chunk_remainder = nroll % chunk_size;

ThreadPool pool = ThreadPool(nthread);
for (int j = 0; j < njobs; j++) {
auto task = [=, &m, &d, &pool](void) {
_unsafe_rollout(m, d[pool.WorkerId()], j*chunk_size, (j+1)*chunk_size,
nstep, control_spec, state0, warmstart0, control, state, sensordata);
};
pool.Schedule(task);
}

if (chunk_remainder > 0) {
auto task = [=, &m, &d, &pool](void) {
_unsafe_rollout(m, d[pool.WorkerId()], njobs*chunk_size, njobs*chunk_size+chunk_remainder,
nstep, control_spec, state0, warmstart0, control, state, sensordata);
};
pool.Schedule(task);
}

pool.WaitCount(nroll);
}

// NOLINTEND(whitespace/line_length)

// check size of optional argument to rollout(), return raw pointer
Expand Down Expand Up @@ -190,7 +323,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
// get subsequent states and corresponding sensor values
pymodule.def(
"rollout",
[](py::list m, MjDataWrapper& d,
[](py::list m, py::list d,
int nstep, unsigned int control_spec,
const PyCArray state0,
std::optional<const PyCArray> warmstart0,
Expand All @@ -204,7 +337,12 @@ PYBIND11_MODULE(_rollout, pymodule) {
for (int r = 0; r < nroll; r++) {
model_ptrs[r] = m[r].cast<const MjModelWrapper*>()->get();
}
raw::MjData* data = d.get();

int nthread = py::len(d);
std::vector<raw::MjData*> data_ptrs(nthread);
for (int t = 0; t < nthread; t++) {
data_ptrs[t] = d[t].cast<MjDataWrapper*>()->get();
}

// check that some steps need to be taken, return if not
if (nstep < 1) {
Expand All @@ -230,9 +368,18 @@ PYBIND11_MODULE(_rollout, pymodule) {
py::gil_scoped_release no_gil;

// call unsafe rollout function
InterceptMjErrors(_unsafe_rollout)(
model_ptrs, data, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
if (nthread > 0) {
int chunk_size = std::max(1, nroll/(10 * nthread));
InterceptMjErrors(_unsafe_rollout_threaded)(
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
nthread, chunk_size);
}
else {
InterceptMjErrors(_unsafe_rollout)(
model_ptrs, data_ptrs[0], 0, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
}
}
},
py::arg("model"),
Expand Down

0 comments on commit 36f3321

Please sign in to comment.