You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The usage of the TensorDict class simplifies the process of passing data across processes, designing general classes that are oblivious to the keys used in a specific algorithm (e.g. whether or not a action_log_prob / hidden_state key should be expected).
However, introducing new classes can prevent users from copy-paste and re-use modules (see here). We should make sure that TensorDict is used only when absolutely necessary. These cases include situations where all the content of a dictionary will be treated in a similar way:
indexing
reshaping
sending from worker to worker, device to device
concatenation / stacking
In general, TensorDict should be used for high-level classes: Agent, DataCollector, possibly probabilistic operator modules.
Objectives should not require TensorDicts in general.
However, in some cases they may need to check the trajectory length (1st dimension) or the batch size (0th dimension), or even the device. An option in those cases would be to infer those from a specific tensor in the dictionary (e.g. reward?)
Plan
Test and fix modules such that they all accept a dictionary as input.
Modify typing in this perspective.
Currently, modules return a TensorDict but we could perfectly return a regular dict.
The text was updated successfully, but these errors were encountered:
Benjamin-eecs
changed the title
Make objective modules compatible with dictionaries
[Feature Request] Make objective modules compatible with dictionaries
Jul 21, 2022
The usage of the
TensorDict
class simplifies the process of passing data across processes, designing general classes that are oblivious to the keys used in a specific algorithm (e.g. whether or not aaction_log_prob
/hidden_state
key should be expected).However, introducing new classes can prevent users from copy-paste and re-use modules (see here). We should make sure that
TensorDict
is used only when absolutely necessary. These cases include situations where all the content of a dictionary will be treated in a similar way:In general, TensorDict should be used for high-level classes:
Agent
,DataCollector
, possibly probabilistic operator modules.Objectives should not require TensorDicts in general.
However, in some cases they may need to check the trajectory length (1st dimension) or the batch size (0th dimension), or even the device. An option in those cases would be to infer those from a specific tensor in the dictionary (e.g. reward?)
Plan
TensorDict
but we could perfectly return a regulardict
.The text was updated successfully, but these errors were encountered: