diff --git a/src/modules/agents/rnn_communicating_agent.py b/src/modules/agents/rnn_communicating_agent.py index 458fbbbe..a6c70f3c 100644 --- a/src/modules/agents/rnn_communicating_agent.py +++ b/src/modules/agents/rnn_communicating_agent.py @@ -200,7 +200,7 @@ def forward(self, inputs, hidden_state): am_inputs = taus.reshape((prev_message.shape[0], -1)) new_message = self.am.forward(am_inputs, prev_message) - pi_input_message = th.cat([new_message]*5, 0).reshape(new_message.shape[0], -1) + pi_input_message = th.cat([new_message]*self.n_agents, 0).reshape(new_message.shape[0], -1) action, h_pi = self.pi(observation, pi_input_message, h_pi) hidden_state = th.cat([h_fa, h_fo, h_pi], -1)