Skip to content

Commit

Permalink
[RLlib] Update autoregressive actions example. (ray-project#47829)
Browse files Browse the repository at this point in the history
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
simonsays1980 authored and ujjawal-khare committed Oct 15, 2024
1 parent 4260777 commit 9b55915
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
50 changes: 50 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3263,6 +3263,56 @@ py_test(
args = ["--as-test", "--framework=torch", "--stop-reward=-0.012", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_impala_tf2",
main = "examples/cartpole_lstm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--run=IMPALA", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_impala_torch",
main = "examples/cartpole_lstm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--run=IMPALA", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_ppo_tf2",
main = "examples/cartpole_lstm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--run=PPO", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_ppo_torch",
main = "examples/cartpole_lstm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_ppo_torch_with_prev_a_and_r",
main = "examples/cartpole_lstm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4", "--use-prev-action", "--use-prev-reward"]
)

#@OldAPIStack
py_test(
name = "examples/centralized_critic_tf",
Expand Down
4 changes: 3 additions & 1 deletion rllib/examples/autoregressive_action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def get_cli_args():
config = (
get_trainable_cls(args.run)
.get_default_config()
.environment(CorrelatedActionsEnv)
# Batch-norm models have not been migrated to the RL Module API yet.
.api_stack(enable_rl_module_and_learner=False)
.environment(AutoRegressiveActionEnv)
.framework(args.framework)
.training(gamma=0.5)
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
Expand Down
12 changes: 7 additions & 5 deletions rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import abc
from abc import abstractmethod
from typing import Any, Dict
from typing import Dict

from ray.rllib.core import Columns
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
Expand Down Expand Up @@ -233,10 +234,11 @@ def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
return outs

@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, TensorType], embeddings=None):
# Encoder forward pass to get `embeddings`, if necessary.
if embeddings is None:
embeddings = self.encoder(batch)[ENCODER_OUT]
def compute_values(self, batch: Dict[str, TensorType]):

# Encoder forward pass.
encoder_outs = self.encoder(batch)[ENCODER_OUT]

# Value head forward pass.
vf_out = self.vf(embeddings)
# Squeeze out last dimension (single node value head).
Expand Down

0 comments on commit 9b55915

Please sign in to comment.