Skip to content

Commit

Permalink
Fix tensorboad video slow numpy->torch conversion (#1910)
Browse files Browse the repository at this point in the history
* fixed tb video docs

* updated changelog

* add comment on expected render() output

* Update changelog.rst

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
NickLucche and araffin authored Apr 26, 2024
1 parent e931750 commit 35eccaf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
6 changes: 5 additions & 1 deletion docs/guide/tensorboard.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ Here is an example of how to render an episode and log the resulting video to Te
import gymnasium as gym
import torch as th
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
Expand Down Expand Up @@ -226,6 +227,9 @@ Here is an example of how to render an episode and log the resulting video to Te
:param _locals: A dictionary containing all local variables of the callback's scope
:param _globals: A dictionary containing all global variables of the callback's scope
"""
# We expect `render()` to return a uint8 array with values in [0, 255] or a float array
# with values in [0, 1], as described in
# https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video
screen = self._eval_env.render(mode="rgb_array")
# PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
screens.append(screen.transpose(2, 0, 1))
Expand All @@ -239,7 +243,7 @@ Here is an example of how to render an episode and log the resulting video to Te
)
self.logger.record(
"trajectory/video",
Video(th.ByteTensor([screens]), fps=40),
Video(th.from_numpy(np.asarray([screens])), fps=40),
exclude=("stdout", "log", "json", "csv"),
)
return True
Expand Down
7 changes: 3 additions & 4 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Bug Fixes:

Documentation:
^^^^^^^^^^^^^^
- Added ER-MRL to the project page

- Added ER-MRL to the project page (@corentinlger)
- Updated Tensorboard Logging Videos documentation (@NickLucche)

Release 2.3.1 (2024-04-22)
--------------------------
Expand All @@ -50,7 +50,6 @@ Documentation:
- Updated SBX documentation (CrossQ and deprecated DroQ)
- Updated RL Tips and Tricks section


Release 2.3.0 (2024-03-31)
--------------------------

Expand Down Expand Up @@ -1641,4 +1640,4 @@ And all the contributors:
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche

0 comments on commit 35eccaf

Please sign in to comment.