8585 ValueModelConfig ,
8686)
8787from torchrl .trainers .algorithms .configs .objectives import (
88+ GAEConfig ,
8889 HardUpdateConfig ,
8990 LossConfig ,
9091 PPOLossConfig ,
126127 InitTrackerConfig ,
127128 KLRewardTransformConfig ,
128129 LineariseRewardsConfig ,
130+ ModuleTransformConfig ,
129131 MultiActionConfig ,
130132 MultiStepTransformConfig ,
131133 NoopResetEnvConfig ,
179181 SGDConfig ,
180182 SparseAdamConfig ,
181183)
184+ from torchrl .trainers .algorithms .configs .weight_sync_schemes import (
185+ DistributedWeightSyncSchemeConfig ,
186+ MultiProcessWeightSyncSchemeConfig ,
187+ NoWeightSyncSchemeConfig ,
188+ RayModuleTransformSchemeConfig ,
189+ RayWeightSyncSchemeConfig ,
190+ RPCWeightSyncSchemeConfig ,
191+ SharedMemWeightSyncSchemeConfig ,
192+ VLLMDoubleBufferSyncSchemeConfig ,
193+ VLLMWeightSyncSchemeConfig ,
194+ WeightSyncSchemeConfig ,
195+ )
182196from torchrl .trainers .algorithms .configs .weight_update import (
183197 DistributedWeightUpdaterConfig ,
184198 MultiProcessedWeightUpdaterConfig ,
273287 "InitTrackerConfig" ,
274288 "KLRewardTransformConfig" ,
275289 "LineariseRewardsConfig" ,
290+ "ModuleTransformConfig" ,
276291 "MultiActionConfig" ,
277292 "MultiStepTransformConfig" ,
278293 "NoopResetEnvConfig" ,
330345 "LossConfig" ,
331346 "PPOLossConfig" ,
332347 "SACLossConfig" ,
348+ # Value functions
349+ "GAEConfig" ,
333350 # Trainers
334351 "PPOTrainerConfig" ,
335352 "SACTrainerConfig" ,
348365 "RPCWeightUpdaterConfig" ,
349366 "DistributedWeightUpdaterConfig" ,
350367 "vLLMUpdaterConfig" ,
368+ # Weight Sync Schemes
369+ "WeightSyncSchemeConfig" ,
370+ "MultiProcessWeightSyncSchemeConfig" ,
371+ "SharedMemWeightSyncSchemeConfig" ,
372+ "NoWeightSyncSchemeConfig" ,
373+ "RayWeightSyncSchemeConfig" ,
374+ "RayModuleTransformSchemeConfig" ,
375+ "RPCWeightSyncSchemeConfig" ,
376+ "DistributedWeightSyncSchemeConfig" ,
377+ "VLLMWeightSyncSchemeConfig" ,
378+ "VLLMDoubleBufferSyncSchemeConfig" ,
351379]
352380
353381
@@ -356,6 +384,10 @@ def _register_configs():
356384
357385 This function is called lazily to avoid GlobalHydra initialization issues
358386 during testing. It should be called explicitly when needed.
387+
388+ To add a new config:
389+ - Write the config class in the appropriate file (e.g. torchrl/trainers/algorithms/configs/transforms.py) and add it to the __all__ list in torchrl/trainers/algorithms/configs/__init__.py
390+ - Register the config in the appropriate group, e.g. cs.store(group="transform", name="new_transform", node=NewTransformConfig)
359391 """
360392 cs = ConfigStore .instance ()
361393
@@ -461,6 +493,7 @@ def _register_configs():
461493 cs .store (group = "transform" , name = "action_discretizer" , node = ActionDiscretizerConfig )
462494 cs .store (group = "transform" , name = "traj_counter" , node = TrajCounterConfig )
463495 cs .store (group = "transform" , name = "linearise_rewards" , node = LineariseRewardsConfig )
496+ cs .store (group = "transform" , name = "module" , node = ModuleTransformConfig )
464497 cs .store (group = "transform" , name = "conditional_skip" , node = ConditionalSkipConfig )
465498 cs .store (group = "transform" , name = "multi_action" , node = MultiActionConfig )
466499 cs .store (group = "transform" , name = "timer" , node = TimerConfig )
@@ -487,6 +520,7 @@ def _register_configs():
487520 cs .store (group = "transform" , name = "vip" , node = VIPTransformConfig )
488521 cs .store (group = "transform" , name = "vip_reward" , node = VIPRewardTransformConfig )
489522 cs .store (group = "transform" , name = "vec_norm_v2" , node = VecNormV2Config )
523+ cs .store (group = "transform" , name = "module" , node = ModuleTransformConfig )
490524
491525 # =============================================================================
492526 # Loss Configurations
@@ -496,6 +530,16 @@ def _register_configs():
496530 cs .store (group = "loss" , name = "ppo" , node = PPOLossConfig )
497531 cs .store (group = "loss" , name = "sac" , node = SACLossConfig )
498532
533+ # =============================================================================
534+ # Value Function Configurations
535+ # =============================================================================
536+
537+ cs .store (group = "value" , name = "gae" , node = GAEConfig )
538+
539+ # =============================================================================
540+ # Target Net Updater Configurations
541+ # =============================================================================
542+
499543 cs .store (group = "target_net_updater" , name = "soft" , node = SoftUpdateConfig )
500544 cs .store (group = "target_net_updater" , name = "hard" , node = HardUpdateConfig )
501545
@@ -595,6 +639,41 @@ def _register_configs():
595639 )
596640 cs .store (group = "weight_updater" , name = "vllm" , node = vLLMUpdaterConfig )
597641
642+ # =============================================================================
643+ # Weight Sync Scheme Configurations
644+ # =============================================================================
645+
646+ cs .store (group = "weight_sync_scheme" , name = "base" , node = WeightSyncSchemeConfig )
647+ cs .store (
648+ group = "weight_sync_scheme" ,
649+ name = "multiprocess" ,
650+ node = MultiProcessWeightSyncSchemeConfig ,
651+ )
652+ cs .store (
653+ group = "weight_sync_scheme" ,
654+ name = "shared_mem" ,
655+ node = SharedMemWeightSyncSchemeConfig ,
656+ )
657+ cs .store (group = "weight_sync_scheme" , name = "no_sync" , node = NoWeightSyncSchemeConfig )
658+ cs .store (group = "weight_sync_scheme" , name = "ray" , node = RayWeightSyncSchemeConfig )
659+ cs .store (
660+ group = "weight_sync_scheme" ,
661+ name = "ray_module_transform" ,
662+ node = RayModuleTransformSchemeConfig ,
663+ )
664+ cs .store (group = "weight_sync_scheme" , name = "rpc" , node = RPCWeightSyncSchemeConfig )
665+ cs .store (
666+ group = "weight_sync_scheme" ,
667+ name = "distributed" ,
668+ node = DistributedWeightSyncSchemeConfig ,
669+ )
670+ cs .store (group = "weight_sync_scheme" , name = "vllm" , node = VLLMWeightSyncSchemeConfig )
671+ cs .store (
672+ group = "weight_sync_scheme" ,
673+ name = "vllm_double_buffer" ,
674+ node = VLLMDoubleBufferSyncSchemeConfig ,
675+ )
676+
598677
599678if not sys .version_info < (3 , 10 ): # type: ignore # noqa
600679 _register_configs ()
0 commit comments