Skip to content

Commit

Permalink
rollout reuse native threadpool
Browse files Browse the repository at this point in the history
  • Loading branch information
aftersomemath committed Dec 5, 2024
1 parent 63a6e82 commit 166640c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
27 changes: 19 additions & 8 deletions python/mujoco/rollout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
}

// C-style threaded version of _unsafe_rollout
static ThreadPool* pool = nullptr;
void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d,
int nroll, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
Expand All @@ -272,24 +273,34 @@ void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData
int njobs = nfulljobs;
if (chunk_remainder > 0) njobs++;

ThreadPool pool = ThreadPool(nthread);
if (pool == nullptr) {
pool = new ThreadPool(nthread);
}
else if (pool->NumThreads() != nthread) {
delete pool; // TODO make sure pool is shutdown correctly
pool = new ThreadPool(nthread);
} else {
pool->ResetCount();
}

for (int j = 0; j < nfulljobs; j++) {
auto task = [=, &m, &d, &pool](void) {
_unsafe_rollout(m, d[pool.WorkerId()], j*chunk_size, (j+1)*chunk_size,
auto task = [=, &m, &d](void) {
int id = pool->WorkerId();
_unsafe_rollout(m, d[id], j*chunk_size, (j+1)*chunk_size,
nstep, control_spec, state0, warmstart0, control, state, sensordata);
};
pool.Schedule(task);
pool->Schedule(task);
}

if (chunk_remainder > 0) {
auto task = [=, &m, &d, &pool](void) {
_unsafe_rollout(m, d[pool.WorkerId()], nfulljobs*chunk_size, nfulljobs*chunk_size+chunk_remainder,
auto task = [=, &m, &d](void) {
_unsafe_rollout(m, d[pool->WorkerId()], nfulljobs*chunk_size, nfulljobs*chunk_size+chunk_remainder,
nstep, control_spec, state0, warmstart0, control, state, sensordata);
};
pool.Schedule(task);
pool->Schedule(task);
}

pool.WaitCount(njobs);
pool->WaitCount(njobs);
}

// NOLINTEND(whitespace/line_length)
Expand Down
4 changes: 2 additions & 2 deletions python/mujoco/rollout_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def run(self, model_list, initial_state, nstep):

def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'):
nthread = 24
nroll = [int(1e0), int(1e1), int(1e2)]
nroll = [int(1e0), int(1e1), int(1e2), int(2e2)]
nstep = [int(1e0), int(1e1), int(1e2), int(2e2)]

print('making structures')
Expand All @@ -84,7 +84,7 @@ def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'
for nstep_i in nstep:
nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], nstep=nstep_i), number=10)
pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=10)
print('{:03d} {:03d} {:0.3f} {:0.3f} {:0.3f}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res))
print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res))

# Generate plots

Expand Down

0 comments on commit 166640c

Please sign in to comment.