Skip to content

Commit

Permalink
Saving and loading replay buffer with HDF5 (#261)
Browse files Browse the repository at this point in the history
As mentioned in #260, this pull request is about an implementation of saving and loading the replay buffer with HDF5.
  • Loading branch information
nicoguertler authored Dec 17, 2020
1 parent cd48142 commit 5d13d8a
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,4 @@ MUJOCO_LOG.TXT
.DS_Store
*.zip
*.pstats
*.swp
9 changes: 9 additions & 0 deletions docs/bibtex.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"cited": {
"tutorials/dqn": [
"DQN",
"DDPG",
"PPO"
]
}
}
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
]
)
}
bibtex_bibfiles = ['refs.bib']

# -- Options for HTML output -------------------------------------------------

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ def get_version() -> str:
install_requires=[
"gym>=0.15.4",
"tqdm",
"numpy",
"numpy!=1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard",
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=3.1.0"
],
extras_require={
"dev": [
Expand Down
70 changes: 70 additions & 0 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import torch
import pickle
import pytest
import tempfile
import h5py
import numpy as np
from timeit import timeit

from tianshou.data import Batch, SegmentTree, \
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.utils.converter import to_hdf5

if __name__ == '__main__':
from env import MyTestEnv
Expand Down Expand Up @@ -278,7 +282,73 @@ def test_pickle():
pbuf.weight[np.arange(len(pbuf))])


def test_hdf5():
size = 100
buffers = {
"array": ReplayBuffer(size, stack_num=2),
"list": ListReplayBuffer(),
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
}
buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device)
for i in range(4):
kwargs = {
'obs': Batch(index=np.array([i])),
'act': i,
'rew': rew,
'done': 0,
'info': {"number": {"n": i}, 'extra': None},
}
buffers["array"].add(**kwargs)
buffers["list"].add(**kwargs)
buffers["prioritized"].add(weight=np.random.rand(), **kwargs)

# save
paths = {}
for k, buf in buffers.items():
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
buf.save_hdf5(path)
paths[k] = path

# load replay buffer
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}

