Skip to content

Commit 6bc201a

Browse files
committed
[Feature] Trainer Algorithms - Configuration System
ghstack-source-id: 5aa0e7e Pull-Request: #3189
1 parent dc21523 commit 6bc201a

File tree

8 files changed

+401
-1
lines changed

8 files changed

+401
-1
lines changed

torchrl/trainers/algorithms/configs/__init__.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
ValueModelConfig,
8686
)
8787
from torchrl.trainers.algorithms.configs.objectives import (
88+
GAEConfig,
8889
HardUpdateConfig,
8990
LossConfig,
9091
PPOLossConfig,
@@ -126,6 +127,7 @@
126127
InitTrackerConfig,
127128
KLRewardTransformConfig,
128129
LineariseRewardsConfig,
130+
ModuleTransformConfig,
129131
MultiActionConfig,
130132
MultiStepTransformConfig,
131133
NoopResetEnvConfig,
@@ -179,6 +181,18 @@
179181
SGDConfig,
180182
SparseAdamConfig,
181183
)
184+
from torchrl.trainers.algorithms.configs.weight_sync_schemes import (
185+
DistributedWeightSyncSchemeConfig,
186+
MultiProcessWeightSyncSchemeConfig,
187+
NoWeightSyncSchemeConfig,
188+
RayModuleTransformSchemeConfig,
189+
RayWeightSyncSchemeConfig,
190+
RPCWeightSyncSchemeConfig,
191+
SharedMemWeightSyncSchemeConfig,
192+
VLLMDoubleBufferSyncSchemeConfig,
193+
VLLMWeightSyncSchemeConfig,
194+
WeightSyncSchemeConfig,
195+
)
182196
from torchrl.trainers.algorithms.configs.weight_update import (
183197
DistributedWeightUpdaterConfig,
184198
MultiProcessedWeightUpdaterConfig,
@@ -273,6 +287,7 @@
273287
"InitTrackerConfig",
274288
"KLRewardTransformConfig",
275289
"LineariseRewardsConfig",
290+
"ModuleTransformConfig",
276291
"MultiActionConfig",
277292
"MultiStepTransformConfig",
278293
"NoopResetEnvConfig",
@@ -330,6 +345,8 @@
330345
"LossConfig",
331346
"PPOLossConfig",
332347
"SACLossConfig",
348+
# Value functions
349+
"GAEConfig",
333350
# Trainers
334351
"PPOTrainerConfig",
335352
"SACTrainerConfig",
@@ -348,6 +365,17 @@
348365
"RPCWeightUpdaterConfig",
349366
"DistributedWeightUpdaterConfig",
350367
"vLLMUpdaterConfig",
368+
# Weight Sync Schemes
369+
"WeightSyncSchemeConfig",
370+
"MultiProcessWeightSyncSchemeConfig",
371+
"SharedMemWeightSyncSchemeConfig",
372+
"NoWeightSyncSchemeConfig",
373+
"RayWeightSyncSchemeConfig",
374+
"RayModuleTransformSchemeConfig",
375+
"RPCWeightSyncSchemeConfig",
376+
"DistributedWeightSyncSchemeConfig",
377+
"VLLMWeightSyncSchemeConfig",
378+
"VLLMDoubleBufferSyncSchemeConfig",
351379
]
352380

353381

