-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Added test version of BC algorithm based on RLModules an RLTr…
…ainers (#32471) Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
- Loading branch information
1 parent
997e95e
commit 4ffa7fd
Showing
13 changed files
with
138 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
"""Contains example implementation of a custom algorithm. | ||
Note: It doesn't include any real use-case functionality; it only serves as an example | ||
to test the algorithm construction and customization. | ||
""" | ||
|
||
from ray.rllib.algorithms import Algorithm, AlgorithmConfig | ||
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 | ||
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 | ||
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule | ||
from ray.rllib.core.testing.torch.bc_rl_trainer import BCTorchRLTrainer | ||
from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule | ||
from ray.rllib.core.testing.tf.bc_rl_trainer import BCTfRLTrainer | ||
|
||
|
||
class BCConfigTest(AlgorithmConfig): | ||
def __init__(self, algo_class=None): | ||
super().__init__(algo_class=algo_class or BCAlgorithmTest) | ||
|
||
def get_default_rl_module_class(self): | ||
if self.framework_str == "torch": | ||
return DiscreteBCTorchModule | ||
elif self.framework_str == "tf2": | ||
return DiscreteBCTFModule | ||
|
||
def get_default_rl_trainer_class(self): | ||
if self.framework_str == "torch": | ||
return BCTorchRLTrainer | ||
elif self.framework_str == "tf2": | ||
return BCTfRLTrainer | ||
|
||
|
||
class BCAlgorithmTest(Algorithm): | ||
@classmethod | ||
def get_default_policy_class(cls, config: AlgorithmConfig): | ||
if config.framework_str == "torch": | ||
return TorchPolicyV2 | ||
elif config.framework_str == "tf2": | ||
return EagerTFPolicyV2 | ||
else: | ||
raise ValueError("Unknown framework: {}".format(config.framework_str)) | ||
|
||
def training_step(self): | ||
# do nothing. | ||
return {} |
Oops, something went wrong.