diff --git a/rllib/examples/two_trainer_workflow.py b/rllib/examples/two_trainer_workflow.py index a9a525809598f..58ae217898d94 100644 --- a/rllib/examples/two_trainer_workflow.py +++ b/rllib/examples/two_trainer_workflow.py @@ -173,14 +173,20 @@ def training_step(self) -> ResultDict: # Provide entire AlgorithmConfig object, not just an override. PPOConfig() .training(num_sgd_iter=10, sgd_minibatch_size=128) - .framework("torch" if args.torch or args.mixed_torch_tf else "tf"), + .framework("torch" if args.torch or args.mixed_torch_tf else "tf") + .training(_enable_learner_api=False) + .rl_module(_enable_rl_module_api=False), ), "dqn_policy": ( DQNTorchPolicy if args.torch else DQNTFPolicy, None, None, # Provide entire AlgorithmConfig object, not just an override. - DQNConfig().training(target_network_update_freq=500).framework("tf"), + DQNConfig() + .training(target_network_update_freq=500) + .framework("tf") + .training(_enable_learner_api=False) + .rl_module(_enable_rl_module_api=False), ), } @@ -199,6 +205,9 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) .reporting(metrics_num_episodes_for_smoothing=30) + # TODO (Kourosh): Migrate this to the new RLModule / Learner API. + .training(_enable_learner_api=False) + .rl_module(_enable_rl_module_api=False) ) stop = {