Skip to content

Commit

Permalink
Buffer fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ervin Teng committed Feb 23, 2021
1 parent fce4ad3 commit 2c03d2b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
11 changes: 4 additions & 7 deletions ml-agents/mlagents/trainers/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,10 @@ def get_batch(
)
if batch_size * training_length > len(self):
padding = np.array(self[-1], dtype=np.float32) * self.padding_value
return np.array(
[padding] * (training_length - leftover) + self[:], dtype=np.float32
)
return [padding] * (training_length - leftover) + self[:]

else:
return np.array(
self[len(self) - batch_size * training_length :], dtype=np.float32
)
return self[len(self) - batch_size * training_length :]
else:
# The sequences will have overlapping elements
if batch_size is None:
Expand All @@ -182,7 +179,7 @@ def get_batch(
tmp_list: List[np.ndarray] = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += self[end - training_length : end]
return np.array(tmp_list, dtype=np.float32)
return tmp_list

def reset_field(self) -> None:
"""
Expand Down
4 changes: 3 additions & 1 deletion ml-agents/mlagents/trainers/coma/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def _update_policy(self):
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)

advantages = self.update_buffer[BufferKey.ADVANTAGES].get_batch()
advantages = np.array(
self.update_buffer[BufferKey.ADVANTAGES].get_batch(), dtype=np.float32
)
self.update_buffer[BufferKey.ADVANTAGES].set(
(advantages - advantages.mean()) / (advantages.std() + 1e-10)
)
Expand Down

0 comments on commit 2c03d2b

Please sign in to comment.