From 167d4bb78d70a5261da113d9a638801d43adf4c8 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 1 May 2024 01:22:04 +0200 Subject: [PATCH 01/12] Batch: important new functionality 1. apply_array_func for applying array operations recursively. Use it in to_numpy and to_torch 2. isnull, hasnull, dropnull 3. set_array_at_key for setting a subarray at a desired index inplace Added extensive tests for the new methods --- test/base/test_batch.py | 98 +++++++++++++++-- tianshou/data/batch.py | 233 +++++++++++++++++++++++++++++++++------- 2 files changed, 284 insertions(+), 47 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 0530d8232..fce82f919 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -703,9 +703,7 @@ def test_to_dict_nested_batch_with_torch_tensor() -> None: assert not DeepDiff(batch.to_dict(recursive=True), expected) -class TestToNumpy: - """Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` .""" - +class TestConversions: @staticmethod def test_to_numpy() -> None: batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) @@ -726,10 +724,6 @@ def test_to_numpy_() -> None: assert isinstance(batch.b, np.ndarray) assert isinstance(batch.c.d, np.ndarray) - -class TestToTorch: - """Tests for `Batch.to_torch()` and its in-place counterpart `Batch.to_torch_()` .""" - @staticmethod def test_to_torch() -> None: batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])}) @@ -749,3 +743,93 @@ def test_to_torch_() -> None: assert id_batch == id(batch) assert isinstance(batch.b, torch.Tensor) assert isinstance(batch.c.d, torch.Tensor) + + @staticmethod + def test_apply_array_func() -> None: + batch = Batch(a=1, b=np.arange(3), c={"d": np.array([1, 2, 3])}) + batch_with_max = batch.apply_array_func(np.max) + assert np.array_equal(batch_with_max.a, np.array(1)) + assert np.array_equal(batch_with_max.b, np.array(2)) + assert np.array_equal(batch_with_max.c.d, np.array(3)) + + batch_array_added = batch.apply_array_func(lambda x: x + np.array([1, 2, 3])) + assert np.array_equal(batch_array_added.a, np.array([2, 3, 4])) + assert np.array_equal(batch_array_added.c.d, np.array([2, 4, 6])) + + +class TestAssignment: + @staticmethod + def test_assign_full_length_array() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) + batch.set_array_at_key(np.array([1, 2, 3]), "a") + batch.set_array_at_key(np.array([4, 5, 6]), "new_key") + assert np.array_equal(batch.a, np.array([1, 2, 3])) + assert np.array_equal(batch.new_key, np.array([4, 5, 6])) + + # other keys are not affected + assert np.array_equal(batch.b, np.array([7, 8, 9])) + assert np.array_equal(batch.c.d, np.array([1, 2, 3])) + + with pytest.raises(ValueError): + # wrong length + batch.set_array_at_key(np.array([1, 2]), "a") + + @staticmethod + def test_assign_subarray_existing_key() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) + batch.set_array_at_key(np.array([1, 2]), "a", index=[0, 1]) + assert np.array_equal(batch.a, np.array([1, 2, 6])) + batch.set_array_at_key(np.array([10, 12]), "a", index=slice(0, 2)) + assert np.array_equal(batch.a, np.array([10, 12, 6])) + batch.set_array_at_key(np.array([1, 2]), "a", index=[0, 2]) + assert np.array_equal(batch.a, np.array([1, 12, 2])) + batch.set_array_at_key(np.array([1, 2]), "a", index=[2, 0]) + assert np.array_equal(batch.a, np.array([2, 12, 1])) + batch.set_array_at_key(np.array([1, 2, 3]), "a", index=[2, 1, 0]) + assert np.array_equal(batch.a, np.array([3, 2, 1])) + + with pytest.raises(IndexError): + # Index out of bounds + batch.set_array_at_key(np.array([1, 2]), "a", index=[10, 11]) + + # other keys are not affected + assert np.array_equal(batch.b, np.array([7, 8, 9])) + assert np.array_equal(batch.c.d, np.array([1, 2, 3])) + + @staticmethod + def test_assign_subarray_new_key() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) + batch.set_array_at_key(np.array([1, 2]), "new_key", index=[0, 1], default_value=0) + assert np.array_equal(batch.new_key, np.array([1, 2, 0])) + # with float, None can be cast to NaN + batch.set_array_at_key(np.array([1.0, 2.0]), "new_key2", index=[0, 1]) + assert np.array_equal(batch.new_key2, np.array([1.0, 2.0, np.nan]), equal_nan=True) + + @staticmethod + def test_isnull() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([1, None, 3])}) + batch_isnan = batch.isnull() + assert not batch_isnan.a.any() + assert batch_isnan.b[2] + assert not batch_isnan.b[:2].any() + assert np.array_equal(batch_isnan.c.d, np.array([False, True, False])) + + @staticmethod + def test_hasnull() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([1, 2, 3])}) + assert batch.hasnull() + batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) + assert not batch.hasnull() + batch = Batch(a=[4, 5, 6], c={"d": np.array([1, None, 3])}) + assert batch.hasnull() + + @staticmethod + def test_dropnull(): + batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([None, 2.1, 3.0])}) + assert batch.dropnull() == Batch(a=[5], b=[8], c={"d": np.array([2.1])}).apply_array_func( + np.atleast_1d, + ) + batch2 = Batch(a=[4, 5, 6, 7], b=[7, 8, None, 10], c={"d": np.array([None, 2, 3, 4])}) + assert batch2.dropnull() == Batch(a=[5, 7], b=[8, 10], c={"d": np.array([2, 4])}) + batch_no_nan = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) + assert batch_no_nan.dropnull() == batch_no_nan diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 03b3d9849..90e943a17 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,6 +1,6 @@ import pprint import warnings -from collections.abc import Collection, Iterable, Iterator, KeysView, Sequence +from collections.abc import Callable, Collection, Iterable, Iterator, KeysView, Sequence from copy import deepcopy from numbers import Number from types import EllipsisType @@ -16,14 +16,19 @@ ) import numpy as np +import pandas as pd import torch from deepdiff import DeepDiff +from tianshou.utils import logging + _SingleIndexType = slice | int | EllipsisType IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] TBatch = TypeVar("TBatch", bound="BatchProtocol") arr_type = torch.Tensor | np.ndarray +log = logging.getLogger(__name__) + def _is_batch_set(obj: Any) -> bool: # Batch set is a list/tuple of dict/Batch objects, @@ -417,6 +422,71 @@ def to_dict(self, recurse: bool = True) -> dict[str, Any]: def to_list_of_dicts(self) -> list[dict[str, Any]]: ... + def get_keys(self) -> KeysView: + ... + + def set_array_at_key( + self, + seq: np.ndarray, + key: str, + index: Sequence[int] | None = None, + default_value: float | None = None, + ) -> None: + """Set a sequence of values at a given key. + + If index is not passed, the sequence must have the same length as the batch. + :param seq: the array of values to set. + :param key: the key to set the sequence at. + :param index: the indices to set the sequence at. If None, the sequence must have + the same length as the batch and will be set at all indices. + :param default_value: this only applies if index is passed an the key does not exist yet + in the batch. In that case entries outside the passed index will be filled + with this default value. + Note that the array at the key will be of the same dtype as the passed sequence, + so default value should be such that numpy can cast it to this dtype. + """ + ... + + def isnull(self) -> Self: + """Return a boolean mask of the same shape, indicating missing values.""" + ... + + def hasnull(self) -> bool: + """Return whether the batch has missing values.""" + ... + + def dropnull(self) -> Self: + """Return a batch where all items in which any value is null are dropped. + + Note that it is not the same as just dropping the entries of the sequence. + For example, with + + >>> b = Batch(a=[None, 2, 3, 4], b=[4, 5, None, 7]) + >>> b.dropnull() + + will result in + + >>> Batch(a=[2, 4], b=[5, 7]) + + This logic is applied recursively to all nested batches. The result is + the same as if the batch was flattened, entries were dropped, + and then the batch was reshaped back to the original nested structure. + """ + ... + + def apply_array_func( + self, + array_func: Callable[[np.ndarray | torch.Tensor], Any], + inplace: bool = False, + ) -> None | Self: + """Apply a function to all arrays in the batch, including nested ones. + + :param array_func: the function to apply to the arrays. + :param inplace: whether to apply the function in-place. If False, a new batch is returned, + otherwise the batch is modified in-place and None is returned. + """ + ... + class Batch(BatchProtocol): """See :class:`~tianshou.data.batch.BatchProtocol`.""" @@ -630,22 +700,17 @@ def __repr__(self) -> str: @staticmethod def to_numpy(batch: TBatch) -> TBatch: - batch_dict = deepcopy(batch) - for batch_key, obj in batch_dict.items(): - if isinstance(obj, torch.Tensor): - batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy() - elif isinstance(obj, Batch): - obj = Batch.to_numpy(obj) - batch_dict.__dict__[batch_key] = obj - - return batch_dict + result = deepcopy(batch) + result.to_numpy_() + return result def to_numpy_(self) -> None: - for batch_key, obj in self.items(): - if isinstance(obj, torch.Tensor): - self.__dict__[batch_key] = obj.detach().cpu().numpy() - elif isinstance(obj, Batch): - obj.to_numpy_() + def arr_to_numpy(arr: arr_type) -> arr_type: + if isinstance(arr, torch.Tensor): + return arr.detach().cpu().numpy() + return arr + + self.apply_array_func(arr_to_numpy, inplace=True) @staticmethod def to_torch( @@ -653,10 +718,9 @@ def to_torch( dtype: torch.dtype | None = None, device: str | int | torch.device = "cpu", ) -> TBatch: - new_batch = Batch(batch, copy=True) - new_batch.to_torch_(dtype=dtype, device=device) - - return new_batch # type: ignore[return-value] + result = deepcopy(batch) + result.to_torch_(dtype=dtype, device=device) + return result def to_torch_( self, @@ -666,28 +730,23 @@ def to_torch_( if not isinstance(device, torch.device): device = torch.device(device) - for batch_key, obj in self.items(): - if isinstance(obj, torch.Tensor): - if ( - dtype is not None - and obj.dtype != dtype - or obj.device.type != device.type - or device.index != obj.device.index - ): - if dtype is not None: - self.__dict__[batch_key] = obj.type(dtype).to(device) - else: - self.__dict__[batch_key] = obj.to(device) - elif isinstance(obj, Batch): - obj.to_torch_(dtype, device) - else: - # ndarray or scalar - if not isinstance(obj, np.ndarray): - obj = np.asanyarray(obj) - obj = torch.from_numpy(obj).to(device) + def arr_to_torch(arr: arr_type) -> arr_type: + if isinstance(arr, np.ndarray): + return torch.tensor(arr, dtype=dtype, device=device) + + # TODO: simplify + if ( + dtype is not None + and arr.dtype != dtype + or arr.device.type != device.type + or device.index != arr.device.index + ): if dtype is not None: - obj = obj.type(dtype) - self.__dict__[batch_key] = obj + arr = arr.type(dtype) + return arr.to(device) + return None + + self.apply_array_func(arr_to_torch, inplace=True) def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: """Private method for Batch.cat_. @@ -967,3 +1026,97 @@ def split( yield self[indices[idx:]] break yield self[indices[idx : idx + size]] + + def apply_array_func( + self, + array_func: Callable[[np.ndarray | torch.Tensor], Any], + inplace: bool = False, + ) -> None | Self: + return _apply_array_func_recursively(self, array_func, inplace=inplace) + + def set_array_at_key( + self, + arr: np.ndarray, + key: str, + index: IndexType | None = None, + default_value: float | None = None, + ) -> None: + if index is not None: + if key not in self.get_keys(): + log.info( + f"Key {key} not found in batch, " + f"creating a sequence of len {len(self)} with {default_value=} for it.", + ) + try: + self[key] = np.array([default_value] * len(self), dtype=arr.dtype) + except TypeError as exception: + raise TypeError( + f"Cannot create a sequence of dtype {arr.dtype} with default value {default_value}. " + f"You can fix this either by passing an array with the correct dtype or by passing " + f"a different default value that can be cast to the array's dtype (or both).", + ) from exception + else: + existing_entry = self[key] + if isinstance(existing_entry, BatchProtocol): + raise ValueError( + f"Cannot set sequence at key {key} because it is a nested batch, " + f"can only set a subsequence of an array.", + ) + self[key][index] = arr + else: + if len(arr) != len(self): + raise ValueError( + f"Sequence length {len(arr)} does not match " + f"batch length {len(self)}. For setting a subsequence with missing " + f"entries filled up by default values, consider passing an index.", + ) + self[key] = arr + + def isnull(self) -> Self: + return self.apply_array_func(pd.isnull, inplace=False) + + def hasnull(self) -> bool: + isnan_batch = self.isnull() + is_any_null_batch = isnan_batch.apply_array_func(np.any, inplace=False) + + def is_any_true(boolean_batch: BatchProtocol): + for val in boolean_batch.values(): + if isinstance(val, BatchProtocol): + if is_any_true(val): + return True + else: + assert val.size == 1, "This shouldn't have happened, it's a bug!" + # an unsized array with a boolean, e.g. np.array(False). behaves like the boolean itself + if val: + return True + return None + + return is_any_true(is_any_null_batch) + + def dropnull(self) -> Self: + # we need to use dicts since a batch retrieved for a single index has no length and cat fails + # TODO: make cat work with batches containing scalars? + sub_batches = [] + for b in self: + if b.hasnull(): + continue + # needed for cat to work + b = b.apply_array_func(np.atleast_1d) + sub_batches.append(b) + return Batch.cat(sub_batches) + + +def _apply_array_func_recursively( + batch: TBatch, + array_func: Callable[[np.ndarray | torch.Tensor], Any], + inplace: bool = False, +) -> TBatch | None: + result = batch if inplace else deepcopy(batch) + for key, val in batch.__dict__.items(): + if isinstance(val, BatchProtocol): + result[key] = _apply_array_func_recursively(val, array_func, inplace=False) + else: + result[key] = array_func(val) + if not inplace: + return result + return None From d3905c34617f95de43f7988c4c8c7dff79b79b67 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 3 May 2024 22:22:49 +0200 Subject: [PATCH 02/12] Batch: fixed hasnull return --- tianshou/data/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 90e943a17..40d6ffd4e 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1089,7 +1089,7 @@ def is_any_true(boolean_batch: BatchProtocol): # an unsized array with a boolean, e.g. np.array(False). behaves like the boolean itself if val: return True - return None + return False return is_any_true(is_any_null_batch) From ce6e34b1d6350cd8ca655c40d227d2ce28676fbd Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 29 Jul 2024 15:32:28 +0200 Subject: [PATCH 03/12] Improved typing, bugfix in Batch.apply_array_func A typo led to None instead of arr being returned --- examples/inverse/irl_gail.py | 18 +- test/base/test_batch.py | 2 +- test/base/test_buffer.py | 531 +++++++++++++++---------- test/base/test_collector.py | 14 +- test/base/test_env_finite.py | 11 +- test/base/test_returns.py | 150 ++++--- tianshou/data/batch.py | 53 ++- tianshou/data/types.py | 2 +- tianshou/policy/multiagent/mapolicy.py | 2 +- 9 files changed, 482 insertions(+), 301 deletions(-) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 42e5bc2c9..e327fd490 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -4,7 +4,7 @@ import datetime import os import pprint -from typing import SupportsFloat +from typing import SupportsFloat, cast import d4rl import gymnasium as gym @@ -16,6 +16,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data.types import RolloutBatchProtocol from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.policy import GAILPolicy from tianshou.policy.base import BasePolicy @@ -185,12 +186,15 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: for i in range(dataset_size): expert_buffer.add( - Batch( - obs=dataset["observations"][i], - act=dataset["actions"][i], - rew=dataset["rewards"][i], - done=dataset["terminals"][i], - obs_next=dataset["next_observations"][i], + cast( + RolloutBatchProtocol, + Batch( + obs=dataset["observations"][i], + act=dataset["actions"][i], + rew=dataset["rewards"][i], + done=dataset["terminals"][i], + obs_next=dataset["next_observations"][i], + ), ), ) print("dataset loaded") diff --git a/test/base/test_batch.py b/test/base/test_batch.py index fce82f919..a26d5d679 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -824,7 +824,7 @@ def test_hasnull() -> None: assert batch.hasnull() @staticmethod - def test_dropnull(): + def test_dropnull() -> None: batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([None, 2.1, 3.0])}) assert batch.dropnull() == Batch(a=[5], b=[8], c={"d": np.array([2.1])}).apply_array_func( np.atleast_1d, diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 9f3f40828..75ff919c1 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -21,6 +21,7 @@ SegmentTree, VectorReplayBuffer, ) +from tianshou.data.types import RolloutBatchProtocol from tianshou.data.utils.converter import to_hdf5 @@ -34,14 +35,17 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) buf.add( - Batch( - obs=obs, - act=[act], - rew=rew, - terminated=terminated, - truncated=truncated, - obs_next=obs_next, - info=info, + cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=[act], + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=obs_next, + info=info, + ), ), ) obs = obs_next @@ -58,33 +62,36 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: assert (data.terminated <= 1).all() assert (data.truncated >= 0).all() assert (data.truncated <= 1).all() - b = ReplayBuffer(size=10) + replay_buffer = ReplayBuffer(size=10) # neg bsz should return empty index - assert b.sample_indices(-1).tolist() == [] - ptr, ep_rew, ep_len, ep_idx = b.add( - Batch( - obs=1, - act=1, - rew=1, - terminated=1, - truncated=0, - obs_next="str", - info={"a": 3, "b": {"c": 5.0}}, + assert replay_buffer.sample_indices(-1).tolist() == [] + ptr, ep_rew, ep_len, ep_idx = replay_buffer.add( + cast( + RolloutBatchProtocol, + Batch( + obs=1, + act=1, + rew=1, + terminated=1, + truncated=0, + obs_next="str", + info={"a": 3, "b": {"c": 5.0}}, + ), ), ) - assert b.obs[0] == 1 - assert b.done[0] - assert b.terminated[0] - assert not b.truncated[0] - assert b.obs_next[0] == "str" - assert np.all(b.obs[1:] == 0) - assert np.all(b.obs_next[1:] == np.array(None)) - assert b.info.a[0] == 3 - assert b.info.a.dtype == int - assert np.all(b.info.a[1:] == 0) - assert b.info.b.c[0] == 5.0 - assert b.info.b.c.dtype == float - assert np.all(b.info.b.c[1:] == 0.0) + assert replay_buffer.obs[0] == 1 + assert replay_buffer.done[0] + assert replay_buffer.terminated[0] + assert not replay_buffer.truncated[0] + assert replay_buffer.obs_next[0] == "str" + assert np.all(replay_buffer.obs[1:] == 0) + assert np.all(replay_buffer.obs_next[1:] == np.array(None)) + assert replay_buffer.info.a[0] == 3 + assert replay_buffer.info.a.dtype == int + assert np.all(replay_buffer.info.a[1:] == 0) + assert replay_buffer.info.b.c[0] == 5.0 + assert replay_buffer.info.b.c.dtype == float + assert np.all(replay_buffer.info.b.c[1:] == 0.0) assert ptr.shape == (1,) assert ptr[0] == 0 assert ep_rew.shape == (1,) @@ -94,28 +101,32 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: assert ep_idx.shape == (1,) assert ep_idx[0] == 0 # test extra keys pop up, the buffer should handle it dynamically - batch = Batch( - obs=2, - act=2, - rew=2, - terminated=0, - truncated=0, - obs_next="str2", - info={"a": 4, "d": {"e": -np.inf}}, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=2, + act=2, + rew=2, + terminated=0, + truncated=0, + obs_next="str2", + info={"a": 4, "d": {"e": -np.inf}}, + ), ) - b.add(batch) + replay_buffer.add(batch) info_keys = ["a", "b", "d"] - assert set(b.info.keys()) == set(info_keys) - assert b.info.a[1] == 4 - assert b.info.b.c[1] == 0 - assert b.info.d.e[1] == -np.inf + assert set(replay_buffer.info.keys()) == set(info_keys) + assert replay_buffer.info.a[1] == 4 + assert replay_buffer.info.b.c[1] == 0 + assert replay_buffer.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 batch.done = [1] - batch.terminated = [0] - batch.truncated = [1] + batch.terminated = [0] # type: ignore[assignment] + batch.truncated = [1] # type: ignore[assignment] + assert isinstance(batch.info, Batch) batch.info.e = np.zeros([1, 4]) batch = Batch.stack([batch]) - ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) + ptr, ep_rew, ep_len, ep_idx = replay_buffer.add(batch, buffer_ids=[0]) assert ptr.shape == (1,) assert ptr[0] == 2 assert ep_rew.shape == (1,) @@ -124,17 +135,17 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: assert ep_len[0] == 2 assert ep_idx.shape == (1,) assert ep_idx[0] == 1 - assert set(b.info.keys()) == {*info_keys, "e"} - assert b.info.e.shape == (b.maxsize, 1, 4) + assert set(replay_buffer.info.keys()) == {*info_keys, "e"} + assert replay_buffer.info.e.shape == (replay_buffer.maxsize, 1, 4) with pytest.raises(IndexError): - b[22] + replay_buffer[22] # test prev / next - assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) - assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) + assert np.all(replay_buffer.prev(np.array([0, 1, 2])) == [0, 1, 1]) + assert np.all(replay_buffer.next(np.array([0, 1, 2])) == [0, 2, 2]) batch.done = [0] - b.add(batch, buffer_ids=[0]) - assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) - assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) + replay_buffer.add(batch, buffer_ids=[0]) + assert np.all(replay_buffer.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) + assert np.all(replay_buffer.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) def test_ignore_obs_next(size: int = 10) -> None: @@ -142,17 +153,20 @@ def test_ignore_obs_next(size: int = 10) -> None: buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): buf.add( - Batch( - obs={ - "mask1": np.array([i, 1, 1, 0, 0]), - "mask2": np.array([i + 4, 0, 1, 0, 0]), - "mask": i, - }, - act={"act_id": i, "position_id": i + 3}, - rew=i, - terminated=i % 3 == 0, - truncated=False, - info={"if": i}, + cast( + RolloutBatchProtocol, + Batch( + obs={ + "mask1": np.array([i, 1, 1, 0, 0]), + "mask2": np.array([i + 4, 0, 1, 0, 0]), + "mask": i, + }, + act={"act_id": i, "position_id": i + 3}, + rew=i, + terminated=i % 3 == 0, + truncated=False, + info={"if": i}, + ), ), ) indices = np.arange(len(buf)) @@ -224,34 +238,43 @@ def test_stack(size: int = 5, bufsize: int = 9, stack_num: int = 4, cached_num: obs_next, rew, terminated, truncated, info = env.step(1) done = terminated or truncated buf.add( - Batch( - obs=obs, - act=1, - rew=rew, - terminated=terminated, - truncated=truncated, - info=info, + cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=1, + rew=rew, + terminated=terminated, + truncated=truncated, + info=info, + ), ), ) buf2.add( - Batch( - obs=obs, - act=1, - rew=rew, - terminated=terminated, - truncated=truncated, - info=info, + cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=1, + rew=rew, + terminated=terminated, + truncated=truncated, + info=info, + ), ), ) buf3.add( - Batch( - obs=[obs, obs, obs], - act=1, - rew=rew, - terminated=terminated, - truncated=truncated, - obs_next=[obs, obs], - info=info, + cast( + RolloutBatchProtocol, + Batch( + obs=[obs, obs, obs], + act=1, + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=[obs, obs], + info=info, + ), ), ) obs = obs_next @@ -293,15 +316,18 @@ def test_priortized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) - batch = Batch( - obs=obs, - act=act, - rew=rew, - terminated=terminated, - truncated=truncated, - obs_next=obs_next, - info=info, - policy=np.random.randn() - 0.5, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=act, + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=obs_next, + info=info, + policy=np.random.randn() - 0.5, + ), ) batch_stack = Batch.stack([batch, batch, batch]) buf.add(Batch.stack([batch]), buffer_ids=[0]) @@ -362,14 +388,17 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) - batch = Batch( - obs=obs, - act=[act], - rew=rew, - terminated=terminated, - truncated=truncated, - obs_next=obs_next, - info=info, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=[act], + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=obs_next, + info=info, + ), ) buf.add(batch) buf2.add(Batch.stack([batch, batch, batch]), buffer_ids=[0, 1, 2]) @@ -448,14 +477,17 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: for i in range(ep_len): act = 1 obs_next, rew, terminated, truncated, info = env.step(act) - batch = Batch( - obs=obs, - act=[act], - rew=rew, - terminated=(i == ep_len - 1), - truncated=(i == ep_len - 1), - obs_next=obs_next, - info=info, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=[act], + rew=rew, + terminated=(i == ep_len - 1), + truncated=(i == ep_len - 1), + obs_next=obs_next, + info=info, + ), ) buf.add(batch) obs = obs_next @@ -476,14 +508,17 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: for i in range(ep_len): act = 1 obs_next, rew, terminated, truncated, info = env.step(act) - batch = Batch( - obs=obs, - act=[act], - rew=rew, - terminated=(i == ep_len - 1), - truncated=(i == ep_len - 1), - obs_next=obs_next, - info=info, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=[act], + rew=rew, + terminated=(i == ep_len - 1), + truncated=(i == ep_len - 1), + obs_next=obs_next, + info=info, + ), ) if x == 1 and obs["observation"] < 10: obs = obs_next @@ -501,13 +536,16 @@ def test_update() -> None: buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): buf1.add( - Batch( - obs=np.array([i]), - act=float(i), - rew=i * i, - terminated=i % 2 == 0, - truncated=False, - info={"incident": "found"}, + cast( + RolloutBatchProtocol, + Batch( + obs=np.array([i]), + act=float(i), + rew=i * i, + terminated=i % 2 == 0, + truncated=False, + info={"incident": "found"}, + ), ), ) assert len(buf1) > len(buf2) @@ -610,23 +648,29 @@ def test_pickle() -> None: rew = np.array([1, 1]) for i in range(4): vbuf.add( - Batch( - obs=Batch(index=np.array([i])), - act=0, - rew=rew, - terminated=0, - truncated=0, + cast( + RolloutBatchProtocol, + Batch( + obs=Batch(index=np.array([i])), + act=0, + rew=rew, + terminated=0, + truncated=0, + ), ), ) for i in range(5): pbuf.add( - Batch( - obs=Batch(index=np.array([i])), - act=2, - rew=rew, - terminated=0, - truncated=0, - info=np.random.rand(), + cast( + RolloutBatchProtocol, + Batch( + obs=Batch(index=np.array([i])), + act=2, + rew=rew, + terminated=0, + truncated=0, + info=np.random.rand(), + ), ), ) # save & load @@ -660,8 +704,8 @@ def test_hdf5() -> None: "done": i % 3 == 2, "info": {"number": {"n": i, "t": info_t}, "extra": None}, } - buffers["array"].add(Batch(kwargs)) - buffers["prioritized"].add(Batch(kwargs)) + buffers["array"].add(cast(RolloutBatchProtocol, Batch(kwargs))) + buffers["prioritized"].add(cast(RolloutBatchProtocol, Batch(kwargs))) # save paths = {} @@ -703,12 +747,15 @@ def test_hdf5() -> None: def test_replaybuffermanager() -> None: buf = VectorReplayBuffer(20, 4) - batch = Batch( - obs=[1, 2, 3], - act=[1, 2, 3], - rew=[1, 2, 3], - terminated=[0, 0, 1], - truncated=[0, 0, 0], + batch = cast( + RolloutBatchProtocol, + Batch( + obs=[1, 2, 3], + act=[1, 2, 3], + rew=[1, 2, 3], + terminated=[0, 0, 1], + truncated=[0, 0, 0], + ), ) ptr, ep_rew, ep_len, ep_idx = buf.add(batch, buffer_ids=[0, 1, 2]) assert np.all(ep_len == [0, 0, 1]) @@ -728,7 +775,10 @@ def test_replaybuffermanager() -> None: indices_next = buf.next(indices) assert np.allclose(indices_next, indices), indices_next assert np.allclose(buf.unfinished_index(), [0, 5]) - buf.add(Batch(obs=[4], act=[4], rew=[4], terminated=[1], truncated=[0]), buffer_ids=[3]) + buf.add( + cast(RolloutBatchProtocol, Batch(obs=[4], act=[4], rew=[4], terminated=[1], truncated=[0])), + buffer_ids=[3], + ) assert np.allclose(buf.unfinished_index(), [0, 5]) batch, indices = buf.sample(10) batch, indices = buf.sample(0) @@ -739,20 +789,32 @@ def test_replaybuffermanager() -> None: assert np.allclose(indices_next, indices), indices_next data = np.array([0, 0, 0, 0]) buf.add( - Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), + ), buffer_ids=[0, 1, 2, 3], ) buf.add( - Batch(obs=data, act=data, rew=data, terminated=1 - data, truncated=data), + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=data, terminated=1 - data, truncated=data), + ), buffer_ids=[0, 1, 2, 3], ) assert len(buf) == 12 buf.add( - Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), + ), buffer_ids=[0, 1, 2, 3], ) buf.add( - Batch(obs=data, act=data, rew=data, terminated=[0, 1, 0, 1], truncated=data), + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=data, terminated=[0, 1, 0, 1], truncated=data), + ), buffer_ids=[0, 1, 2, 3], ) assert len(buf) == 20 @@ -839,7 +901,7 @@ def test_replaybuffermanager() -> None: ) assert np.allclose(buf.unfinished_index(), [4, 14]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[1], act=[1], rew=[1], terminated=[1], truncated=[0]), + cast(RolloutBatchProtocol, Batch(obs=[1], act=[1], rew=[1], terminated=[1], truncated=[0])), buffer_ids=[2], ) assert np.all(ep_len == [3]) @@ -915,7 +977,7 @@ def test_cachedbuffer() -> None: assert buf.sample_indices(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[1], act=[1], rew=[1], terminated=[0], truncated=[0]), + cast(RolloutBatchProtocol, Batch(obs=[1], act=[1], rew=[1], terminated=[0], truncated=[0])), buffer_ids=[1], ) obs = np.zeros(buf.maxsize) @@ -930,7 +992,7 @@ def test_cachedbuffer() -> None: assert np.all(ptr == [15]) assert np.all(ep_idx == [15]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[2], act=[2], rew=[2], terminated=[1], truncated=[0]), + cast(RolloutBatchProtocol, Batch(obs=[2], act=[2], rew=[2], terminated=[1], truncated=[0])), buffer_ids=[3], ) obs[[0, 25]] = 2 @@ -946,7 +1008,10 @@ def test_cachedbuffer() -> None: assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_indices(0), [0, 15]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], terminated=[0, 1], truncated=[0, 0]), + cast( + RolloutBatchProtocol, + Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], terminated=[0, 1], truncated=[0, 0]), + ), buffer_ids=[3, 1], # TODO ) assert np.all(ep_len == [0, 2]) @@ -968,12 +1033,35 @@ def test_cachedbuffer() -> None: buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) data = np.zeros(4) rew = np.ones([4, 4]) - buf.add(Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 1, 1], truncated=[0, 0, 0, 0])) - buf.add(Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0])) - buf.add(Batch(obs=data, act=data, rew=rew, terminated=[1, 1, 1, 1], truncated=[0, 0, 0, 0])) - buf.add(Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0])) + buf.add( + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 1, 1], truncated=[0, 0, 0, 0]), + ), + ) + buf.add( + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0]), + ), + ) + buf.add( + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=rew, terminated=[1, 1, 1, 1], truncated=[0, 0, 0, 0]), + ), + ) + buf.add( + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0]), + ), + ) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=data, act=data, rew=rew, terminated=[0, 1, 0, 1], truncated=[0, 0, 0, 0]), + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=rew, terminated=[0, 1, 0, 1], truncated=[0, 0, 0, 0]), + ), ) assert np.all(ptr == [1, -1, 11, -1]) assert np.all(ep_idx == [0, -1, 10, -1]) @@ -1041,14 +1129,17 @@ def test_multibuf_stack() -> None: truncated_list = [truncated] * cached_num obs_next_list = -obs_list info_list = [info] * cached_num - batch = Batch( - obs=obs_list, - act=act_list, - rew=rew_list, - terminated=terminated_list, - truncated=truncated_list, - obs_next=obs_next_list, - info=info_list, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=obs_list, + act=act_list, + rew=rew_list, + terminated=terminated_list, + truncated=truncated_list, + obs_next=obs_next_list, + info=info_list, + ), ) buf5.add(batch) buf4.add(batch) @@ -1184,13 +1275,16 @@ def test_multibuf_stack() -> None: ) obs = np.random.rand(size, 4, 84, 84) buf6.add( - Batch( - obs=[obs[2], obs[0]], - act=[1, 1], - rew=[0, 0], - terminated=[0, 1], - truncated=[0, 0], - obs_next=[obs[3], obs[1]], + cast( + RolloutBatchProtocol, + Batch( + obs=[obs[2], obs[0]], + act=[1, 1], + rew=[0, 0], + terminated=[0, 1], + truncated=[0, 0], + obs_next=[obs[3], obs[1]], + ), ), buffer_ids=[1, 2], ) @@ -1309,49 +1403,52 @@ def test_from_data() -> None: def test_custom_key() -> None: - batch = Batch( - obs_next=np.array( - [ + batch = cast( + RolloutBatchProtocol, + Batch( + obs_next=np.array( [ - 1.174, - -0.1151, - -0.609, - -0.5205, - -0.9316, - 3.236, - -2.418, - 0.386, - 0.2227, - -0.5117, - 2.293, + [ + 1.174, + -0.1151, + -0.609, + -0.5205, + -0.9316, + 3.236, + -2.418, + 0.386, + 0.2227, + -0.5117, + 2.293, + ], ], - ], - ), - rew=np.array([4.28125]), - act=np.array([[-0.3088, -0.4636, 0.4956]]), - truncated=np.array([False]), - obs=np.array( - [ + ), + rew=np.array([4.28125]), + act=np.array([[-0.3088, -0.4636, 0.4956]]), + truncated=np.array([False]), + obs=np.array( [ - 1.193, - -0.1203, - -0.6123, - -0.519, - -0.9434, - 3.32, - -2.266, - 0.9116, - 0.623, - 0.1259, - 0.363, + [ + 1.193, + -0.1203, + -0.6123, + -0.519, + -0.9434, + 3.32, + -2.266, + 0.9116, + 0.623, + 0.1259, + 0.363, + ], ], - ], + ), + terminated=np.array([False]), + done=np.array([False]), + returns=np.array([74.70343082]), + info=Batch(), + policy=Batch(), ), - terminated=np.array([False]), - done=np.array([False]), - returns=np.array([74.70343082]), - info=Batch(), - policy=Batch(), ) buffer_size = len(batch.rew) buffer = ReplayBuffer(buffer_size) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index d03a54df7..95b604905 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Sequence from test.base.env import MoveToRightEnv, NXEnv -from typing import Any +from typing import Any, cast import gymnasium as gym import numpy as np @@ -17,7 +17,11 @@ VectorReplayBuffer, ) from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ( + ActStateBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy, TrainingStats @@ -54,7 +58,7 @@ def forward( batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> Batch: + ) -> ActStateBatchProtocol: if self.need_state: if state is None: state = np.zeros((len(batch.obs), 2)) @@ -69,9 +73,9 @@ def forward( action_shape = len(batch.obs["index"]) else: action_shape = len(batch.obs) - return Batch(act=np.ones(action_shape), state=state) + return cast(ActStateBatchProtocol, Batch(act=np.ones(action_shape), state=state)) action_shape = self.action_shape if self.action_shape else len(batch.obs) - return Batch(act=np.ones(action_shape), state=state) + return cast(ActStateBatchProtocol, Batch(act=np.ones(action_shape), state=state)) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: raise NotImplementedError diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index ce8a93640..287e79677 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -12,7 +12,12 @@ from torch.utils.data import DataLoader, Dataset, DistributedSampler from tianshou.data import Batch, Collector -from tianshou.data.types import BatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ( + ActBatchProtocol, + BatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type from tianshou.policy import BasePolicy @@ -208,8 +213,8 @@ def forward( batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> Batch: - return Batch(act=np.stack([1] * len(batch))) + ) -> ActBatchProtocol: + return cast(ActBatchProtocol, Batch(act=np.stack([1] * len(batch)))) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> None: pass diff --git a/test/base/test_returns.py b/test/base/test_returns.py index ab4430b85..078893113 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -1,7 +1,10 @@ +from typing import cast + import numpy as np import torch from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import BasePolicy @@ -20,56 +23,68 @@ def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: def test_episodic_returns(size: int = 2560) -> None: fn = BasePolicy.compute_episodic_return buf = ReplayBuffer(20) - batch = Batch( - terminated=np.array([1, 0, 0, 1, 0, 0, 0, 1.0]), - truncated=np.array([0, 0, 0, 0, 0, 1, 0, 0]), - rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.0]), - info=Batch( - { - "TimeLimit.truncated": np.array( - [False, False, False, False, False, True, False, False], - ), - }, + batch = cast( + RolloutBatchProtocol, + Batch( + terminated=np.array([1, 0, 0, 1, 0, 0, 0, 1.0]), + truncated=np.array([0, 0, 0, 0, 0, 1, 0, 0]), + rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.0]), + info=Batch( + { + "TimeLimit.truncated": np.array( + [False, False, False, False, False, True, False, False], + ), + }, + ), ), ) for b in iter(batch): - b.obs = b.act = 1 + b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) assert np.allclose(returns, ans) buf.reset() - batch = Batch( - terminated=np.array([0, 1, 0, 1, 0, 1, 0.0]), - truncated=np.array([0, 0, 0, 0, 0, 0, 0.0]), - rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), + batch = cast( + RolloutBatchProtocol, + Batch( + terminated=np.array([0, 1, 0, 1, 0, 1, 0.0]), + truncated=np.array([0, 0, 0, 0, 0, 0, 0.0]), + rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), + ), ) for b in iter(batch): - b.obs = b.act = 1 + b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) assert np.allclose(returns, ans) buf.reset() - batch = Batch( - terminated=np.array([0, 1, 0, 1, 0, 0, 1.0]), - truncated=np.array([0, 0, 0, 0, 0, 0, 0]), - rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), + batch = cast( + RolloutBatchProtocol, + Batch( + terminated=np.array([0, 1, 0, 1, 0, 0, 1.0]), + truncated=np.array([0, 0, 0, 0, 0, 0, 0]), + rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), + ), ) for b in iter(batch): - b.obs = b.act = 1 + b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) assert np.allclose(returns, ans) buf.reset() - batch = Batch( - terminated=np.array([0, 0, 0, 1.0, 0, 0, 0, 1, 0, 0, 0, 1]), - truncated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), - rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), + batch = cast( + RolloutBatchProtocol, + Batch( + terminated=np.array([0, 0, 0, 1.0, 0, 0, 0, 1, 0, 0, 0, 1]), + truncated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), + ), ) for b in batch: - b.obs = b.act = 1 + b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) v = np.array([2.0, 3.0, 4, -1, 5.0, 6.0, 7, -2, 8.0, 9.0, 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) @@ -91,33 +106,36 @@ def test_episodic_returns(size: int = 2560) -> None: ) assert np.allclose(returns, ground_truth) buf.reset() - batch = Batch( - terminated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), - truncated=np.array([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]), - rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), - info=Batch( - { - "TimeLimit.truncated": np.array( - [ - False, - False, - False, - True, - False, - False, - False, - True, - False, - False, - False, - False, - ], - ), - }, + batch = cast( + RolloutBatchProtocol, + Batch( + terminated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), + truncated=np.array([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]), + rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), + info=Batch( + { + "TimeLimit.truncated": np.array( + [ + False, + False, + False, + True, + False, + False, + False, + True, + False, + False, + False, + False, + ], + ), + }, + ), ), ) for b in iter(batch): - b.obs = b.act = 1 + b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) v = np.array([2.0, 3.0, 4, -1, 5.0, 6.0, 7, -2, 8.0, 9.0, 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) @@ -180,12 +198,15 @@ def test_nstep_returns(size: int = 10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( - Batch( - obs=0, - act=0, - rew=i + 1, - terminated=i % 4 == 3, - truncated=False, + cast( + RolloutBatchProtocol, + Batch( + obs=0, + act=0, + rew=i + 1, + terminated=i % 4 == 3, + truncated=False, + ), ), ) batch, indices = buf.sample(0) @@ -258,13 +279,16 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( - Batch( - obs=0, - act=0, - rew=i + 1, - terminated=i % 4 == 3 and i != 3, - truncated=i == 3, - info={"TimeLimit.truncated": i == 3}, + cast( + RolloutBatchProtocol, + Batch( + obs=0, + act=0, + rew=i + 1, + terminated=i % 4 == 3 and i != 3, + truncated=i == 3, + info={"TimeLimit.truncated": i == 3}, + ), ), ) batch, indices = buf.sample(0) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 40d6ffd4e..dbd0afc54 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -6,6 +6,7 @@ from types import EllipsisType from typing import ( Any, + Literal, Protocol, Self, TypeVar, @@ -429,7 +430,7 @@ def set_array_at_key( self, seq: np.ndarray, key: str, - index: Sequence[int] | None = None, + index: IndexType | None = None, default_value: float | None = None, ) -> None: """Set a sequence of values at a given key. @@ -474,6 +475,29 @@ def dropnull(self) -> Self: """ ... + @overload + def apply_array_func( + self, + array_func: Callable[[np.ndarray | torch.Tensor], Any], + ) -> Self: + ... + + @overload + def apply_array_func( + self, + array_func: Callable[[np.ndarray | torch.Tensor], Any], + inplace: Literal[True], + ) -> None: + ... + + @overload + def apply_array_func( + self, + array_func: Callable[[np.ndarray | torch.Tensor], Any], + inplace: Literal[False], + ) -> Self: + ... + def apply_array_func( self, array_func: Callable[[np.ndarray | torch.Tensor], Any], @@ -744,7 +768,7 @@ def arr_to_torch(arr: arr_type) -> arr_type: if dtype is not None: arr = arr.type(dtype) return arr.to(device) - return None + return arr self.apply_array_func(arr_to_torch, inplace=True) @@ -1027,6 +1051,29 @@ def split( break yield self[indices[idx : idx + size]] + @overload + def apply_array_func( + self, + array_func: Callable[[np.ndarray | torch.Tensor], Any], + ) -> Self: + ... + + @overload + def apply_array_func( + self, + array_func: Callable[[np.ndarray | torch.Tensor], Any], + inplace: Literal[True], + ) -> None: + ... + + @overload + def apply_array_func( + self, + array_func: Callable[[np.ndarray | torch.Tensor], Any], + inplace: Literal[False], + ) -> Self: + ... + def apply_array_func( self, array_func: Callable[[np.ndarray | torch.Tensor], Any], @@ -1079,7 +1126,7 @@ def hasnull(self) -> bool: isnan_batch = self.isnull() is_any_null_batch = isnan_batch.apply_array_func(np.any, inplace=False) - def is_any_true(boolean_batch: BatchProtocol): + def is_any_true(boolean_batch: BatchProtocol) -> bool: for val in boolean_batch.values(): if isinstance(val, BatchProtocol): if is_any_true(val): diff --git a/tianshou/data/types.py b/tianshou/data/types.py index 3572e5484..1200b0df4 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -20,7 +20,7 @@ class ObsBatchProtocol(BatchProtocol, Protocol): """ obs: arr_type | BatchProtocol - info: arr_type + info: arr_type | BatchProtocol class RolloutBatchProtocol(ObsBatchProtocol, Protocol): diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index e88214d45..05cc8db8f 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -158,7 +158,7 @@ def process_fn( # type: ignore results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice) if has_rew: # restore from save_rew buffer._meta.rew = save_rew - return Batch(results) + return cast(MAPRolloutBatchProtocol, Batch(results)) _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") From 2480a62a61e8e4630ecb39bfe2ed90259a21be7c Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 29 Jul 2024 20:47:08 +0200 Subject: [PATCH 04/12] Improved typing in Batch and Protocol, added get and pop explicitly --- test/base/test_batch.py | 2 +- tianshou/data/batch.py | 169 ++++++++++++++++++++++++----------- tianshou/data/buffer/base.py | 2 +- tianshou/data/buffer/prio.py | 3 + tianshou/data/types.py | 18 ++-- tianshou/policy/base.py | 8 +- 6 files changed, 134 insertions(+), 68 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index a26d5d679..c6e1e5810 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -378,7 +378,7 @@ def test_batch_over_batch_to_torch() -> None: a=np.float64(1.0), b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)), ) - batch.b.__dict__["e"] = 1 # bypass the check + batch.b.set_array_at_key(np.array([1]), "e") batch.to_torch_() assert isinstance(batch.a, torch.Tensor) assert isinstance(batch.b.c, torch.Tensor) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index dbd0afc54..a76a31144 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -12,6 +12,7 @@ TypeVar, Union, cast, + get_args, overload, runtime_checkable, ) @@ -20,13 +21,16 @@ import pandas as pd import torch from deepdiff import DeepDiff +from torch.distributions import Categorical, Distribution, Independent, Normal from tianshou.utils import logging _SingleIndexType = slice | int | EllipsisType IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] TBatch = TypeVar("TBatch", bound="BatchProtocol") -arr_type = torch.Tensor | np.ndarray +TDistribution = TypeVar("TDistribution", bound=Distribution) +T = TypeVar("T") +TArr = torch.Tensor | np.ndarray log = logging.getLogger(__name__) @@ -202,6 +206,33 @@ def alloc_by_keys_diff( meta[key] = create_value(batch[key], size, stack) +class ProtocolCalledException(Exception): + """The methods of a Protocol should never be called. + + Currently, no static type checker actually verifies that a class that inherits + from a Protocol does in fact provide the correct interface. Thus, it may happen + that a method of the protocol is called accidentally (this is an + implementation error). The normal error for that is a somewhat cryptical + AttributeError, wherefore we instead raise this custom exception in the + BatchProtocol. + + Finally and importantly: using this in BatchProtocol makes mypy verify the fields + in the various sub-protocols and thus renders is MUCH more useful! + """ + + +def get_sliced_dist(dist: TDistribution, index: IndexType) -> TDistribution: + """Slice a distribution object by the given index.""" + if isinstance(dist, Categorical): + return Categorical(probs=dist.probs[index]) # type: ignore[return-value] + if isinstance(dist, Normal): + return Normal(loc=dist.loc[index], scale=dist.scale[index]) # type: ignore[return-value] + if isinstance(dist, Independent): + return Independent(get_sliced_dist(dist.base_dist, index), dist.reinterpreted_batch_ndims) # type: ignore[return-value] + else: + raise NotImplementedError(f"Unsupported distribution for slicing: {dist}") + + # Note: This is implemented as a protocol because the interface # of Batch is always extended by adding new fields. Having a hierarchy of # protocols building off this one allows for type safety and IDE support despite @@ -220,72 +251,75 @@ class BatchProtocol(Protocol): @property def shape(self) -> list[int]: - ... - + raise ProtocolCalledException + + # NOTE: even though setattr and getattr are defined for any object, we need + # to explicitly define them for the BatchProtocol, since otherwise mypy will + # complain about new fields being added dynamically. For example, things like + # `batch.new_field = ...` followed by using `batch.new_field` become type errors + # if getattr and setattr are missing in the BatchProtocol. + # + # For the moment, tianshou relies on this kind of dynamic-field-addition + # in many, many places. In principle, it would be better to construct new + # objects with new combinations of fields instead of mutating existing ones - the + # latter is error-prone and can't properly be expressed with types. May be in a + # future, rather different version of tianshou it would be feasible to have stricter + # typing. Then the need for Protocols would in fact disappear def __setattr__(self, key: str, value: Any) -> None: - ... + raise ProtocolCalledException def __getattr__(self, key: str) -> Any: - ... + raise ProtocolCalledException - def __contains__(self, key: str) -> bool: - ... - - def __getstate__(self) -> dict: - ... - - def __setstate__(self, state: dict) -> None: - ... + def __iter__(self) -> Iterator[Self]: + raise ProtocolCalledException @overload def __getitem__(self, index: str) -> Any: - ... + raise ProtocolCalledException @overload def __getitem__(self, index: IndexType) -> Self: - ... + raise ProtocolCalledException def __getitem__(self, index: str | IndexType) -> Any: - ... + raise ProtocolCalledException def __setitem__(self, index: str | IndexType, value: Any) -> None: - ... + raise ProtocolCalledException def __iadd__(self, other: Self | Number | np.number) -> Self: - ... + raise ProtocolCalledException def __add__(self, other: Self | Number | np.number) -> Self: - ... + raise ProtocolCalledException def __imul__(self, value: Number | np.number) -> Self: - ... + raise ProtocolCalledException def __mul__(self, value: Number | np.number) -> Self: - ... + raise ProtocolCalledException def __itruediv__(self, value: Number | np.number) -> Self: - ... + raise ProtocolCalledException def __truediv__(self, value: Number | np.number) -> Self: - ... + raise ProtocolCalledException def __repr__(self) -> str: - ... - - def __iter__(self) -> Iterator[Self]: - ... + raise ProtocolCalledException def __eq__(self, other: Any) -> bool: - ... + raise ProtocolCalledException @staticmethod def to_numpy(batch: TBatch) -> TBatch: """Change all torch.Tensor to numpy.ndarray and return a new Batch.""" - ... + raise ProtocolCalledException def to_numpy_(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" - ... + raise ProtocolCalledException @staticmethod def to_torch( @@ -294,7 +328,7 @@ def to_torch( device: str | int | torch.device = "cpu", ) -> TBatch: """Change all numpy.ndarray to torch.Tensor and return a new Batch.""" - ... + raise ProtocolCalledException def to_torch_( self, @@ -302,11 +336,11 @@ def to_torch_( device: str | int | torch.device = "cpu", ) -> None: """Change all numpy.ndarray to torch.Tensor in-place.""" - ... + raise ProtocolCalledException def cat_(self, batches: Self | Sequence[dict | Self]) -> None: """Concatenate a list of (or one) Batch objects into current batch.""" - ... + raise ProtocolCalledException @staticmethod def cat(batches: Sequence[dict | TBatch]) -> TBatch: @@ -326,11 +360,11 @@ def cat(batches: Sequence[dict | TBatch]) -> TBatch: >>> c.common.c.shape (7, 5) """ - ... + raise ProtocolCalledException def stack_(self, batches: Sequence[dict | Self], axis: int = 0) -> None: """Stack a list of Batch object into current batch.""" - ... + raise ProtocolCalledException @staticmethod def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch: @@ -355,7 +389,7 @@ def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch: If there are keys that are not shared across all batches, ``stack`` with ``axis != 0`` is undefined, and will cause an exception. """ - ... + raise ProtocolCalledException def empty_(self, index: slice | IndexType | None = None) -> Self: """Return an empty Batch object with 0 or None filled. @@ -382,7 +416,7 @@ def empty_(self, index: slice | IndexType | None = None) -> Self: ), ) """ - ... + raise ProtocolCalledException @staticmethod def empty(batch: TBatch, index: IndexType | None = None) -> TBatch: @@ -390,14 +424,14 @@ def empty(batch: TBatch, index: IndexType | None = None) -> TBatch: The shape is the same as the given Batch. """ - ... + raise ProtocolCalledException def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: """Update this batch from another dict/Batch.""" - ... + raise ProtocolCalledException def __len__(self) -> int: - ... + raise ProtocolCalledException def split( self, @@ -415,16 +449,16 @@ def split( :param merge_last: merge the last batch into the previous one. Default to False. """ - ... + raise ProtocolCalledException def to_dict(self, recurse: bool = True) -> dict[str, Any]: - ... + raise ProtocolCalledException def to_list_of_dicts(self) -> list[dict[str, Any]]: - ... + raise ProtocolCalledException def get_keys(self) -> KeysView: - ... + raise ProtocolCalledException def set_array_at_key( self, @@ -446,15 +480,15 @@ def set_array_at_key( Note that the array at the key will be of the same dtype as the passed sequence, so default value should be such that numpy can cast it to this dtype. """ - ... + raise ProtocolCalledException def isnull(self) -> Self: """Return a boolean mask of the same shape, indicating missing values.""" - ... + raise ProtocolCalledException def hasnull(self) -> bool: """Return whether the batch has missing values.""" - ... + raise ProtocolCalledException def dropnull(self) -> Self: """Return a batch where all items in which any value is null are dropped. @@ -509,7 +543,13 @@ def apply_array_func( :param inplace: whether to apply the function in-place. If False, a new batch is returned, otherwise the batch is modified in-place and None is returned. """ - ... + raise ProtocolCalledException + + def get(self, key: str, default: Any | None = None) -> Any: + raise ProtocolCalledException + + def pop(self, key: str, default: Any | None = None) -> Any: + raise ProtocolCalledException class Batch(BatchProtocol): @@ -553,6 +593,12 @@ def to_dict(self, recursive: bool = True) -> dict[str, Any]: def get_keys(self) -> KeysView: return self.__dict__.keys() + def get(self, key: str, default: Any | None = None) -> Any: + return self.__dict__.get(key, default) + + def pop(self, key: str, default: Any | None = None) -> Any: + return self.__dict__.pop(key, default) + def to_list_of_dicts(self) -> list[dict[str, Any]]: return [entry.to_dict() for entry in self] @@ -598,15 +644,23 @@ def __getitem__(self, index: IndexType) -> Self: ... def __getitem__(self, index: str | IndexType) -> Any: - """Return self[index].""" + """Returns either the value of a key or a sliced Batch object.""" if isinstance(index, str): return self.__dict__[index] batch_items = self.items() if len(batch_items) > 0: new_batch = Batch() for batch_key, obj in batch_items: - if isinstance(obj, Batch) and len(obj.get_keys()) == 0: + # None and empty Batches as values are added to any slice + if obj is None: + new_batch.__dict__[batch_key] = None + elif isinstance(obj, Batch) and len(obj.get_keys()) == 0: new_batch.__dict__[batch_key] = Batch() + # We attempt slicing of a distribution. This is hacky, but presents an important special case + elif isinstance(obj, Distribution): + new_batch.__dict__[batch_key] = get_sliced_dist(obj, index) + # All other objects are either array-like or Batch-like, so hopefully sliceable + # A batch should have no scalars, and if it does, slicing them is not supported else: new_batch.__dict__[batch_key] = obj[index] return new_batch @@ -729,7 +783,7 @@ def to_numpy(batch: TBatch) -> TBatch: return result def to_numpy_(self) -> None: - def arr_to_numpy(arr: arr_type) -> arr_type: + def arr_to_numpy(arr: TArr) -> TArr: if isinstance(arr, torch.Tensor): return arr.detach().cpu().numpy() return arr @@ -754,7 +808,7 @@ def to_torch_( if not isinstance(device, torch.device): device = torch.device(device) - def arr_to_torch(arr: arr_type) -> arr_type: + def arr_to_torch(arr: TArr) -> TArr: if isinstance(arr, np.ndarray): return torch.tensor(arr, dtype=dtype, device=device) @@ -858,6 +912,9 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: if len(batch_list) == 0: return batches = batch_list + + # TODO: lot's of the remaining logic is devoted to filling up remaining keys with zeros + # this should be removed, and also the check above should be extended to nested keys try: # len(batch) here means batch is a nested empty batch # like Batch(a=Batch), and we have to treat it as length zero and @@ -871,6 +928,7 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: ) from exception if len(self.get_keys()) != 0: batches = [self, *list(batches)] + # len of zero means that that item is Batch() and should be ignored lens = [0 if len(self) == 0 else len(self), *lens] self.__cat(batches, lens) @@ -1155,7 +1213,7 @@ def dropnull(self) -> Self: def _apply_array_func_recursively( batch: TBatch, - array_func: Callable[[np.ndarray | torch.Tensor], Any], + array_func: Callable[[TArr], Any], inplace: bool = False, ) -> TBatch | None: result = batch if inplace else deepcopy(batch) @@ -1163,6 +1221,11 @@ def _apply_array_func_recursively( if isinstance(val, BatchProtocol): result[key] = _apply_array_func_recursively(val, array_func, inplace=False) else: + if not isinstance(val, get_args(TArr)): + raise TypeError( + f"Unsupported type {type(val)} for value of {key=}. " + f"Supported values are torch tensors or numpy arrays.", + ) result[key] = array_func(val) if not inplace: return result diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 2ecddc5ce..2699d92d7 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -355,7 +355,7 @@ def get( set, return this default_value. :param stack_num: Default to self.stack_num. """ - if key not in self._meta and default_value is not None: + if key not in self._meta.get_keys() and default_value is not None: return default_value val = self._meta[key] if stack_num is None: diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index bef6a06a0..406e39afd 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -103,5 +103,8 @@ def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> PrioBatchP batch.weight = weight / np.max(weight) if self._weight_norm else weight return cast(PrioBatchProtocol, batch) + def sample(self, batch_size: int | None) -> tuple[PrioBatchProtocol, np.ndarray]: + return cast(tuple[PrioBatchProtocol, np.ndarray], super().sample(batch_size=batch_size)) + def set_beta(self, beta: float) -> None: self._beta = beta diff --git a/tianshou/data/types.py b/tianshou/data/types.py index 1200b0df4..a4fd43543 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -4,7 +4,7 @@ import torch from tianshou.data import Batch -from tianshou.data.batch import BatchProtocol, arr_type +from tianshou.data.batch import BatchProtocol, TArr TNestedDictValue = np.ndarray | dict[str, "TNestedDictValue"] @@ -19,24 +19,24 @@ class ObsBatchProtocol(BatchProtocol, Protocol): Typically used inside a policy's forward """ - obs: arr_type | BatchProtocol - info: arr_type | BatchProtocol + obs: TArr | BatchProtocol + info: TArr | BatchProtocol class RolloutBatchProtocol(ObsBatchProtocol, Protocol): """Typically, the outcome of sampling from a replay buffer.""" - obs_next: arr_type | BatchProtocol - act: arr_type + obs_next: TArr | BatchProtocol + act: TArr rew: np.ndarray - terminated: arr_type - truncated: arr_type + terminated: TArr + truncated: TArr class BatchWithReturnsProtocol(RolloutBatchProtocol, Protocol): """With added returns, usually computed with GAE.""" - returns: arr_type + returns: TArr class PrioBatchProtocol(RolloutBatchProtocol, Protocol): @@ -55,7 +55,7 @@ class RecurrentStateBatch(BatchProtocol, Protocol): class ActBatchProtocol(BatchProtocol, Protocol): """Simplest batch, just containing the action. Useful e.g., for random policy.""" - act: arr_type + act: TArr class ActStateBatchProtocol(ActBatchProtocol, Protocol): diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index c4fc9af3b..d886180a5 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -15,7 +15,7 @@ from torch import nn from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as -from tianshou.data.batch import Batch, BatchProtocol, arr_type +from tianshou.data.batch import Batch, BatchProtocol, TArr from tianshou.data.buffer.base import TBuffer from tianshou.data.types import ( ActBatchProtocol, @@ -355,7 +355,7 @@ def forward( """ @staticmethod - def _action_to_numpy(act: arr_type) -> np.ndarray: + def _action_to_numpy(act: TArr) -> np.ndarray: act = to_numpy(act) # NOTE: to_numpy could confusingly also return a Batch if not isinstance(act, np.ndarray): raise ValueError( @@ -365,7 +365,7 @@ def _action_to_numpy(act: arr_type) -> np.ndarray: def map_action( self, - act: arr_type, + act: TArr, ) -> np.ndarray: """Map raw network output to action range in gym's env.action_space. @@ -400,7 +400,7 @@ def map_action( def map_action_inverse( self, - act: arr_type, + act: TArr, ) -> np.ndarray: """Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`. From 539c1590bc99e031276099cb6d4979d10ffdf0b1 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 30 Jul 2024 00:01:06 +0200 Subject: [PATCH 05/12] Fixed Batch.to_torch (was previously copying the array) --- test/base/test_batch.py | 136 +++++++++++++++++++++++++++------------- tianshou/data/batch.py | 99 ++++++++++++++++------------- 2 files changed, 150 insertions(+), 85 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index c6e1e5810..9e7015690 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -9,8 +9,10 @@ import pytest import torch from deepdiff import DeepDiff +from torch.distributions.categorical import Categorical from tianshou.data import Batch, to_numpy, to_torch +from tianshou.data.batch import IndexType, get_sliced_dist def test_batch() -> None: @@ -122,8 +124,8 @@ def test_batch() -> None: with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) - assert isinstance(batch2[0].a.b, np.float64) - assert isinstance(batch2[0].a.d.e, np.float64) + assert isinstance(batch2[0].a.b, float) + assert isinstance(batch2[0].a.d.e, float) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch(list(batch2)) assert batch2_from_list.a.b == batch2.a.b @@ -373,31 +375,6 @@ def test_batch_cat_and_stack() -> None: Batch.stack([b1, b2], axis=1) -def test_batch_over_batch_to_torch() -> None: - batch = Batch( - a=np.float64(1.0), - b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)), - ) - batch.b.set_array_at_key(np.array([1]), "e") - batch.to_torch_() - assert isinstance(batch.a, torch.Tensor) - assert isinstance(batch.b.c, torch.Tensor) - assert isinstance(batch.b.d, torch.Tensor) - assert isinstance(batch.b.e, torch.Tensor) - assert batch.a.dtype == torch.float64 - assert batch.b.c.dtype == torch.float32 - assert batch.b.d.dtype == torch.float64 - if sys.platform in ["win32", "cygwin"]: # windows - assert batch.b.e.dtype == torch.int32 - else: - assert batch.b.e.dtype == torch.int64 - batch.to_torch_(dtype=torch.float32) - assert batch.a.dtype == torch.float32 - assert batch.b.c.dtype == torch.float32 - assert batch.b.d.dtype == torch.float32 - assert batch.b.e.dtype == torch.float32 - - def test_utils_to_torch_numpy() -> None: batch = Batch( a=np.float64(1.0), @@ -473,18 +450,6 @@ def test_batch_pickle() -> None: assert np.all(batch.np == batch_pk.np) -def test_batch_from_to_numpy_without_copy() -> None: - batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) - a_mem_addr_orig = batch.a.__array_interface__["data"][0] - c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] - batch.to_torch_() - batch.to_numpy_() - a_mem_addr_new = batch.a.__array_interface__["data"][0] - c_mem_addr_new = batch.b.c.__array_interface__["data"][0] - assert a_mem_addr_new == a_mem_addr_orig - assert c_mem_addr_new == c_mem_addr_orig - - def test_batch_copy() -> None: batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6])) batch2 = Batch({"c": np.array([6, 7, 8]), "b": batch}) @@ -703,7 +668,7 @@ def test_to_dict_nested_batch_with_torch_tensor() -> None: assert not DeepDiff(batch.to_dict(recursive=True), expected) -class TestConversions: +class TestBatchConversions: @staticmethod def test_to_numpy() -> None: batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) @@ -747,15 +712,63 @@ def test_to_torch_() -> None: @staticmethod def test_apply_array_func() -> None: batch = Batch(a=1, b=np.arange(3), c={"d": np.array([1, 2, 3])}) - batch_with_max = batch.apply_array_func(np.max) + batch_with_max = batch.apply_values_transform(np.max) assert np.array_equal(batch_with_max.a, np.array(1)) assert np.array_equal(batch_with_max.b, np.array(2)) assert np.array_equal(batch_with_max.c.d, np.array(3)) - batch_array_added = batch.apply_array_func(lambda x: x + np.array([1, 2, 3])) + batch_array_added = batch.apply_values_transform(lambda x: x + np.array([1, 2, 3])) assert np.array_equal(batch_array_added.a, np.array([2, 3, 4])) assert np.array_equal(batch_array_added.c.d, np.array([2, 4, 6])) + @staticmethod + def test_batch_to_numpy_without_copy() -> None: + batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) + a_mem_addr_orig = batch.a.__array_interface__["data"][0] + c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] + batch.to_numpy_() + a_mem_addr_new = batch.a.__array_interface__["data"][0] + c_mem_addr_new = batch.b.c.__array_interface__["data"][0] + assert a_mem_addr_new == a_mem_addr_orig + assert c_mem_addr_new == c_mem_addr_orig + + @staticmethod + def test_batch_from_to_numpy_without_copy() -> None: + batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) + a_mem_addr_orig = batch.a.__array_interface__["data"][0] + c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] + batch.to_torch_() + batch.to_numpy_() + a_mem_addr_new = batch.a.__array_interface__["data"][0] + c_mem_addr_new = batch.b.c.__array_interface__["data"][0] + assert a_mem_addr_new == a_mem_addr_orig + assert c_mem_addr_new == c_mem_addr_orig + + @staticmethod + def test_batch_over_batch_to_torch() -> None: + batch = Batch( + a=np.float64(1.0), + b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)), + ) + batch.b.set_array_at_key(np.array([1]), "e") + batch.to_torch_() + assert isinstance(batch.a, torch.Tensor) + assert isinstance(batch.b.c, torch.Tensor) + assert isinstance(batch.b.d, torch.Tensor) + assert isinstance(batch.b.e, torch.Tensor) + assert batch.a.dtype == torch.float64 + assert batch.b.c.dtype == torch.float32 + assert batch.b.d.dtype == torch.float64 + if sys.platform in ["win32", "cygwin"]: # windows + assert batch.b.e.dtype == torch.int32 + else: + assert batch.b.e.dtype == torch.int64 + batch.to_torch_(dtype=torch.float32) + assert batch.a.dtype == torch.float32 + assert batch.b.c.dtype == torch.float32 + assert batch.b.d.dtype == torch.float32 + assert batch.b.e.dtype == torch.float32 + class TestAssignment: @staticmethod @@ -826,10 +839,47 @@ def test_hasnull() -> None: @staticmethod def test_dropnull() -> None: batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([None, 2.1, 3.0])}) - assert batch.dropnull() == Batch(a=[5], b=[8], c={"d": np.array([2.1])}).apply_array_func( + assert batch.dropnull() == Batch( + a=[5], + b=[8], + c={"d": np.array([2.1])}, + ).apply_values_transform( np.atleast_1d, ) batch2 = Batch(a=[4, 5, 6, 7], b=[7, 8, None, 10], c={"d": np.array([None, 2, 3, 4])}) assert batch2.dropnull() == Batch(a=[5, 7], b=[8, 10], c={"d": np.array([2, 4])}) batch_no_nan = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) assert batch_no_nan.dropnull() == batch_no_nan + + +class TestSlicing: + # TODO: parametrize with other dists + @staticmethod + def test_slice_distribution() -> None: + cat_probs = torch.randint(1, 10, (10, 3)) + dist = Categorical(probs=cat_probs) + batch = Batch(dist=dist) + selected_idx = [1, 3] + sliced_batch = batch[selected_idx] + sliced_probs = cat_probs[selected_idx] + assert (sliced_batch.dist.probs == Categorical(probs=sliced_probs).probs).all() + assert ( + Categorical(probs=sliced_probs).probs == get_sliced_dist(dist, selected_idx).probs + ).all() + # retrieving a single index + assert (batch[0].dist.probs == dist.probs[0]).all() + + @staticmethod + def test_getitem_with_int_gives_scalars() -> None: + batch = Batch(a=[1, 2], b=Batch(c=[3, 4])) + batch_sliced = batch[0] + assert batch_sliced.a == np.array(1) + assert batch_sliced.b.c == np.array(3) + + @staticmethod + @pytest.mark.parametrize("index", ([0, 1], np.array([0, 1]), torch.tensor([0, 1]), slice(0, 2))) + def test_getitem_with_slice_gives_subslice(index: IndexType) -> None: + batch = Batch(a=[1, 2, 3], b=Batch(c=torch.tensor([4, 5, 6]))) + batch_sliced = batch[index] + assert (batch_sliced.a == batch.a[index]).all() + assert (batch_sliced.b.c == batch.b.c[index]).all() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index a76a31144..854707f9c 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -12,7 +12,6 @@ TypeVar, Union, cast, - get_args, overload, runtime_checkable, ) @@ -26,7 +25,7 @@ from tianshou.utils import logging _SingleIndexType = slice | int | EllipsisType -IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] +IndexType = np.ndarray | _SingleIndexType | Sequence[_SingleIndexType] TBatch = TypeVar("TBatch", bound="BatchProtocol") TDistribution = TypeVar("TDistribution", bound=Distribution) T = TypeVar("T") @@ -228,7 +227,10 @@ def get_sliced_dist(dist: TDistribution, index: IndexType) -> TDistribution: if isinstance(dist, Normal): return Normal(loc=dist.loc[index], scale=dist.scale[index]) # type: ignore[return-value] if isinstance(dist, Independent): - return Independent(get_sliced_dist(dist.base_dist, index), dist.reinterpreted_batch_ndims) # type: ignore[return-value] + return Independent( + get_sliced_dist(dist.base_dist, index), + dist.reinterpreted_batch_ndims, + ) # type: ignore[return-value] else: raise NotImplementedError(f"Unsupported distribution for slicing: {dist}") @@ -510,36 +512,36 @@ def dropnull(self) -> Self: ... @overload - def apply_array_func( + def apply_values_transform( self, - array_func: Callable[[np.ndarray | torch.Tensor], Any], + values_transform: Callable[[np.ndarray | torch.Tensor], Any], ) -> Self: ... @overload - def apply_array_func( + def apply_values_transform( self, - array_func: Callable[[np.ndarray | torch.Tensor], Any], + values_transform: Callable, inplace: Literal[True], ) -> None: ... @overload - def apply_array_func( + def apply_values_transform( self, - array_func: Callable[[np.ndarray | torch.Tensor], Any], + values_transform: Callable[[np.ndarray | torch.Tensor], Any], inplace: Literal[False], ) -> Self: ... - def apply_array_func( + def apply_values_transform( self, - array_func: Callable[[np.ndarray | torch.Tensor], Any], + values_transform: Callable[[np.ndarray | torch.Tensor], Any], inplace: bool = False, ) -> None | Self: """Apply a function to all arrays in the batch, including nested ones. - :param array_func: the function to apply to the arrays. + :param values_transform: the function to apply to the arrays. :param inplace: whether to apply the function in-place. If False, a new batch is returned, otherwise the batch is modified in-place and None is returned. """ @@ -650,19 +652,22 @@ def __getitem__(self, index: str | IndexType) -> Any: batch_items = self.items() if len(batch_items) > 0: new_batch = Batch() + + sliced_obj: Any for batch_key, obj in batch_items: # None and empty Batches as values are added to any slice if obj is None: - new_batch.__dict__[batch_key] = None + sliced_obj = None elif isinstance(obj, Batch) and len(obj.get_keys()) == 0: - new_batch.__dict__[batch_key] = Batch() + sliced_obj = Batch() # We attempt slicing of a distribution. This is hacky, but presents an important special case elif isinstance(obj, Distribution): - new_batch.__dict__[batch_key] = get_sliced_dist(obj, index) + sliced_obj = get_sliced_dist(obj, index) # All other objects are either array-like or Batch-like, so hopefully sliceable - # A batch should have no scalars, and if it does, slicing them is not supported + # A batch should have no scalars else: - new_batch.__dict__[batch_key] = obj[index] + sliced_obj = obj[index] + new_batch.__dict__[batch_key] = sliced_obj return new_batch raise IndexError("Cannot access item from empty Batch object.") @@ -788,7 +793,7 @@ def arr_to_numpy(arr: TArr) -> TArr: return arr.detach().cpu().numpy() return arr - self.apply_array_func(arr_to_numpy, inplace=True) + self.apply_values_transform(arr_to_numpy, inplace=True) @staticmethod def to_torch( @@ -810,7 +815,7 @@ def to_torch_( def arr_to_torch(arr: TArr) -> TArr: if isinstance(arr, np.ndarray): - return torch.tensor(arr, dtype=dtype, device=device) + return torch.from_numpy(arr).to(device) # TODO: simplify if ( @@ -824,7 +829,7 @@ def arr_to_torch(arr: TArr) -> TArr: return arr.to(device) return arr - self.apply_array_func(arr_to_torch, inplace=True) + self.apply_values_transform(arr_to_torch, inplace=True) def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: """Private method for Batch.cat_. @@ -1060,7 +1065,6 @@ def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: self.update(kwargs) def __len__(self) -> int: - """Return len(self).""" lens = [] for obj in self.__dict__.values(): # TODO: causes inconsistent behavior to batch with empty batches @@ -1110,34 +1114,46 @@ def split( yield self[indices[idx : idx + size]] @overload - def apply_array_func( + def apply_values_transform( self, - array_func: Callable[[np.ndarray | torch.Tensor], Any], + values_transform: Callable, ) -> Self: ... @overload - def apply_array_func( + def apply_values_transform( self, - array_func: Callable[[np.ndarray | torch.Tensor], Any], + values_transform: Callable, inplace: Literal[True], ) -> None: ... @overload - def apply_array_func( + def apply_values_transform( self, - array_func: Callable[[np.ndarray | torch.Tensor], Any], + values_transform: Callable, inplace: Literal[False], ) -> Self: ... - def apply_array_func( + def apply_values_transform( self, - array_func: Callable[[np.ndarray | torch.Tensor], Any], + values_transform: Callable, inplace: bool = False, ) -> None | Self: - return _apply_array_func_recursively(self, array_func, inplace=inplace) + """Applies a function to all non-batch-values in the batch, including + values in nested batches. + + A batch with keys pointing to either batches or to non-batch values can + be thought of as a tree of Batch nodes. This function traverses the tree + and applies the function to all leaf nodes (i.e. values that are not + batches themselves). + + The values are usually arrays, but can also be scalar values of an + arbitrary type since retrieving a single entry from a Batch a la + `batch[0]` will return a batch with scalar values. + """ + return _apply_batch_values_func_recursively(self, values_transform, inplace=inplace) def set_array_at_key( self, @@ -1178,11 +1194,11 @@ def set_array_at_key( self[key] = arr def isnull(self) -> Self: - return self.apply_array_func(pd.isnull, inplace=False) + return self.apply_values_transform(pd.isnull, inplace=False) def hasnull(self) -> bool: isnan_batch = self.isnull() - is_any_null_batch = isnan_batch.apply_array_func(np.any, inplace=False) + is_any_null_batch = isnan_batch.apply_values_transform(np.any, inplace=False) def is_any_true(boolean_batch: BatchProtocol) -> bool: for val in boolean_batch.values(): @@ -1206,27 +1222,26 @@ def dropnull(self) -> Self: if b.hasnull(): continue # needed for cat to work - b = b.apply_array_func(np.atleast_1d) + b = b.apply_values_transform(np.atleast_1d) sub_batches.append(b) return Batch.cat(sub_batches) -def _apply_array_func_recursively( +def _apply_batch_values_func_recursively( batch: TBatch, - array_func: Callable[[TArr], Any], + values_transform: Callable, inplace: bool = False, ) -> TBatch | None: + """Applies the desired function on all values of the batch recursively. + + See docstring of the corresponding method in the Batch class for more details. + """ result = batch if inplace else deepcopy(batch) for key, val in batch.__dict__.items(): if isinstance(val, BatchProtocol): - result[key] = _apply_array_func_recursively(val, array_func, inplace=False) + result[key] = _apply_batch_values_func_recursively(val, values_transform, inplace=False) else: - if not isinstance(val, get_args(TArr)): - raise TypeError( - f"Unsupported type {type(val)} for value of {key=}. " - f"Supported values are torch tensors or numpy arrays.", - ) - result[key] = array_func(val) + result[key] = values_transform(val) if not inplace: return result return None From 94a66a0006d011166f005aa1b9f2f7d077cf8824 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 30 Jul 2024 00:05:48 +0200 Subject: [PATCH 06/12] Spelling --- docs/spelling_wordlist.txt | 2 +- tianshou/data/batch.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 83de82356..c30b9f2cb 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -270,4 +270,4 @@ v_s v_s_ obs obs_next - +dtype diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 854707f9c..f65c9cddb 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -211,7 +211,7 @@ class ProtocolCalledException(Exception): Currently, no static type checker actually verifies that a class that inherits from a Protocol does in fact provide the correct interface. Thus, it may happen that a method of the protocol is called accidentally (this is an - implementation error). The normal error for that is a somewhat cryptical + implementation error). The normal error for that is a somewhat cryptic AttributeError, wherefore we instead raise this custom exception in the BatchProtocol. @@ -471,16 +471,17 @@ def set_array_at_key( ) -> None: """Set a sequence of values at a given key. - If index is not passed, the sequence must have the same length as the batch. + If `index` is not passed, the sequence must have the same length as the batch. + :param seq: the array of values to set. :param key: the key to set the sequence at. :param index: the indices to set the sequence at. If None, the sequence must have the same length as the batch and will be set at all indices. - :param default_value: this only applies if index is passed an the key does not exist yet - in the batch. In that case entries outside the passed index will be filled + :param default_value: this only applies if `index` is passed and the key does not exist yet + in the batch. In that case, entries outside the passed index will be filled with this default value. Note that the array at the key will be of the same dtype as the passed sequence, - so default value should be such that numpy can cast it to this dtype. + so `default_value` should be such that numpy can cast it to this dtype. """ raise ProtocolCalledException From cbba46129180a6aa575c08193f61dedbbffe0e95 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 30 Jul 2024 00:26:23 +0200 Subject: [PATCH 07/12] More input validation in Batch.cat_ --- tianshou/data/batch.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index f65c9cddb..bd78255ac 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -906,17 +906,32 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: batches = [batches] # check input format batch_list = [] + + original_keys_only_batch = self.apply_array_func(lambda x: None) if len(self) > 0 else None + """A batch with all values removed, just keys left. Can be considered a sort of schema.""" + for batch in batches: if isinstance(batch, dict): - if len(batch) > 0: - batch_list.append(Batch(batch)) - elif isinstance(batch, Batch): - if len(batch.get_keys()) != 0: - batch_list.append(batch) - else: + batch = Batch(batch) + if not isinstance(batch, Batch): raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_") + if len(batch.get_keys()) == 0: + continue + if original_keys_only_batch is None: + original_keys_only_batch = batch.apply_values_transform(lambda x: None) + batch_list.append(batch) + continue + + cur_keys_only_batch = batch.apply_values_transform(lambda x: None) + if original_keys_only_batch != cur_keys_only_batch: + raise ValueError( + f"Batch.cat_ only supports concatenation of batches with the same structure but got " + f"structures {original_keys_only_batch} and {cur_keys_only_batch}.", + ) + batch_list.append(batch) if len(batch_list) == 0: return + batches = batch_list # TODO: lot's of the remaining logic is devoted to filling up remaining keys with zeros From 61c4ffd6bc975f7aab1fa6d09b34496e98cb31e0 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 30 Jul 2024 12:49:54 +0200 Subject: [PATCH 08/12] Fixed batch schema comparison in cat_ Removed part of the tests of cat_ that were handling incompatible batches --- test/base/test_batch.py | 37 ++++++------------------------------- tianshou/data/batch.py | 28 +++++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 9e7015690..9accff86c 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -246,13 +246,14 @@ def test_batch_cat_and_stack() -> None: assert b12_cat_in.a.d.e.ndim == 1 a = Batch(a=Batch(a=np.random.randn(3, 4))) + a_empty = Batch(a=Batch(a=Batch())) assert np.allclose( np.concatenate([a.a.a, a.a.a]), - Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a, + Batch.cat([a, a_empty, a]).a.a, ) # test cat with lens infer - a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4)) + a = Batch(a=Batch(a=np.random.randn(3, 4), t=Batch()), b=np.random.randn(3, 4)) b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4)) ans = Batch.cat([a, b, a]) assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) @@ -263,34 +264,8 @@ def test_batch_cat_and_stack() -> None: assert isinstance(b1.a.d.e, np.ndarray) assert b1.a.d.e.ndim == 2 - # test cat with incompatible keys - b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) - b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) - test = Batch.cat([b1, b2]) - ans = Batch( - a=np.concatenate([b1.a, np.zeros((4, 4))]), - b=torch.cat([torch.zeros(3, 3), b2.b]), - common=Batch(c=np.concatenate([b1.common.c, b2.common.c])), - ) - assert np.allclose(test.a, ans.a) - assert torch.allclose(test.b, ans.b) - assert np.allclose(test.common.c, ans.common.c) - - # test cat with reserved keys (values are Batch()) - b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) - b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) - test = Batch.cat([b1, b2]) - ans = Batch( - a=np.concatenate([b1.a, np.zeros((4, 4))]), - b=torch.cat([torch.zeros(3, 3), b2.b]), - common=Batch(c=np.concatenate([b1.common.c, b2.common.c])), - ) - assert np.allclose(test.a, ans.a) - assert torch.allclose(test.b, ans.b) - assert np.allclose(test.common.c, ans.common.c) - # test cat with all reserved keys (values are Batch()) - b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5))) + b1 = Batch(a=Batch(), b=torch.zeros(3, 3), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) ans = Batch( @@ -385,7 +360,7 @@ def test_utils_to_torch_numpy() -> None: a_torch_double = to_torch(batch.a, dtype=torch.float64) assert a_torch_double.dtype == torch.float64 batch_torch_float = to_torch(batch, dtype=torch.float32) - assert batch_torch_float.a.dtype == torch.float32 + assert batch_torch_float.a.dtype == torch.float64 assert batch_torch_float.b.c.dtype == torch.float32 assert batch_torch_float.b.d.dtype == torch.float32 data_list = [float("nan"), 1] @@ -867,7 +842,7 @@ def test_slice_distribution() -> None: Categorical(probs=sliced_probs).probs == get_sliced_dist(dist, selected_idx).probs ).all() # retrieving a single index - assert (batch[0].dist.probs == dist.probs[0]).all() + assert torch.allclose(batch[0].dist.probs, dist.probs[0]) @staticmethod def test_getitem_with_int_gives_scalars() -> None: diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index bd78255ac..d147ac0e6 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -907,8 +907,13 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: # check input format batch_list = [] - original_keys_only_batch = self.apply_array_func(lambda x: None) if len(self) > 0 else None - """A batch with all values removed, just keys left. Can be considered a sort of schema.""" + original_keys_only_batch = None + """A batch with all values removed, just keys left. Can be considered a sort of schema. + Will be either the schema of self, or of the first non-empty batch in the sequence. + """ + if len(self) > 0: + original_keys_only_batch = self.apply_values_transform(lambda x: None) + original_keys_only_batch.replace_empty_batches_by_none() for batch in batches: if isinstance(batch, dict): @@ -919,14 +924,16 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: continue if original_keys_only_batch is None: original_keys_only_batch = batch.apply_values_transform(lambda x: None) + original_keys_only_batch.replace_empty_batches_by_none() batch_list.append(batch) continue cur_keys_only_batch = batch.apply_values_transform(lambda x: None) + cur_keys_only_batch.replace_empty_batches_by_none() if original_keys_only_batch != cur_keys_only_batch: raise ValueError( f"Batch.cat_ only supports concatenation of batches with the same structure but got " - f"structures {original_keys_only_batch} and {cur_keys_only_batch}.", + f"structures: \n{original_keys_only_batch}\n and\n{cur_keys_only_batch}.", ) batch_list.append(batch) if len(batch_list) == 0: @@ -1242,6 +1249,21 @@ def dropnull(self) -> Self: sub_batches.append(b) return Batch.cat(sub_batches) + def replace_empty_batches_by_none(self) -> None: + """Goes through the batch-tree" recursively and replaces empty batches by None. + + This is useful for extracting the structure of a batch without the actual data, + especially in combination with `apply_values_transform` with a + transform function à la `lambda x: None`. + """ + empty_batch = Batch() + for key, val in self.items(): + if isinstance(val, Batch): + if val == empty_batch: + self[key] = None + else: + val.replace_empty_batches_by_none() + def _apply_batch_values_func_recursively( batch: TBatch, From 2a87a9a593830612b9aef4a4f9ea871cda17a3b2 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 30 Jul 2024 13:04:08 +0200 Subject: [PATCH 09/12] Typo [skip ci] --- tianshou/data/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index d147ac0e6..650a5ccdf 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1254,7 +1254,7 @@ def replace_empty_batches_by_none(self) -> None: This is useful for extracting the structure of a batch without the actual data, especially in combination with `apply_values_transform` with a - transform function à la `lambda x: None`. + transform function a la `lambda x: None`. """ empty_batch = Batch() for key, val in self.items(): From 2d4266dc3512d54e4e630f8d1964d850557e9791 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 30 Jul 2024 23:47:20 +0200 Subject: [PATCH 10/12] Tests: fixed bug in MoveToRightEnv (info dict had wrong structure) --- test/base/env.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index 2a7b09278..49156aca8 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -147,6 +147,8 @@ def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. if self.index == self.size: self.terminated = True return self._get_state(), self._get_reward(), self.terminated, False, {} + + info_dict = {"key": 1, "env": self} if self.dict_state else {} if action == 0: self.index = max(self.index - 1, 0) return ( @@ -154,7 +156,7 @@ def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. self._get_reward(), self.terminated, False, - {"key": 1, "env": self} if self.dict_state else {}, + info_dict, ) if action == 1: self.index += 1 @@ -164,7 +166,7 @@ def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. self._get_reward(), self.terminated, False, - {"key": 1, "env": self}, + info_dict, ) return None From defcffb152609c2c8991a72f10f4f1bc3d0706bc Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 31 Jul 2024 19:14:41 +0200 Subject: [PATCH 11/12] Fix in tensor type handling in bcq and cql learn --- tianshou/policy/imitation/bcq.py | 4 +++- tianshou/policy/imitation/cql.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index dee1a80a3..991c4aace 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -196,8 +196,10 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBCQT # now target_Q: (batch_size, 1) target_Q = ( - batch.rew.reshape(-1, 1) + (1 - batch.done).reshape(-1, 1) * self.gamma * target_Q + batch.rew.reshape(-1, 1) + + torch.logical_not(batch.done).reshape(-1, 1) * self.gamma * target_Q ) + target_Q = target_Q.float() current_Q1 = self.critic(obs, act) current_Q2 = self.critic2(obs, act) diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 1ce6d83d4..66438c758 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -280,7 +280,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TCQLT target_Q = torch.min(target_Q1, target_Q2) - self.alpha * new_log_pi - target_Q = rew + self.gamma * (1 - batch.done) * target_Q.flatten() + target_Q = rew + torch.logical_not(batch.done) * self.gamma * target_Q.flatten() + target_Q = target_Q.float() # shape: (batch_size) # compute critic loss From a9a8af44ca4e10a3417a2d6db0cb7b7b8fea048c Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 31 Jul 2024 19:15:44 +0200 Subject: [PATCH 12/12] Tests: fixed info of MoveToRightEnv needed now due to the stricter Batch.cat_ --- test/base/env.py | 2 +- test/base/test_buffer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index 49156aca8..02f76ad2d 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -148,7 +148,7 @@ def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. self.terminated = True return self._get_state(), self._get_reward(), self.terminated, False, {} - info_dict = {"key": 1, "env": self} if self.dict_state else {} + info_dict = {"key": 1, "env": self} if action == 0: self.index = max(self.index - 1, 0) return ( diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 75ff919c1..0996f2436 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -312,7 +312,7 @@ def test_priortized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: env = MoveToRightEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) - obs, info = env.reset() + obs, _ = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act)