Skip to content

Commit

Permalink
[bug-fix] Use correct memories for LSTM SAC (#5228)
Browse files Browse the repository at this point in the history
* Use correct memories for LSTM SAC

* Add some comments
  • Loading branch information
Ervin T committed Apr 8, 2021
1 parent 02b77dd commit 7077302
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions ml-agents/mlagents/trainers/sac/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,30 +497,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
)
]
offset = 1 if self.policy.sequence_length > 1 else 0
next_value_memories_list = [
ModelUtils.list_to_tensor(
batch[BufferKey.CRITIC_MEMORY][i]
) # only pass value part of memory to target network
for i in range(
offset, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
)
]

if len(memories_list) > 0:
memories = torch.stack(memories_list).unsqueeze(0)
value_memories = torch.stack(value_memories_list).unsqueeze(0)
next_value_memories = torch.stack(next_value_memories_list).unsqueeze(0)
else:
memories = None
value_memories = None
next_value_memories = None

# Q and V network memories are 0'ed out, since we don't have them during inference.
q_memories = (
torch.zeros_like(next_value_memories)
if next_value_memories is not None
else None
torch.zeros_like(value_memories) if value_memories is not None else None
)

# Copy normalizers from policy
Expand Down Expand Up @@ -568,6 +555,18 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
q1_stream, q2_stream = q1_out, q2_out

with torch.no_grad():
# Since we didn't record the next value memories, evaluate one step in the critic to
# get them.
if value_memories is not None:
# Get the first observation in each sequence
just_first_obs = [
_obs[:: self.policy.sequence_length] for _obs in current_obs
]
_, next_value_memories = self._critic.critic_pass(
just_first_obs, value_memories, sequence_length=1
)
else:
next_value_memories = None
target_values, _ = self.target_network(
next_obs,
memories=next_value_memories,
Expand Down

0 comments on commit 7077302

Please sign in to comment.