Skip to content

Commit 347852b

Browse files
author
Ervin T
authored
[bug-fix] Fix memory leak when using LSTMs (#5048)
* Detach memory before storing * Add test * Evaluate with no_grad
1 parent 5e87e2c commit 347852b

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,21 @@ def get_trajectory_value_estimates(
166166

167167
# If we're using LSTM, we want to get all the intermediate memories.
168168
all_next_memories: Optional[AgentBufferField] = None
169-
if self.policy.use_recurrent:
170-
(
171-
value_estimates,
172-
all_next_memories,
173-
next_memory,
174-
) = self._evaluate_by_sequence(current_obs, memory)
175-
else:
176-
value_estimates, next_memory = self.critic.critic_pass(
177-
current_obs, memory, sequence_length=batch.num_experiences
178-
)
179169

180-
# Store the memory for the next trajectory
170+
# To prevent memory leak and improve performance, evaluate with no_grad.
171+
with torch.no_grad():
172+
if self.policy.use_recurrent:
173+
(
174+
value_estimates,
175+
all_next_memories,
176+
next_memory,
177+
) = self._evaluate_by_sequence(current_obs, memory)
178+
else:
179+
value_estimates, next_memory = self.critic.critic_pass(
180+
current_obs, memory, sequence_length=batch.num_experiences
181+
)
182+
183+
# Store the memory for the next trajectory. This should NOT have a gradient.
181184
self.critic_memory_dict[agent_id] = next_memory
182185

183186
next_value_estimate, _ = self.critic.critic_pass(

ml-agents/mlagents/trainers/tests/torch/test_ppo.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@ def test_ppo_get_value_estimates(dummy_config, rnn, visual, discrete):
207207
run_out, final_value_out, all_memories = optimizer.get_trajectory_value_estimates(
208208
trajectory.to_agentbuffer(), trajectory.next_obs, done=False
209209
)
210+
if rnn:
211+
# Check that memories don't have a Torch gradient
212+
for mem in optimizer.critic_memory_dict.values():
213+
assert not mem.requires_grad
214+
210215
for key, val in run_out.items():
211216
assert type(key) is str
212217
assert len(val) == 15

0 commit comments

Comments
 (0)