Skip to content

Commit

Permalink
fix negative collector time (#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored Mar 26, 2022
1 parent 2a9c928 commit 6ab9860
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_extras_require() -> str:
"dev": [
"sphinx<4",
"sphinx_rtd_theme",
"jinja2<3.1", # temporary fix
"sphinxcontrib-bibtex",
"flake8",
"flake8-bugbear",
Expand Down
36 changes: 19 additions & 17 deletions tianshou/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def test_episode(

def gather_info(
start_time: float,
train_c: Optional[Collector],
test_c: Optional[Collector],
train_collector: Optional[Collector],
test_collector: Optional[Collector],
best_reward: float,
best_reward_std: float,
) -> Dict[str, Union[float, str]]:
Expand All @@ -57,38 +57,40 @@ def gather_info(
* ``best_reward`` the best reward over the test results;
* ``duration`` the total elapsed time.
"""
duration = time.time() - start_time
duration = max(0, time.time() - start_time)
model_time = duration
result: Dict[str, Union[float, str]] = {
"duration": f"{duration:.2f}s",
"train_time/model": f"{model_time:.2f}s",
}
if test_c is not None:
model_time = duration - test_c.collect_time
test_speed = test_c.collect_step / test_c.collect_time
if test_collector is not None:
model_time = max(0, duration - test_collector.collect_time)
test_speed = test_collector.collect_step / test_collector.collect_time
result.update(
{
"test_step": test_c.collect_step,
"test_episode": test_c.collect_episode,
"test_time": f"{test_c.collect_time:.2f}s",
"test_step": test_collector.collect_step,
"test_episode": test_collector.collect_episode,
"test_time": f"{test_collector.collect_time:.2f}s",
"test_speed": f"{test_speed:.2f} step/s",
"best_reward": best_reward,
"best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
"duration": f"{duration:.2f}s",
"train_time/model": f"{model_time:.2f}s",
}
)
if train_c is not None:
model_time -= train_c.collect_time
if test_c is not None:
train_speed = train_c.collect_step / (duration - test_c.collect_time)
if train_collector is not None:
model_time = max(0, model_time - train_collector.collect_time)
if test_collector is not None:
train_speed = train_collector.collect_step / (
duration - test_collector.collect_time
)
else:
train_speed = train_c.collect_step / duration
train_speed = train_collector.collect_step / duration
result.update(
{
"train_step": train_c.collect_step,
"train_episode": train_c.collect_episode,
"train_time/collector": f"{train_c.collect_time:.2f}s",
"train_step": train_collector.collect_step,
"train_episode": train_collector.collect_episode,
"train_time/collector": f"{train_collector.collect_time:.2f}s",
"train_time/model": f"{model_time:.2f}s",
"train_speed": f"{train_speed:.2f} step/s",
}
Expand Down

0 comments on commit 6ab9860

Please sign in to comment.