-
Notifications
You must be signed in to change notification settings - Fork 66
/
rl_factory.py
31 lines (30 loc) · 1.54 KB
/
rl_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from src.rlagents.dqnagent import DQNAgent
from src.rlagents.ddpgagent import DDPGAgent
def rl_factory(rl_type, args, neural_network, exp_replay, rl_stats, n_actions, eps):
if rl_type == 'dqn':
return DQNAgent(neural_network,
eps,
exp_replay,
n_actions,
args.nsteps,
args.batch,
args.nreplay,
args.gamma,
rl_stats,
args.mode,
args.updates)
elif rl_type == 'ddpg':
return DDPGAgent(neural_network,
eps,
exp_replay,
n_actions,
args.nsteps,
args.batch,
args.nreplay,
args.gamma,
rl_stats,
args.mode,
args.updates)
else:
#raise not found exceptions
assert 0, 'Supplied rl argument type '+str(rl_type)+' does not exist.'