Skip to content

Commit

Permalink
[RLlib] Added test version of BC algorithm based on RLModules an RLTr…
Browse files Browse the repository at this point in the history
…ainers (ray-project#32471)

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
kouroshHakha authored and edoakes committed Mar 22, 2023
1 parent c570cdc commit 7a6ad1f
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 78 deletions.
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,13 @@ py_test(
srcs = ["core/rl_trainer/torch/tests/test_torch_rl_trainer.py"]
)

py_test(
name = "test_bc_algorithm",
tags = ["team:rllib", "core"],
size = "medium",
srcs = ["core/testing/tests/test_bc_algorithm.py"]
)

# --------------------------------------------------------------------
# Models and Distributions
# rllib/models/
Expand Down
26 changes: 13 additions & 13 deletions rllib/core/rl_module/tests/test_marl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def test_from_config(self):
module1 = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)
module2 = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

multi_agent_dict = {"module1": module1, "module2": module2}
Expand All @@ -43,11 +43,11 @@ def test_from_multi_agent_config(self):
"modules": {
"module1": SingleAgentRLModuleSpec(
module_class=DiscreteBCTorchModule,
model_config={"hidden_dim": 64},
model_config={"fcnet_hiddens": [64]},
),
"module2": SingleAgentRLModuleSpec(
module_class=DiscreteBCTorchModule,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
),
},
"observation_space": env.observation_space, # this is common
Expand All @@ -68,7 +68,7 @@ def test_as_multi_agent(self):
marl_module = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
).as_multi_agent()

self.assertNotIsInstance(marl_module, DiscreteBCTorchModule)
Expand All @@ -87,7 +87,7 @@ def test_get_set_state(self):
module = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
).as_multi_agent()

state = module.get_state()
Expand All @@ -101,7 +101,7 @@ def test_get_set_state(self):
module2 = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
).as_multi_agent()
state2 = module2.get_state()
check(state, state2, false=True)
Expand All @@ -119,15 +119,15 @@ def test_add_remove_modules(self):
module = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
).as_multi_agent()

module.add_module(
"test",
DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
),
)
self.assertEqual(set(module.keys()), {DEFAULT_POLICY_ID, "test"})
Expand All @@ -142,7 +142,7 @@ def test_add_remove_modules(self):
DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
),
),
)
Expand All @@ -152,7 +152,7 @@ def test_add_remove_modules(self):
DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
),
override=True,
)
Expand Down Expand Up @@ -239,12 +239,12 @@ def test_serialize_deserialize(self):
module1 = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)
module2 = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

multi_agent_dict = {"module1": module1, "module2": module2}
Expand Down
12 changes: 6 additions & 6 deletions rllib/core/rl_module/tests/test_rl_module_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def build(self):
# this handles all implementation details
config = {
"input_dim": self.observation_space.shape[0],
"hidden_dim": self.model_config["hidden_dim"],
"hidden_dim": self.model_config["fcnet_hiddens"][0],
"output_dim": self.action_space.n,
}
return self.module_class(**config)
Expand All @@ -48,7 +48,7 @@ def test_single_agent_spec(self):
module_class=module_class,
observation_space=env.observation_space,
action_space=env.action_space,
model_config={"hidden_dim": 64},
model_config={"fcnet_hiddens": [64]},
)

module = spec.build()
Expand All @@ -63,7 +63,7 @@ def test_customized_single_agent_spec(self):
module_class=module_class,
observation_space=env.observation_space,
action_space=env.action_space,
model_config={"hidden_dim": 64},
model_config={"fcnet_hiddens": [64]},
)
module = spec.build()
self.assertIsInstance(module, module_class)
Expand All @@ -81,7 +81,7 @@ def test_multi_agent_spec(self):
module_class=module_class,
observation_space=env.observation_space,
action_space=env.action_space,
model_config={"hidden_dim": 32 * (i + 1)},
model_config={"fcnet_hiddens": [32 * (i + 1)]},
)

spec = MultiAgentRLModuleSpec(
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_customized_multi_agent_spec(self):
}
),
action_space=gym.spaces.Discrete(action_dims[0]),
model_config={"hidden_dim": 128},
model_config={"fcnet_hiddens": [128]},
),
"agent_2": SingleAgentRLModuleSpec(
module_class=module_cls,
Expand All @@ -133,7 +133,7 @@ def test_customized_multi_agent_spec(self):
}
),
action_space=gym.spaces.Discrete(action_dims[1]),
model_config={"hidden_dim": 128},
model_config={"fcnet_hiddens": [128]},
),
},
)
Expand Down
12 changes: 6 additions & 6 deletions rllib/core/rl_module/tf/tests/test_tf_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_compilation(self):
module = DiscreteBCTFModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

self.assertIsInstance(module, TfRLModule)
Expand All @@ -30,7 +30,7 @@ def test_forward_train(self):
module = DiscreteBCTFModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

obs_shape = env.observation_space.shape
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_forward(self):
module = DiscreteBCTFModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

obs_shape = env.observation_space.shape
Expand All @@ -78,7 +78,7 @@ def test_get_set_state(self):
module = DiscreteBCTFModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

