diff --git a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py index 7ce18ac1bc..8578c887a0 100644 --- a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py +++ b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py @@ -166,18 +166,21 @@ def get_trajectory_value_estimates( # If we're using LSTM, we want to get all the intermediate memories. all_next_memories: Optional[AgentBufferField] = None - if self.policy.use_recurrent: - ( - value_estimates, - all_next_memories, - next_memory, - ) = self._evaluate_by_sequence(current_obs, memory) - else: - value_estimates, next_memory = self.critic.critic_pass( - current_obs, memory, sequence_length=batch.num_experiences - ) - # Store the memory for the next trajectory + # To prevent memory leak and improve performance, evaluate with no_grad. + with torch.no_grad(): + if self.policy.use_recurrent: + ( + value_estimates, + all_next_memories, + next_memory, + ) = self._evaluate_by_sequence(current_obs, memory) + else: + value_estimates, next_memory = self.critic.critic_pass( + current_obs, memory, sequence_length=batch.num_experiences + ) + + # Store the memory for the next trajectory. This should NOT have a gradient. self.critic_memory_dict[agent_id] = next_memory next_value_estimate, _ = self.critic.critic_pass( diff --git a/ml-agents/mlagents/trainers/tests/torch/test_ppo.py b/ml-agents/mlagents/trainers/tests/torch/test_ppo.py index 0b4c2c3472..ca1a18b1e7 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_ppo.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_ppo.py @@ -207,6 +207,11 @@ def test_ppo_get_value_estimates(dummy_config, rnn, visual, discrete): run_out, final_value_out, all_memories = optimizer.get_trajectory_value_estimates( trajectory.to_agentbuffer(), trajectory.next_obs, done=False ) + if rnn: + # Check that memories don't have a Torch gradient + for mem in optimizer.critic_memory_dict.values(): + assert not mem.requires_grad + for key, val in run_out.items(): assert type(key) is str assert len(val) == 15