Skip to content

Commit

Permalink
[RLlib; Offline RL] CQL: Support multi-GPU/CPU setup and different le…
Browse files Browse the repository at this point in the history
…arning rates for actor, critic, and alpha. (ray-project#47402)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
simonsays1980 authored and ujjawal-khare committed Oct 12, 2024
1 parent b3cd36b commit 91bd1e2
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 253 deletions.
48 changes: 41 additions & 7 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
from typing import Optional, Type, Union

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.algorithms.sac.sac import (
SAC,
SACConfig,
)
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.algorithms.sac.sac import (
SAC,
SACConfig,
)
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
Expand Down Expand Up @@ -48,7 +49,7 @@
SAMPLE_TIMER,
TIMERS,
)
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.typing import ResultDict, RLModuleSpecType

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
Expand Down Expand Up @@ -83,7 +84,14 @@ def __init__(self, algo_class=None):
self.lagrangian = False
self.lagrangian_thresh = 5.0
self.min_q_weight = 5.0
self.deterministic_backup = True
self.lr = 3e-4
# Note, the new stack defines learning rates for each component.
# The base learning rate `lr` has to be set to `None`, if using
# the new stack.
self.actor_lr = 1e-4,
self.critic_lr = 1e-3
self.alpha_lr = 1e-3

# Changes to Algorithm's/SACConfig's default:

Expand All @@ -105,6 +113,7 @@ def training(
lagrangian: Optional[bool] = NotProvided,
lagrangian_thresh: Optional[float] = NotProvided,
min_q_weight: Optional[float] = NotProvided,
deterministic_backup: Optional[bool] = NotProvided,
**kwargs,
) -> "CQLConfig":
"""Sets the training-related configuration.
Expand All @@ -116,6 +125,8 @@ def training(
lagrangian: Whether to use the Lagrangian for Alpha Prime (in CQL loss).
lagrangian_thresh: Lagrangian threshold.
min_q_weight: in Q weight multiplier.
deterministic_backup: If the target in the Bellman update should have an
entropy backup. Defaults to `True`.
Returns:
This updated AlgorithmConfig object.
Expand All @@ -135,6 +146,8 @@ def training(
self.lagrangian_thresh = lagrangian_thresh
if min_q_weight is not NotProvided:
self.min_q_weight = min_q_weight
if deterministic_backup is not NotProvided:
self.deterministic_backup = deterministic_backup

return self

Expand Down Expand Up @@ -234,6 +247,27 @@ def validate(self) -> None:
"Set this hyperparameter in the `AlgorithmConfig.offline_data`."
)

@override(SACConfig)
def get_default_rl_module_spec(self) -> RLModuleSpecType:
from ray.rllib.algorithms.sac.sac_catalog import SACCatalog

if self.framework_str == "torch":
from ray.rllib.algorithms.cql.torch.cql_torch_rl_module import (
CQLTorchRLModule,
)

return RLModuleSpec(module_class=CQLTorchRLModule, catalog_class=SACCatalog)
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. " "Use `torch`."
)

@property
def _model_config_auto_includes(self):
return super()._model_config_auto_includes | {
"num_actions": self.num_actions,
}


class CQL(SAC):
"""CQL (derived from SAC)."""
Expand Down
Loading

0 comments on commit 91bd1e2

Please sign in to comment.