diff --git a/python/mujoco/rollout_benchmark.py b/python/mujoco/rollout_benchmark.py index 92de631a27..74ccac1e70 100644 --- a/python/mujoco/rollout_benchmark.py +++ b/python/mujoco/rollout_benchmark.py @@ -16,6 +16,7 @@ import concurrent.futures import threading +import time import timeit import mujoco @@ -81,12 +82,33 @@ def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml' print('running benchmark') for nroll_i in nroll: + print('roll', nroll_i) for nstep_i in nstep: - nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], nstep=nstep_i), number=10) - pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=10) - print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res)) + number = int((1*nroll[-1] * nstep[-1]) / nroll_i / nstep_i) + number = max(20, number) - # Generate plots + nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i), number=number) + pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=number) + nt_res /= number + pt_res /= number + + # times = [time.time()] + # for i in range(number): + # rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i) + # times.append(time.time()) + # dt = np.diff(times) + # nt_res = np.mean(dt) + + # times = [time.time()] + # for i in range(number): + # pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i) + # times.append(time.time()) + # dt = np.diff(times) + # pt_res = np.mean(dt) + + print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f} {}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res, number)) + + # TODO generate plots if __name__ == '__main__': benchmark_rollout()