-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathfpo.py
44 lines (37 loc) · 1.55 KB
/
fpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from sandbox.cpo.algos.safe.policy_gradient_safe import PolicyGradientSafe
from sandbox.cpo.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer
from rllab.core.serializable import Serializable
import rllab.misc.logger as logger
class FPO(PolicyGradientSafe, Serializable):
"""
Fixed Penalty Optimization
"""
def __init__(
self,
optimizer=None,
optimizer_args=None,
safety_constraint=None,
**kwargs):
Serializable.quick_init(self, locals())
if optimizer is None:
if optimizer_args is None:
optimizer_args = dict()
optimizer = ConjugateGradientOptimizer(**optimizer_args)
pop_keys = ['safety_constrained_optimizer',
'safety_tradeoff',
'learn_safety_tradeoff_coeff',
'safety_key',
'pdo_vf_mode']
for key in pop_keys:
if key in kwargs.keys():
kwargs.pop(key)
safety_key = 'returns'
pdo_vf_mode = 1
super(FPO, self).__init__(optimizer=optimizer,
safety_constrained_optimizer=False,
safety_constraint=safety_constraint,
safety_tradeoff=True,
learn_safety_tradeoff_coeff=False,
safety_key=safety_key,
pdo_vf_mode=pdo_vf_mode,
**kwargs)