77
88import re
99import warnings
10+ from collections .abc import MutableSequence
1011
1112from textwrap import indent
12- from typing import Any , Dict , List , Optional
13+ from typing import Any , Dict , List , Optional , OrderedDict , overload
1314
1415import torch
1516
@@ -621,9 +622,12 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
621622 log(p(z | x, y))
622623
623624 Args:
624- *modules (sequence of TensorDictModules ): An ordered sequence of
625- :class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
625+ *modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule ): An ordered sequence of
626+ :class:`~tensordict.nn.TensorDictModule` instances, usually terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
626627 to be run sequentially.
628+ The modules can be instances of TensorDictModuleBase or any other function that matches this signature.
629+ Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
630+ and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
627631
628632 Keyword Args:
629633 partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some
@@ -791,6 +795,28 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
791795
792796 """
793797
798+ @overload
799+ def __init__ (
800+ self ,
801+ modules : OrderedDict [str , TensorDictModuleBase | ProbabilisticTensorDictModule ],
802+ partial_tolerant : bool = False ,
803+ return_composite : bool | None = None ,
804+ aggregate_probabilities : bool | None = None ,
805+ include_sum : bool | None = None ,
806+ inplace : bool | None = None ,
807+ ) -> None : ...
808+
809+ @overload
810+ def __init__ (
811+ self ,
812+ modules : List [TensorDictModuleBase | ProbabilisticTensorDictModule ],
813+ partial_tolerant : bool = False ,
814+ return_composite : bool | None = None ,
815+ aggregate_probabilities : bool | None = None ,
816+ include_sum : bool | None = None ,
817+ inplace : bool | None = None ,
818+ ) -> None : ...
819+
794820 def __init__ (
795821 self ,
796822 * modules : TensorDictModuleBase | ProbabilisticTensorDictModule ,
@@ -805,7 +831,14 @@ def __init__(
805831 "ProbabilisticTensorDictSequential must consist of zero or more "
806832 "TensorDictModules followed by a ProbabilisticTensorDictModule"
807833 )
808- if not return_composite and not isinstance (
834+ self ._ordered_dict = False
835+ if len (modules ) == 1 and isinstance (modules [0 ], (OrderedDict , MutableSequence )):
836+ if isinstance (modules [0 ], OrderedDict ):
837+ modules_list = list (modules [0 ].values ())
838+ self ._ordered_dict = True
839+ else :
840+ modules = modules_list = list (modules [0 ])
841+ elif not return_composite and not isinstance (
809842 modules [- 1 ],
810843 (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential ),
811844 ):
@@ -814,13 +847,22 @@ def __init__(
814847 "an instance of ProbabilisticTensorDictModule or another "
815848 "ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)."
816849 )
850+ else :
851+ modules_list = list (modules )
852+
817853 # if the modules not including the final probabilistic module return the sampled
818- # key we wont be sampling it again, in that case
854+ # key we won't be sampling it again, in that case
819855 # ProbabilisticTensorDictSequential is presumably used to return the
820856 # distribution using `get_dist` or to sample log_probabilities
821- _ , out_keys = self ._compute_in_and_out_keys (modules [:- 1 ])
822- self ._requires_sample = modules [- 1 ].out_keys [0 ] not in set (out_keys )
823- self .__dict__ ["_det_part" ] = TensorDictSequential (* modules [:- 1 ])
857+ _ , out_keys = self ._compute_in_and_out_keys (modules_list [:- 1 ])
858+ self ._requires_sample = modules_list [- 1 ].out_keys [0 ] not in set (out_keys )
859+ if self ._ordered_dict :
860+ self .__dict__ ["_det_part" ] = TensorDictSequential (
861+ OrderedDict (list (modules [0 ].items ())[:- 1 ])
862+ )
863+ else :
864+ self .__dict__ ["_det_part" ] = TensorDictSequential (* modules [:- 1 ])
865+
824866 super ().__init__ (* modules , partial_tolerant = partial_tolerant )
825867 self .return_composite = return_composite
826868 self .aggregate_probabilities = aggregate_probabilities
@@ -861,7 +903,7 @@ def get_dist_params(
861903 tds = self .det_part
862904 type = interaction_type ()
863905 if type is None :
864- for m in reversed (self .module ):
906+ for m in reversed (list ( self ._module_iter ()) ):
865907 if hasattr (m , "default_interaction_type" ):
866908 type = m .default_interaction_type
867909 break
@@ -873,7 +915,7 @@ def get_dist_params(
873915 @property
874916 def num_samples (self ):
875917 num_samples = ()
876- for tdm in self .module :
918+ for tdm in self ._module_iter () :
877919 if isinstance (
878920 tdm , (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential )
879921 ):
@@ -917,7 +959,7 @@ def get_dist(
917959
918960 td_copy = tensordict .copy ()
919961 dists = {}
920- for i , tdm in enumerate (self .module ):
962+ for i , tdm in enumerate (self ._module_iter () ):
921963 if isinstance (
922964 tdm , (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential )
923965 ):
@@ -957,12 +999,21 @@ def default_interaction_type(self):
957999 encountered is returned. If no such value is found, a default `interaction_type()` is returned.
9581000
9591001 """
960- for m in reversed (self .module ):
1002+ for m in reversed (list ( self ._module_iter ()) ):
9611003 interaction = getattr (m , "default_interaction_type" , None )
9621004 if interaction is not None :
9631005 return interaction
9641006 return interaction_type ()
9651007
1008+ @property
1009+ def _last_module (self ):
1010+ if not self ._ordered_dict :
1011+ return self .module [- 1 ]
1012+ mod = None
1013+ for mod in self ._module_iter (): # noqa: B007
1014+ continue
1015+ return mod
1016+
9661017 def log_prob (
9671018 self ,
9681019 tensordict ,
@@ -1079,7 +1130,7 @@ def log_prob(
10791130 include_sum = include_sum ,
10801131 ** kwargs ,
10811132 )
1082- last_module : ProbabilisticTensorDictModule = self .module [ - 1 ]
1133+ last_module : ProbabilisticTensorDictModule = self ._last_module
10831134 out = last_module .log_prob (tensordict_inp , dist = dist , ** kwargs )
10841135 if is_tensor_collection (out ):
10851136 if tensordict_out is not None :
@@ -1138,7 +1189,7 @@ def forward(
11381189 else :
11391190 tensordict_exec = tensordict
11401191 if self .return_composite :
1141- for m in self .module :
1192+ for m in self ._module_iter () :
11421193 if isinstance (
11431194 m , (ProbabilisticTensorDictModule , ProbabilisticTensorDictModule )
11441195 ):
@@ -1149,7 +1200,7 @@ def forward(
11491200 tensordict_exec = m (tensordict_exec , ** kwargs )
11501201 else :
11511202 tensordict_exec = self .get_dist_params (tensordict_exec , ** kwargs )
1152- tensordict_exec = self .module [ - 1 ] (
1203+ tensordict_exec = self ._last_module (
11531204 tensordict_exec , _requires_sample = self ._requires_sample
11541205 )
11551206 if tensordict_out is not None :
0 commit comments