Skip to content

Commit

Permalink
Merge branch 'master' into clip_action
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored Mar 21, 2021
2 parents 6293deb + 0c7117d commit 60dee40
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ The following code snippet illustrates its usage, including:
::

>>> import pickle, numpy as np
>>> from tianshou.data import ReplayBuffer
>>> from tianshou.data import Batch, ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})
... buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))

>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
Expand All @@ -96,7 +96,7 @@ The following code snippet illustrates its usage, including:
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... done = i % 4 == 0
... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})
... buf2.add(Batch(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}))
>>> len(buf2)
10
>>> buf2.obs
Expand Down Expand Up @@ -147,25 +147,26 @@ The following code snippet illustrates its usage, including:
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i,
... done=done, obs_next={'id': i + 1})
... ptr, ep_rew, ep_len, ep_idx = buf.add(
... Batch(obs={'id': i}, act=i, rew=i,
... done=done, obs_next={'id': i + 1}))
... print(i, ep_len, ep_rew)
0 1 0.0
1 0 0.0
2 0 0.0
3 0 0.0
4 0 0.0
5 5 15.0
6 0 0.0
7 0 0.0
8 0 0.0
9 0 0.0
10 5 40.0
11 0 0.0
12 0 0.0
13 0 0.0
14 0 0.0
15 5 65.0
0 [1] [0.]
1 [0] [0.]
2 [0] [0.]
3 [0] [0.]
4 [0] [0.]
5 [5] [15.]
6 [0] [0.]
7 [0] [0.]
8 [0] [0.]
9 [0] [0.]
10 [5] [40.]
11 [0] [0.]
12 [0] [0.]
13 [0] [0.]
14 [0] [0.]
15 [5] [65.]
>>> print(buf) # you can see obs_next is not saved in buf
ReplayBuffer(
obs: Batch(
Expand All @@ -175,8 +176,6 @@ The following code snippet illustrates its usage, including:
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([False, True, False, False, False, False, True, False,
False]),
info: Batch(),
policy: Batch(),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
Expand All @@ -194,16 +193,21 @@ The following code snippet illustrates its usage, including:
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> # however, [:] will select the item according to timestamp,
>>> # that equals to index == [7, 8, 0, 1, 2, 3, 4, 5, 6]
>>> print(buf[:].obs_next.id)
[[ 7 8 9 10]
[[ 7 7 7 8]
[ 7 7 8 9]
[ 7 8 9 10]
[ 7 8 9 10]
[11 11 11 12]
[11 11 12 13]
[11 12 13 14]
[12 13 14 15]
[12 13 14 15]
[ 7 7 7 8]
[ 7 7 8 9]]
[12 13 14 15]]
>>> full_index = np.array([7, 8, 0, 1, 2, 3, 4, 5, 6])
>>> np.allclose(buf[:].obs_next.id, buf[full_index].obs_next.id)
True
.. raw:: html

Expand Down

0 comments on commit 60dee40

Please sign in to comment.