-
Notifications
You must be signed in to change notification settings - Fork 16
/
DDQN.py
18 lines (13 loc) · 871 Bytes
/
DDQN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from DQN_With_Fixed_Q_Targets import DQN_With_Fixed_Q_Targets
class DDQN(DQN_With_Fixed_Q_Targets):
"""A double DQN agent"""
agent_name = "DDQN"
def __init__(self, config):
DQN_With_Fixed_Q_Targets.__init__(self, config)
def compute_q_values_for_next_states(self, next_states):
"""Computes the q_values for next state we will use to create the loss to train the Q network. Double DQN
uses the local index to pick the maximum q_value action and then the target network to calculate the q_value.
The reasoning behind this is that it will help stop the network from overestimating q values"""
max_action_indexes = self.q_network_local(next_states).detach().argmax(1)
Q_targets_next = self.q_network_target(next_states).gather(1, max_action_indexes.unsqueeze(1))
return Q_targets_next