From f7ecb647b9abae73d9eed49c31ac60f5f6c2054a Mon Sep 17 00:00:00 2001 From: Markus28 Date: Tue, 2 Nov 2021 09:22:07 +0100 Subject: [PATCH 1/5] Implements set_env_attr and get_env_attr for vector environments (resolves #473) --- test/base/test_env.py | 18 +++++++++++- tianshou/env/venvs.py | 51 ++++++++++++++++++++++++++++++---- tianshou/env/worker/base.py | 8 ++++-- tianshou/env/worker/dummy.py | 5 +++- tianshou/env/worker/ray.py | 18 ++++++++++-- tianshou/env/worker/subproc.py | 7 ++++- 6 files changed, 93 insertions(+), 14 deletions(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index b9d6489b6..d8280b307 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -166,10 +166,26 @@ def test_vecenv(size=10, num=8, sleep=0.001): for i, v in enumerate(venv): print(f'{type(v)}: {t[i]:.6f}s') + def assert_get(v, expected): + print(expected, v.get_env_attr("size")) + assert v.get_env_attr("size") == expected + assert v.get_env_attr("size", id=0) == [expected[0]] + assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3] + for v in venv: - assert v.size == list(range(size, size + num)) + assert_get(v, list(range(size, size + num))) assert v.env_num == num assert v.action_space == [Discrete(2)] * num + + v.set_env_attr("size", 0) + assert_get(v, [0] * num) + + v.set_env_attr("size", 1, 0) + assert_get(v, [1] + [0] * (num - 1)) + + v.set_env_attr("size", 2, [1, 2, 3]) + assert_get(v, [1] + [2] * 3 + [0] * (num - 4)) + for v in venv: v.close() diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 654f55b69..f63739f1d 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -130,17 +130,56 @@ def __getattribute__(self, key: str) -> Any: if key in [ 'metadata', 'reward_range', 'spec', 'action_space', 'observation_space' ]: # reserved keys in gym.Env - return self.__getattr__(key) + return self.get_env_attr(key) else: return super().__getattribute__(key) - def __getattr__(self, key: str) -> List[Any]: - """Fetch a list of env attributes. + def get_env_attr( + self, + key: str, + id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> List[Any]: + """Get an attribute from the underlying environments. + + If id is an int, retrieve the attribute denoted by key from the environment + underlying the worker at index id. The result is returned as a list with one + element. Otherwise, retrieve the attribute for all workers at indices id and + return a list that is ordered correspondingly to id. + + :param str key: The key of the desired attribute + :param id: Indices of the desired workers + + :return list: The list of environment attributes + """ + self._assert_is_not_closed() + id = self._wrap_id(id) + if self.is_async: + self._assert_id(id) - This function tries to retrieve an attribute from each individual wrapped - environment, if it does not belong to the wrapping vector environment class. + return [self.workers[j].get_env_attr(key) for j in id] + + def set_env_attr( + self, + key: str, + value: Any, + id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> None: + """Set an attribute in the underlying environments. + + If id is an int, set the attribute denoted by key from the environment + underlying the worker at index id to value. + Otherwise, set the attribute for all workers at indices id. + + :param str key: The key of the desired attribute + :param Any value: The new value of the attribute + :param id: Indices of the desired workers """ - return [getattr(worker, key) for worker in self.workers] + self._assert_is_not_closed() + id = self._wrap_id(id) + if self.is_async: + self._assert_id(id) + for j in id: + self.workers[j].set_env_attr(key, value) def _wrap_id( self, diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 6fef9f68d..3c63be997 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -12,10 +12,14 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] - self.action_space = getattr(self, "action_space") # noqa: B009 + self.action_space = self.get_env_attr("action_space") # noqa: B009 @abstractmethod - def __getattr__(self, key: str) -> Any: + def get_env_attr(self, key: str) -> Any: + pass + + @abstractmethod + def set_env_attr(self, key: str, value: Any) -> None: pass @abstractmethod diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 9e68e9f04..542c70210 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -13,9 +13,12 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self.env = env_fn() super().__init__(env_fn) - def __getattr__(self, key: str) -> Any: + def get_env_attr(self, key: str) -> Any: return getattr(self.env, key) + def set_env_attr(self, key: str, value: Any) -> None: + setattr(self.env, key, value) + def reset(self) -> Any: return self.env.reset() diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 5d73763f2..7f62aeb92 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -11,15 +11,27 @@ pass +class _SetAttrWrapper(gym.Wrapper): + + def set_env_attr(self, key: str, value: Any) -> None: + setattr(self.env, key, value) + + def get_env_attr(self, key: str) -> Any: + getattr(self.env, key) + + class RayEnvWorker(EnvWorker): """Ray worker used in RayVectorEnv.""" def __init__(self, env_fn: Callable[[], gym.Env]) -> None: - self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn()) + self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn()) super().__init__(env_fn) - def __getattr__(self, key: str) -> Any: - return ray.get(self.env.__getattr__.remote(key)) + def get_env_attr(self, key: str) -> Any: + return ray.get(self.env.get_env_attr.remote(key)) + + def set_env_attr(self, key: str, value: Any) -> None: + ray.get(self.env.set_env_attribute.remote(key, value)) def reset(self) -> Any: return ray.get(self.env.reset.remote()) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 8ef264360..61a69cafd 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -107,6 +107,8 @@ def _encode_obs( p.send(env.seed(data) if hasattr(env, "seed") else None) elif cmd == "getattr": p.send(getattr(env, data) if hasattr(env, data) else None) + elif cmd == "setattr": + setattr(env, data["key"], data["value"]) else: p.close() raise NotImplementedError @@ -140,10 +142,13 @@ def __init__( self.child_remote.close() super().__init__(env_fn) - def __getattr__(self, key: str) -> Any: + def get_env_attr(self, key: str) -> Any: self.parent_remote.send(["getattr", key]) return self.parent_remote.recv() + def set_env_attr(self, key: str, value: Any) -> None: + self.parent_remote.send(["setattr", {"key": key, "value": value}]) + def _decode_obs(self) -> Union[dict, tuple, np.ndarray]: def decode_obs( From d7e4add947bcbc71a623bc8d782a7f9adf17aed6 Mon Sep 17 00:00:00 2001 From: Markus28 Date: Tue, 2 Nov 2021 10:36:43 +0100 Subject: [PATCH 2/5] Fixed get_env_attr in _AttrWrapper, set_env_attr in RayEnvWorker --- test/base/test_env.py | 1 - tianshou/env/worker/ray.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index d8280b307..7f47501c3 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -167,7 +167,6 @@ def test_vecenv(size=10, num=8, sleep=0.001): print(f'{type(v)}: {t[i]:.6f}s') def assert_get(v, expected): - print(expected, v.get_env_attr("size")) assert v.get_env_attr("size") == expected assert v.get_env_attr("size", id=0) == [expected[0]] assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3] diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 7f62aeb92..7917683be 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -17,7 +17,7 @@ def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env, key, value) def get_env_attr(self, key: str) -> Any: - getattr(self.env, key) + return getattr(self.env, key) class RayEnvWorker(EnvWorker): @@ -31,7 +31,7 @@ def get_env_attr(self, key: str) -> Any: return ray.get(self.env.get_env_attr.remote(key)) def set_env_attr(self, key: str, value: Any) -> None: - ray.get(self.env.set_env_attribute.remote(key, value)) + ray.get(self.env.set_env_attr.remote(key, value)) def reset(self) -> Any: return ray.get(self.env.reset.remote()) From 4f82b404aee88afb754176e04b0aaf9bfe8d32a1 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Tue, 2 Nov 2021 09:23:15 -0400 Subject: [PATCH 3/5] fix ci --- test/throughput/test_batch_profile.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/throughput/test_batch_profile.py b/test/throughput/test_batch_profile.py index fbd6fb89c..7e4c1eaa4 100644 --- a/test/throughput/test_batch_profile.py +++ b/test/throughput/test_batch_profile.py @@ -10,7 +10,7 @@ @pytest.fixture(scope="module") def data(): - print("Initialising data...") + print("Initializing data...") np.random.seed(0) batch_set = [ Batch( @@ -19,7 +19,7 @@ def data(): 'b1': (3.14, 3.14), 'b2': np.arange(1e3) }, - c=i + c=i, ) for i in np.arange(int(1e4)) ] batch0 = Batch( @@ -27,8 +27,8 @@ def data(): b=Batch( c=np.ones((1, ), dtype=np.float64), d=torch.ones((3, 3, 3), dtype=torch.float32), - e=list(range(3)) - ) + e=list(range(3)), + ), ) batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] @@ -39,13 +39,13 @@ def data(): indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False) slice_dict = { 'obs': [np.arange(20) for _ in np.arange(batch_len // 10)], - 'reward': np.arange(batch_len // 10) + 'reward': np.arange(batch_len // 10), } dict_set = [ { 'obs': np.arange(20), 'info': "this is info", - 'reward': 0 + 'reward': 0, } for _ in np.arange(1e2) ] batch4 = Batch( @@ -53,11 +53,11 @@ def data(): b=Batch( c=np.ones((1, ), dtype=np.float64), d=torch.ones((1000, 1000), dtype=torch.float32), - e=np.arange(1000) - ) + e=np.arange(1000), + ), ) - print("Initialised") + print("Initialized") return { 'batch_set': batch_set, 'batch0': batch0, @@ -67,7 +67,7 @@ def data(): 'indexs': indexs, 'dict_set': dict_set, 'slice_dict': slice_dict, - 'batch4': batch4 + 'batch4': batch4, } @@ -106,7 +106,7 @@ def test_set_attr(data): def test_numpy_torch_convert(data): """Test conversion between numpy and torch.""" - for _ in np.arange(1e5): + for _ in np.arange(1e4): # not sure what's wrong in torch==1.10.0 data['batch4'].to_torch() data['batch4'].to_numpy() From cd6a1ad1f10291f4782dd1f693317655f63ab900 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Tue, 2 Nov 2021 09:34:08 -0400 Subject: [PATCH 4/5] format --- tianshou/env/venvs.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index f63739f1d..0d7757291 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -146,10 +146,10 @@ def get_env_attr( element. Otherwise, retrieve the attribute for all workers at indices id and return a list that is ordered correspondingly to id. - :param str key: The key of the desired attribute - :param id: Indices of the desired workers + :param str key: The key of the desired attribute. + :param id: Indices of the desired workers. - :return list: The list of environment attributes + :return list: The list of environment attributes. """ self._assert_is_not_closed() id = self._wrap_id(id) @@ -170,9 +170,9 @@ def set_env_attr( underlying the worker at index id to value. Otherwise, set the attribute for all workers at indices id. - :param str key: The key of the desired attribute - :param Any value: The new value of the attribute - :param id: Indices of the desired workers + :param str key: The key of the desired attribute. + :param Any value: The new value of the attribute. + :param id: Indices of the desired workers. """ self._assert_is_not_closed() id = self._wrap_id(id) From af12a7075023a46d102eb6727f7ca632ea46861d Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Tue, 2 Nov 2021 11:23:45 -0400 Subject: [PATCH 5/5] polish comment --- tianshou/env/venvs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 0d7757291..b46593c47 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -147,7 +147,7 @@ def get_env_attr( return a list that is ordered correspondingly to id. :param str key: The key of the desired attribute. - :param id: Indices of the desired workers. + :param id: Indice(s) of the desired worker(s). Default to None for all env_id. :return list: The list of environment attributes. """ @@ -172,7 +172,7 @@ def set_env_attr( :param str key: The key of the desired attribute. :param Any value: The new value of the attribute. - :param id: Indices of the desired workers. + :param id: Indice(s) of the desired worker(s). Default to None for all env_id. """ self._assert_is_not_closed() id = self._wrap_id(id)