@@ -356,6 +384,10 @@ def _register_configs():
356384
357385
This function is called lazily to avoid GlobalHydra initialization issues
358386
during testing. It should be called explicitly when needed.
387+
388+
To add a new config:
389+
- Write the config class in the appropriate file (e.g. torchrl/trainers/algorithms/configs/transforms.py) and add it to the __all__ list in torchrl/trainers/algorithms/configs/__init__.py
390+
- Register the config in the appropriate group, e.g. cs.store(group="transform", name="new_transform", node=NewTransformConfig)
359391
"""
360392
cs = ConfigStore.instance()
361393

@@ -461,6 +493,7 @@ def _register_configs():
461493
cs.store(group="transform", name="action_discretizer", node=ActionDiscretizerConfig)
462494
cs.store(group="transform", name="traj_counter", node=TrajCounterConfig)
463495
cs.store(group="transform", name="linearise_rewards", node=LineariseRewardsConfig)
496+
cs.store(group="transform", name="module", node=ModuleTransformConfig)
464497
cs.store(group="transform", name="conditional_skip", node=ConditionalSkipConfig)
465498
cs.store(group="transform", name="multi_action", node=MultiActionConfig)
466499
cs.store(group="transform", name="timer", node=TimerConfig)
@@ -487,6 +520,7 @@ def _register_configs():
487520
cs.store(group="transform", name="vip", node=VIPTransformConfig)
488521
cs.store(group="transform", name="vip_reward", node=VIPRewardTransformConfig)
489522
cs.store(group="transform", name="vec_norm_v2", node=VecNormV2Config)
523+
cs.store(group="transform", name="module", node=ModuleTransformConfig)
490524

491525
# =============================================================================
492526
# Loss Configurations
@@ -496,6 +530,16 @@ def _register_configs():
496530
cs.store(group="loss", name="ppo", node=PPOLossConfig)
497531
cs.store(group="loss", name="sac", node=SACLossConfig)
498532

533+
# =============================================================================
534+
# Value Function Configurations
535+
# =============================================================================
536+
537+
cs.store(group="value", name="gae", node=GAEConfig)
538+
539+
# =============================================================================
540+
# Target Net Updater Configurations
541+
# =============================================================================
542+
499543
cs.store(group="target_net_updater", name="soft", node=SoftUpdateConfig)
500544
cs.store(group="target_net_updater", name="hard", node=HardUpdateConfig)
501545

@@ -595,6 +639,41 @@ def _register_configs():
595639
)
596640
cs.store(group="weight_updater", name="vllm", node=vLLMUpdaterConfig)
597641

642+
# =============================================================================
643+
# Weight Sync Scheme Configurations
644+
# =============================================================================
645+
646+
cs.store(group="weight_sync_scheme", name="base", node=WeightSyncSchemeConfig)
647+
cs.store(
648+
group="weight_sync_scheme",
649+
name="multiprocess",
650+
node=MultiProcessWeightSyncSchemeConfig,
651+
)
652+
cs.store(
653+
group="weight_sync_scheme",
654+
name="shared_mem",
655+
node=SharedMemWeightSyncSchemeConfig,
656+
)
657+
cs.store(group="weight_sync_scheme", name="no_sync", node=NoWeightSyncSchemeConfig)
658+
cs.store(group="weight_sync_scheme", name="ray", node=RayWeightSyncSchemeConfig)
659+
cs.store(
660+
group="weight_sync_scheme",
661+
name="ray_module_transform",
662+
node=RayModuleTransformSchemeConfig,
663+
)
664+
cs.store(group="weight_sync_scheme", name="rpc", node=RPCWeightSyncSchemeConfig)
665+
cs.store(
666+
group="weight_sync_scheme",
667+
name="distributed",
668+
node=DistributedWeightSyncSchemeConfig,
669+
)
670+
cs.store(group="weight_sync_scheme", name="vllm", node=VLLMWeightSyncSchemeConfig)
671+
cs.store(
672+
group="weight_sync_scheme",
673+
name="vllm_double_buffer",
674+
node=VLLMDoubleBufferSyncSchemeConfig,
675+
)
676+
598677

599678
if not sys.version_info < (3, 10): #  type: ignore # noqa
600679
_register_configs()

torchrl/trainers/algorithms/configs/collectors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ class SyncDataCollectorConfig(DataCollectorConfig):
5151
cudagraph_policy: Any = None
5252
no_cuda_sync: bool = False
5353
weight_updater: Any = None
54+
weight_sync_schemes: Any = None
5455
track_policy_version: bool = False
56+
local_init_rb: bool = False
5557
_target_: str = "torchrl.collectors.SyncDataCollector"
5658
_partial_: bool = False
5759

@@ -94,7 +96,9 @@ class AsyncDataCollectorConfig(DataCollectorConfig):
9496
cudagraph_policy: Any = None
9597
no_cuda_sync: bool = False
9698
weight_updater: Any = None
99+
weight_sync_schemes: Any = None
97100
track_policy_version: bool = False
101+
local_init_rb: bool = False
98102
_target_: str = "torchrl.collectors.aSyncDataCollector"
99103
_partial_: bool = False
100104

@@ -136,7 +140,9 @@ class MultiSyncDataCollectorConfig(DataCollectorConfig):
136140
cudagraph_policy: Any = None
137141
no_cuda_sync: bool = False
138142
weight_updater: Any = None
143+
weight_sync_schemes: Any = None
139144
track_policy_version: bool = False
145+
local_init_rb: bool = False
140146
_target_: str = "torchrl.collectors.MultiSyncDataCollector"
141147
_partial_: bool = False
142148

@@ -179,7 +185,9 @@ class MultiaSyncDataCollectorConfig(DataCollectorConfig):
179185
cudagraph_policy: Any = None
180186
no_cuda_sync: bool = False
181187
weight_updater: Any = None
188+
weight_sync_schemes: Any = None
182189
track_policy_version: bool = False
190+
local_init_rb: bool = False
183191
_target_: str = "torchrl.collectors.MultiaSyncDataCollector"
184192
_partial_: bool = False
185193

torchrl/trainers/algorithms/configs/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ class LazyMemmapStorageConfig(StorageConfig):
254254
device: Any = None
255255
ndim: int = 1
256256
compilable: bool = False
257+
shared_init: bool = False
257258

258259

259260
@dataclass
@@ -265,6 +266,7 @@ class LazyTensorStorageConfig(StorageConfig):
265266
device: Any = None
266267
ndim: int = 1
267268
compilable: bool = False
269+
shared_init: bool = False
268270

269271

270272
@dataclass

torchrl/trainers/algorithms/configs/modules.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ class ModelConfig(ConfigBase):
202202
_partial_: bool = False
203203
in_keys: Any = None
204204
out_keys: Any = None
205+
shared: bool = False
205206

206207
def __post_init__(self) -> None:
207208
"""Post-initialization hook for model configurations."""
@@ -226,7 +227,7 @@ class TensorDictModuleConfig(ModelConfig):
226227

227228
def __post_init__(self) -> None:
228229
"""Post-initialization hook for TensorDict module configurations."""
229-
super().__post_init__()
230+
return super().__post_init__()
230231

231232

232233
@dataclass
@@ -312,6 +313,7 @@ def _make_tanh_normal_model(*args, **kwargs):
312313
return_log_prob = kwargs.pop("return_log_prob", False)
313314
eval_mode = kwargs.pop("eval_mode", False)
314315
exploration_type = kwargs.pop("exploration_type", "RANDOM")
316+
shared = kwargs.pop("shared", False)
315317

316318
# Now instantiate the network
317319
if hasattr(network, "_target_"):
@@ -328,6 +330,8 @@ def _make_tanh_normal_model(*args, **kwargs):
328330
)
329331

330332
module = TensorDictModule(network, in_keys=in_keys, out_keys=param_keys)
333+
if shared:
334+
module = module.share_memory()
331335

332336
# Create ProbabilisticTensorDictModule
333337
prob_module = ProbabilisticTensorDictModule(
@@ -350,4 +354,7 @@ def _make_value_model(*args, **kwargs):
350354
from torchrl.modules import ValueOperator
351355

352356
network = kwargs.pop("network")
357+
shared = kwargs.pop("shared", False)
358+
if shared:
359+
network = network.share_memory()
353360
return ValueOperator(network, **kwargs)

torchrl/trainers/algorithms/configs/objectives.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,30 @@ class HardUpdateConfig(TargetNetUpdaterConfig):
148148

149149
_target_: str = "torchrl.objectives.utils.HardUpdate."
150150
value_network_update_interval: int = 1000
151+
152+
153+
@dataclass
154+
class GAEConfig(LossConfig):
155+
"""A class to configure a GAELoss."""
156+
157+
gamma: float | None = None
158+
lmbda: float | None = None
159+
value_network: Any = None
160+
average_gae: bool = True
161+
differentiable: bool = False
162+
vectorized: bool | None = None
163+
skip_existing: bool | None = None
164+
advantage_key: str | None = None
165+
value_target_key: str | None = None
166+
value_key: str | None = None
167+
shifted: bool = False
168+
device: Any = None
169+
time_dim: int | None = None
170+
auto_reset_env: bool = False
171+
deactivate_vmap: bool = False
172+
_target_: str = "torchrl.objectives.value.GAE"
173+
_partial_: bool = False
174+
175+
def __post_init__(self) -> None:
176+
"""Post-initialization hook for GAELoss configurations."""
177+
super().__post_init__()

0 commit comments

Comments
 (0)