diff --git a/tf_agents/agents/qtopt/qtopt_agent.py b/tf_agents/agents/qtopt/qtopt_agent.py index e41d90fe9..d84dc49d8 100644 --- a/tf_agents/agents/qtopt/qtopt_agent.py +++ b/tf_agents/agents/qtopt/qtopt_agent.py @@ -684,7 +684,8 @@ def _compute_next_q_values(self, next_time_steps, info, network_state=()): if not self._in_graph_bellman_update: return info['target_q'] - next_action_policy_step = self._policy.action(next_time_steps) + next_action_policy_step = self._policy.action( + next_time_steps, network_state) if self._enable_td3: q_values_target_delayed, _ = self._target_q_network_delayed(