From f66dbe6204f9faf12b98aab1809b7583aa7a5c0e Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Fri, 5 Mar 2021 23:29:17 -0500 Subject: [PATCH 1/3] Detach memory before storing --- ml-agents/mlagents/trainers/optimizer/torch_optimizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py index 7ce18ac1bc..641800ee93 100644 --- a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py +++ b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py @@ -177,8 +177,9 @@ def get_trajectory_value_estimates( current_obs, memory, sequence_length=batch.num_experiences ) - # Store the memory for the next trajectory - self.critic_memory_dict[agent_id] = next_memory + # Store the memory for the next trajectory. + # Must detach from graph to preevent memory leaek + self.critic_memory_dict[agent_id] = next_memory.detach() next_value_estimate, _ = self.critic.critic_pass( next_obs, next_memory, sequence_length=1 From b137059a97e1444d4288b2ad6d4cea0e540a0390 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Fri, 5 Mar 2021 23:37:42 -0500 Subject: [PATCH 2/3] Add test --- ml-agents/mlagents/trainers/tests/torch/test_ppo.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_ppo.py b/ml-agents/mlagents/trainers/tests/torch/test_ppo.py index 0b4c2c3472..ca1a18b1e7 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_ppo.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_ppo.py @@ -207,6 +207,11 @@ def test_ppo_get_value_estimates(dummy_config, rnn, visual, discrete): run_out, final_value_out, all_memories = optimizer.get_trajectory_value_estimates( trajectory.to_agentbuffer(), trajectory.next_obs, done=False ) + if rnn: + # Check that memories don't have a Torch gradient + for mem in optimizer.critic_memory_dict.values(): + assert not mem.requires_grad + for key, val in run_out.items(): assert type(key) is str assert len(val) == 15 From 54f45ee8c4eacd16da7bfa39b0f079e708acc745 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 8 Mar 2021 10:47:15 -0500 Subject: [PATCH 3/3] Evaluate with no_grad --- .../trainers/optimizer/torch_optimizer.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py index 641800ee93..8578c887a0 100644 --- a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py +++ b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py @@ -166,20 +166,22 @@ def get_trajectory_value_estimates( # If we're using LSTM, we want to get all the intermediate memories. all_next_memories: Optional[AgentBufferField] = None - if self.policy.use_recurrent: - ( - value_estimates, - all_next_memories, - next_memory, - ) = self._evaluate_by_sequence(current_obs, memory) - else: - value_estimates, next_memory = self.critic.critic_pass( - current_obs, memory, sequence_length=batch.num_experiences - ) - # Store the memory for the next trajectory. - # Must detach from graph to preevent memory leaek - self.critic_memory_dict[agent_id] = next_memory.detach() + # To prevent memory leak and improve performance, evaluate with no_grad. + with torch.no_grad(): + if self.policy.use_recurrent: + ( + value_estimates, + all_next_memories, + next_memory, + ) = self._evaluate_by_sequence(current_obs, memory) + else: + value_estimates, next_memory = self.critic.critic_pass( + current_obs, memory, sequence_length=batch.num_experiences + ) + + # Store the memory for the next trajectory. This should NOT have a gradient. + self.critic_memory_dict[agent_id] = next_memory next_value_estimate, _ = self.critic.critic_pass( next_obs, next_memory, sequence_length=1