From 3f4a1b859bafc4d4b120d24b60eb4fa62322a4c7 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 9 Oct 2023 10:23:57 +0100 Subject: [PATCH 1/2] update Signed-off-by: Matteo Bettini --- torchrl/collectors/collectors.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0d5443b22b4..d66c5e3d7b1 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -398,6 +398,7 @@ class SyncDataCollector(DataCollectorBase): policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. + If provided, it will be rounded to the closest multiple of frames_per_batch. Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. @@ -599,13 +600,26 @@ def __init__( self.total_frames = total_frames self.reset_at_each_iter = reset_at_each_iter self.init_random_frames = init_random_frames + if ( + init_random_frames is not None + and init_random_frames % frames_per_batch != 0 + and RL_WARNINGS + ): + warnings.warn( + f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " + f" this results in more init_random_frames than requested" + f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.postproc = postproc if self.postproc is not None and hasattr(self.postproc, "to"): self.postproc.to(self.storing_device) if frames_per_batch % self.n_env != 0 and RL_WARNINGS: warnings.warn( - f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, " - f" this results in more frames_per_batch per iteration that requested." + f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " + f" this results in more frames_per_batch per iteration that requested" + f" ({-(-frames_per_batch // self.n_env) * self.n_env})." "To silence this message, set the environment variable RL_WARNINGS to False." ) self.requested_frames_per_batch = frames_per_batch @@ -1026,6 +1040,7 @@ class _MultiDataCollector(DataCollectorBase): policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. + If provided, it will be rounded to the closest multiple of frames_per_batch. Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. From a569bc3eb3dc67104233e6ff48fb6a7bcd7069b5 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 9 Oct 2023 10:27:54 +0100 Subject: [PATCH 2/2] typo Signed-off-by: Matteo Bettini --- torchrl/collectors/collectors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index d66c5e3d7b1..afd8ae61765 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -398,7 +398,7 @@ class SyncDataCollector(DataCollectorBase): policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - If provided, it will be rounded to the closest multiple of frames_per_batch. + If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. @@ -1040,7 +1040,7 @@ class _MultiDataCollector(DataCollectorBase): policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - If provided, it will be rounded to the closest multiple of frames_per_batch. + If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection.