From 2dc591b10c67503cd44403c0324064f174e97cb6 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Tue, 12 Oct 2021 10:43:19 -0700 Subject: [PATCH] Passing policy_state correctly into policy.action() function when computing Q values. PiperOrigin-RevId: 402601593 Change-Id: Ibb58d385835ccf9a1062be8a7a281a7f4dd92bdc --- tf_agents/agents/qtopt/qtopt_agent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tf_agents/agents/qtopt/qtopt_agent.py b/tf_agents/agents/qtopt/qtopt_agent.py index e41d90fe9..d84dc49d8 100644 --- a/tf_agents/agents/qtopt/qtopt_agent.py +++ b/tf_agents/agents/qtopt/qtopt_agent.py @@ -684,7 +684,8 @@ def _compute_next_q_values(self, next_time_steps, info, network_state=()): if not self._in_graph_bellman_update: return info['target_q'] - next_action_policy_step = self._policy.action(next_time_steps) + next_action_policy_step = self._policy.action( + next_time_steps, network_state) if self._enable_td3: q_values_target_delayed, _ = self._target_q_network_delayed(