# compare
for k in buffers.keys():
assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num
assert _buffers[k]._maxsize == buffers[k]._maxsize
assert _buffers[k]._index == buffers[k]._index
assert np.all(_buffers[k]._indices == buffers[k]._indices)
for k in ["array", "prioritized"]:
assert isinstance(buffers[k].get(0, "info"), Batch)
assert isinstance(_buffers[k].get(0, "info"), Batch)
for k in ["array"]:
assert np.all(
buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
assert np.all(
buffers[k][:].info.extra == _buffers[k][:].info.extra)

for path in paths.values():
os.remove(path)

# raise exception when value cannot be pickled
data = {"not_supported": lambda x: x*x}
grp = h5py.Group
with pytest.raises(NotImplementedError):
to_hdf5(data, grp)
# ndarray with data type not supported by HDF5 that cannot be pickled
data = {"not_supported": np.array(lambda x: x*x)}
grp = h5py.Group
with pytest.raises(RuntimeError):
to_hdf5(data, grp)


if __name__ == '__main__':
test_hdf5()
test_replaybuffer()
test_ignore_obs_next()
test_stack()
Expand Down
41 changes: 38 additions & 3 deletions tianshou/data/buffer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import h5py
import torch
import numpy as np
from numbers import Number
from typing import Any, Dict, List, Tuple, Union, Optional

from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value
from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.utils.converter import to_hdf5, from_hdf5


class ReplayBuffer:
Expand Down Expand Up @@ -38,7 +40,10 @@ class ReplayBuffer:
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl"
>>> # save to file "buf.pkl"
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
>>> # save to HDF5 file
>>> buf.save_hdf5('buf.hdf5')
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
Expand All @@ -54,7 +59,7 @@ class ReplayBuffer:
0., 0., 0., 0., 0., 0., 0.])
>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> # the batch_data is equal to buf[indice].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
Expand All @@ -63,6 +68,15 @@ class ReplayBuffer:
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3
>>> # load complete buffer from HDF5 file
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
>>> len(buf)
3
>>> # load contents of HDF5 file into existing buffer
>>> # (only possible if size of buffer and data in file match)
>>> buf.load_contents_hdf5('buf.hdf5')
>>> len(buf)
3
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
(typically for RNN usage, see issue#19), ignoring storing the next
Expand Down Expand Up @@ -167,8 +181,14 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
We need it because pickling buffer does not work out-of-the-box
("buffer.__getattr__" is customized).
"""
self._indices = np.arange(state["_maxsize"])
self.__dict__.update(state)

def __getstate__(self) -> dict:
exclude = {"_indices"}
state = {k: v for k, v in self.__dict__.items() if k not in exclude}
return state

def _add_to_buffer(self, name: str, inst: Any) -> None:
try:
value = self._meta.__dict__[name]
Expand Down Expand Up @@ -359,6 +379,21 @@ def __getitem__(
policy=self.get(index, "policy"),
)

def save_hdf5(self, path: str) -> None:
"""Save replay buffer to HDF5 file."""
with h5py.File(path, "w") as f:
to_hdf5(self.__getstate__(), f)

@classmethod
def load_hdf5(
cls, path: str, device: Optional[str] = None
) -> "ReplayBuffer":
"""Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f:
buf = cls.__new__(cls)
buf.__setstate__(from_hdf5(f, device=device))
return buf


class ListReplayBuffer(ReplayBuffer):
"""List-based replay buffer.
Expand Down
91 changes: 90 additions & 1 deletion tianshou/data/utils/converter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import h5py
import torch
import pickle
import numpy as np
from copy import deepcopy
from numbers import Number
from typing import Union, Optional
from typing import Dict, Union, Optional

from tianshou.data.batch import _parse_value, Batch

Expand Down Expand Up @@ -80,3 +82,90 @@ def to_torch_as(
"""
assert isinstance(y, torch.Tensor)
return to_torch(x, dtype=y.dtype, device=y.device)


# Note: object is used as a proxy for objects that can be pickled
# Note: mypy does not support cyclic definition currently
Hdf5ConvertibleValues = Union[ # type: ignore
int, float, Batch, np.ndarray, torch.Tensor, object,
'Hdf5ConvertibleType', # type: ignore
]

Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore


def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
"""Copy object into HDF5 group."""

def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None:
"""Pickle, convert to numpy array and write to HDF5 dataset."""
data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
y.create_dataset(key, data=data)

for k, v in x.items():
if isinstance(v, (Batch, dict)):
# dicts and batches are both represented by groups
subgrp = y.create_group(k)
if isinstance(v, Batch):
subgrp_data = v.__getstate__()
subgrp.attrs["__data_type__"] = "Batch"
else:
subgrp_data = v
to_hdf5(subgrp_data, subgrp)
elif isinstance(v, torch.Tensor):
# PyTorch tensors are written to datasets
y.create_dataset(k, data=to_numpy(v))
y[k].attrs["__data_type__"] = "Tensor"
elif isinstance(v, np.ndarray):
try:
# NumPy arrays are written to datasets
y.create_dataset(k, data=v)
y[k].attrs["__data_type__"] = "ndarray"
except TypeError:
# If data type is not supported by HDF5 fall back to pickle.
# This happens if dtype=object (e.g. due to entries being None)
# and possibly in other cases like structured arrays.
try:
to_hdf5_via_pickle(v, y, k)
except Exception as e:
raise RuntimeError(
f"Attempted to pickle {v.__class__.__name__} due to "
"data type not supported by HDF5 and failed."
) from e
y[k].attrs["__data_type__"] = "pickled_ndarray"
elif isinstance(v, (int, float)):
# ints and floats are stored as attributes of groups
y.attrs[k] = v
else: # resort to pickle for any other type of object
try:
to_hdf5_via_pickle(v, y, k)
except Exception as e:
raise NotImplementedError(
f"No conversion to HDF5 for object of type '{type(v)}' "
"implemented and fallback to pickle failed."
) from e
y[k].attrs["__data_type__"] = v.__class__.__name__


def from_hdf5(
x: h5py.Group, device: Optional[str] = None
) -> Hdf5ConvertibleType:
"""Restore object from HDF5 group."""
if isinstance(x, h5py.Dataset):
# handle datasets
if x.attrs["__data_type__"] == "ndarray":
y = np.array(x)
elif x.attrs["__data_type__"] == "Tensor":
y = torch.tensor(x, device=device)
else:
y = pickle.loads(x[()])
else:
# handle groups representing a dict or a Batch
y = {k: v for k, v in x.attrs.items() if k != "__data_type__"}
for k, v in x.items():
y[k] = from_hdf5(v, device)
if "__data_type__" in x.attrs:
# if dictionary represents Batch, convert to Batch
if x.attrs["__data_type__"] == "Batch":
y = Batch(y)
return y

0 comments on commit 5d13d8a

Please sign in to comment.