Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Warning for init_random_frames rounding in collectors #1616

Merged
merged 2 commits into from
Oct 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ class SyncDataCollector(DataCollectorBase):
policy is ignored before it is called. This feature is mainly
intended to be used in offline/model-based settings, where a
batch of random trajectories can be used to initialize training.
If provided, it will be rounded up to the closest multiple of frames_per_batch.
Defaults to ``None`` (i.e. no random frames).
reset_at_each_iter (bool, optional): Whether environments should be reset
at the beginning of a batch collection.
Expand Down Expand Up @@ -599,13 +600,26 @@ def __init__(
self.total_frames = total_frames
self.reset_at_each_iter = reset_at_each_iter
self.init_random_frames = init_random_frames
if (
init_random_frames is not None
and init_random_frames % frames_per_batch != 0
and RL_WARNINGS
):
warnings.warn(
f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), "
f" this results in more init_random_frames than requested"
f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})."
"To silence this message, set the environment variable RL_WARNINGS to False."
)

self.postproc = postproc
if self.postproc is not None and hasattr(self.postproc, "to"):
self.postproc.to(self.storing_device)
if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
warnings.warn(
f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, "
f" this results in more frames_per_batch per iteration that requested."
f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
f" this results in more frames_per_batch per iteration that requested"
f" ({-(-frames_per_batch // self.n_env) * self.n_env})."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
self.requested_frames_per_batch = frames_per_batch
Expand Down Expand Up @@ -1026,6 +1040,7 @@ class _MultiDataCollector(DataCollectorBase):
policy is ignored before it is called. This feature is mainly
intended to be used in offline/model-based settings, where a
batch of random trajectories can be used to initialize training.
If provided, it will be rounded up to the closest multiple of frames_per_batch.
Defaults to ``None`` (i.e. no random frames).
reset_at_each_iter (bool, optional): Whether environments should be reset
at the beginning of a batch collection.
Expand Down
Loading