-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Offline RL] CQL: Support multi-GPU/CPU setup and different learning rates for actor, critic, and alpha. #47402
[RLlib; Offline RL] CQL: Support multi-GPU/CPU setup and different learning rates for actor, critic, and alpha. #47402
Conversation
…rted to rewrite CQL loss. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…tes for actor, critic, and alpha. Multi-learner setups works. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@@ -84,6 +85,12 @@ def __init__(self, algo_class=None): | |||
self.lagrangian_thresh = 5.0 | |||
self.min_q_weight = 5.0 | |||
self.lr = 3e-4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, so for the new stack, users have to set this to None, manually? I guess this is ok (explicit is always good).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. We discussed this in the other PR concerning SAC
.
rllib/algorithms/cql/cql.py
Outdated
def _model_config_auto_includes(self): | ||
return super()._model_config_auto_includes | { | ||
"num_actions": self.num_actions, | ||
"_deterministic_loss": self._deterministic_loss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Let's remove this deterministic loss thing. It's a relic from a long time ago (2020) when I was trying to debug SAC on torch vs our old SAC on tf. It serves no real purpose and just bloats the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!! That saves us many lines of code!
# here). This is different from doing `.detach()` or `with torch.no_grads()`, | ||
# as these two methds would fully block all gradient recordings, including | ||
# the needed policy ones. | ||
all_params = list(self.pi_encoder.parameters()) + list(self.pi.parameters()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, why pi
? We need to block the q-net gradients (same as in SAC), correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all_params = (
list(self.qf.parameters())
+ list(self.qf_encoder.parameters())
+ list(self.qf_twin.parameters())
+ list(self.qf_twin_encoder.parameters())
)
for param in all_params:
param.requires_grad = False
output["q_curr"] = self.compute_q_values(q_batch_curr)
for param in all_params:
param.requires_grad = True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^ from SAC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I don't think so. That was my first impression of it. But using this kind of backward pass let's the CQL loss rise without limit. Which makes sense: The additional loss term is now merely an added constant which is different from time to time, but does not change in regard to its own amount. Removing in this forward pass the requires_grads=False
does make the CQL loss term decrease by time - as it should.
Note, the SAC loss is computed earlier and uses in the super()._forward_train
exactly the logic you posted above.
+ list(self.qf_twin.parameters()) | ||
+ list(self.qf_twin_encoder.parameters()) | ||
) | ||
all_params = list(self.qf.parameters()) + list(self.qf_encoder.parameters()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great catch!!
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…dy sampled log-probabilities. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…l one published by Kumar et al. (2020). Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @simonsays1980 !!
…arning rates for actor, critic, and alpha. (ray-project#47402) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…arning rates for actor, critic, and alpha. (ray-project#47402) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…arning rates for actor, critic, and alpha. (ray-project#47402) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…arning rates for actor, critic, and alpha. (ray-project#47402) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…arning rates for actor, critic, and alpha. (ray-project#47402) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Why are these changes needed?
This PR introduces - similar to
SAC
- multiple learning rates forCQL
, namely one learning rate for each of the three optimizers (i.e. for actor, critic, and the hyperparameter alpha). Furthermore, it moves all forward passes from the learner into the module and therewith enables multi-learner setups.While
SAC
had already all forward passes moved into_forward_train
, CQL missed the ones that were used for theCQL
loss. This PR provides a complete setup.Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.