-
Notifications
You must be signed in to change notification settings - Fork 16
/
DQN_With_Fixed_Q_Targets.py
23 lines (19 loc) · 1.11 KB
/
DQN_With_Fixed_Q_Targets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import copy
from DQN import DQN
from DQN import DQN
class DQN_With_Fixed_Q_Targets(DQN):
"""A DQN agent that uses an older version of the q_network as the target network"""
agent_name = "DQN with Fixed Q Targets"
def __init__(self, config):
DQN.__init__(self, config)
self.q_network_target = self.create_NN(input_dim=self.state_size, output_dim=self.action_size)
DQN.copy_model_over(from_model=self.q_network_local, to_model=self.q_network_target)
def learn(self, experiences=None):
"""Runs a learning iteration for the Q network"""
super(DQN_With_Fixed_Q_Targets, self).learn(experiences=experiences)
self.soft_update_of_target_network(self.q_network_local, self.q_network_target,
self.hyperparameters["tau"]) # Update the target network
def compute_q_values_for_next_states(self, next_states):
"""Computes the q_values for next state we will use to create the loss to train the Q network"""
Q_targets_next = self.q_network_target(next_states).detach().max(1)[0].unsqueeze(1)
return Q_targets_next