From 753d0647a6eabdfce9734169103509301e1afda0 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Thu, 4 Mar 2021 19:45:08 -0500 Subject: [PATCH 1/2] Fix typo --- ml-agents/mlagents/trainers/buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/buffer.py b/ml-agents/mlagents/trainers/buffer.py index de30542a35..c2b2782969 100644 --- a/ml-agents/mlagents/trainers/buffer.py +++ b/ml-agents/mlagents/trainers/buffer.py @@ -206,7 +206,7 @@ def padded_to_batch( dimension is equal to the length of the AgentBufferField. """ if len(self) > 0 and not isinstance(self[0], list): - return np.asanyarray(self, dytpe=dtype) + return np.asanyarray(self, dtype=dtype) shape = None for _entry in self: From 17f01280c5680b53dbb05fdff3c02150bebb3497 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Thu, 4 Mar 2021 19:48:24 -0500 Subject: [PATCH 2/2] Add test --- ml-agents/mlagents/trainers/tests/test_buffer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ml-agents/mlagents/trainers/tests/test_buffer.py b/ml-agents/mlagents/trainers/tests/test_buffer.py index 6d99cb0e88..c4c612ae7d 100644 --- a/ml-agents/mlagents/trainers/tests/test_buffer.py +++ b/ml-agents/mlagents/trainers/tests/test_buffer.py @@ -154,6 +154,10 @@ def test_agentbufferfield(): assert np.array_equal(padded[0], np.array([1, 1, 1, 1])) assert np.array_equal(padded[1], np.array([2, 2, 3, 3])) + # Make sure it doesn't fail when the field isn't a list + padded_a = a.padded_to_batch() + assert np.array_equal(padded_a, a) + def fakerandint(values): return 19