Skip to content

Commit

Permalink
rollout better benchmarking and use skip_checks with native threading
Browse files Browse the repository at this point in the history
  • Loading branch information
aftersomemath committed Dec 5, 2024
1 parent 166640c commit 1fa1727
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions python/mujoco/rollout_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import concurrent.futures
import threading
import time
import timeit

import mujoco
Expand Down Expand Up @@ -81,12 +82,33 @@ def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'

print('running benchmark')
for nroll_i in nroll:
print('roll', nroll_i)
for nstep_i in nstep:
nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], nstep=nstep_i), number=10)
pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=10)
print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res))
number = int((1*nroll[-1] * nstep[-1]) / nroll_i / nstep_i)
number = max(20, number)

# Generate plots
nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i), number=number)
pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=number)
nt_res /= number
pt_res /= number

# times = [time.time()]
# for i in range(number):
# rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i)
# times.append(time.time())
# dt = np.diff(times)
# nt_res = np.mean(dt)

# times = [time.time()]
# for i in range(number):
# pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i)
# times.append(time.time())
# dt = np.diff(times)
# pt_res = np.mean(dt)

print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f} {}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res, number))

# TODO generate plots

if __name__ == '__main__':
benchmark_rollout()

0 comments on commit 1fa1727

Please sign in to comment.