state = module.get_state()
Expand All @@ -87,7 +87,7 @@ def test_get_set_state(self):
module2 = DiscreteBCTFModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)
state2 = module2.get_state()
check(state["policy"][0], state2["policy"][0], false=True)
Expand All @@ -101,7 +101,7 @@ def test_serialize_deserialize(self):
module = DiscreteBCTFModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

# create a new module from the old module
Expand Down
12 changes: 6 additions & 6 deletions rllib/core/rl_module/torch/tests/test_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_compilation(self):
module = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

self.assertIsInstance(module, TorchRLModule)
Expand All @@ -29,7 +29,7 @@ def test_forward_train(self):
module = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

obs_shape = env.observation_space.shape
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_forward(self):
module = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

obs_shape = env.observation_space.shape
Expand All @@ -73,7 +73,7 @@ def test_get_set_state(self):
module = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

state = module.get_state()
Expand All @@ -82,7 +82,7 @@ def test_get_set_state(self):
module2 = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)
state2 = module2.get_state()
check(state, state2, false=True)
Expand All @@ -96,7 +96,7 @@ def test_serialize_deserialize(self):
module = DiscreteBCTorchModule.from_model_config(
env.observation_space,
env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
)

# create a new module from the old module
Expand Down
4 changes: 2 additions & 2 deletions rllib/core/rl_trainer/tests/test_rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_trainer() -> RLTrainer:
module_class=DiscreteBCTFModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config={"hidden_dim": 32},
model_config={"fcnet_hiddens": [32]},
),
optimizer_config={"lr": 1e-3},
trainer_scaling_config=TrainerScalingConfig(),
Expand Down Expand Up @@ -127,7 +127,7 @@ def set_optimizer_fn(module):
module_class=DiscreteBCTFModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config={"hidden_dim": 16},
model_config={"fcnet_hiddens": [16]},
),
set_optimizer_fn=set_optimizer_fn,
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/core/rl_trainer/tests/test_trainer_runner_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_trainer_runner_build_from_algorithm_config(self):
AlgorithmConfig()
.rl_module(rl_module_class=DiscreteBCTFModule)
.training(rl_trainer_class=BCTfRLTrainer)
.training(model={"hidden_dim": 32})
.training(model={"fcnet_hiddens": [32]})
)
config.freeze()
runner_config = config.get_trainer_runner_config(
Expand Down
18 changes: 3 additions & 15 deletions rllib/core/rl_trainer/torch/tests/test_torch_rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,15 @@
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
from ray.rllib.core.testing.torch.bc_rl_trainer import BCTorchRLTrainer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.test_utils import check, get_cartpole_dataset_reader
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.core.rl_trainer.scaling_config import TrainerScalingConfig
from ray.rllib.core.testing.utils import get_rl_trainer


def _get_trainer() -> RLTrainer:
env = gym.make("CartPole-v1")

trainer = BCTorchRLTrainer(
module_spec=SingleAgentRLModuleSpec(
module_class=DiscreteBCTorchModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config={"hidden_dim": 32},
),
optimizer_config={"lr": 1e-3},
trainer_scaling_config=TrainerScalingConfig(),
)

trainer = get_rl_trainer("torch", env)
trainer.build()

return trainer
Expand Down Expand Up @@ -125,7 +113,7 @@ def set_optimizer_fn(module):
module_class=DiscreteBCTorchModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config={"hidden_dim": 16},
model_config={"fcnet_hiddens": [16]},
),
set_optimizer_fn=set_optimizer_fn,
)
Expand Down
45 changes: 45 additions & 0 deletions rllib/core/testing/bc_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Contains example implementation of a custom algorithm.
Note: It doesn't include any real use-case functionality; it only serves as an example
to test the algorithm construction and customization.
"""

from ray.rllib.algorithms import Algorithm, AlgorithmConfig
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
from ray.rllib.core.testing.torch.bc_rl_trainer import BCTorchRLTrainer
from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule
from ray.rllib.core.testing.tf.bc_rl_trainer import BCTfRLTrainer


class BCConfigTest(AlgorithmConfig):
def __init__(self, algo_class=None):
super().__init__(algo_class=algo_class or BCAlgorithmTest)

def get_default_rl_module_class(self):
if self.framework_str == "torch":
return DiscreteBCTorchModule
elif self.framework_str == "tf2":
return DiscreteBCTFModule

def get_default_rl_trainer_class(self):
if self.framework_str == "torch":
return BCTorchRLTrainer
elif self.framework_str == "tf2":
return BCTfRLTrainer


class BCAlgorithmTest(Algorithm):
@classmethod
def get_default_policy_class(cls, config: AlgorithmConfig):
if config.framework_str == "torch":
return TorchPolicyV2
elif config.framework_str == "tf2":
return EagerTFPolicyV2
else:
raise ValueError("Unknown framework: {}".format(config.framework_str))

def training_step(self):
# do nothing.
return {}
Loading

0 comments on commit 7a6ad1f

Please sign in to comment.