From f4c0e5570fac9b046ac9ca126681581640ebf29b Mon Sep 17 00:00:00 2001 From: Heinrich Kuttler Date: Wed, 26 Jan 2022 12:25:14 +0000 Subject: [PATCH] Fix _underscore access, rename Nethack variable as `nethack`. (#303) --- nle/env/base.py | 38 +++++++++++++++++++------------------- nle/env/tasks.py | 12 ++++++------ nle/nethack/nethack.py | 6 +++--- nle/tests/test_envs.py | 23 ++++++++++------------- nle/tests/test_profile.py | 2 +- 5 files changed, 39 insertions(+), 42 deletions(-) diff --git a/nle/env/base.py b/nle/env/base.py index 4f2f7981a..41c9e5ca3 100644 --- a/nle/env/base.py +++ b/nle/env/base.py @@ -248,15 +248,13 @@ def __init__( if actions is None: actions = FULL_ACTIONS - self._actions = actions + self.actions = actions self.last_observation = () try: if savedir is None: self.savedir = None - self._stats_file = None - self._stats_logger = None elif savedir: self.savedir = os.path.abspath(savedir) os.makedirs(self.savedir) @@ -312,7 +310,7 @@ def __init__( else: ttyrec = None - self.env = nethack.Nethack( + self.nethack = nethack.Nethack( observation_keys=self._observation_keys, options=options, playername="Agent-" + self.character, @@ -320,7 +318,7 @@ def __init__( wizard=wizard, spawn_monsters=spawn_monsters, ) - self._close_env = weakref.finalize(self, self.env.close) + self._close_nethack = weakref.finalize(self, self.nethack.close) self._random = random.SystemRandom() @@ -332,7 +330,7 @@ def __init__( {key: space_dict[key] for key in observation_keys} ) - self.action_space = gym.spaces.Discrete(len(self._actions)) + self.action_space = gym.spaces.Discrete(len(self.actions)) def _get_observation(self, observation): return { @@ -341,7 +339,7 @@ def _get_observation(self, observation): } def print_action_meanings(self): - for a_idx, a in enumerate(self._actions): + for a_idx, a in enumerate(self.actions): print(a_idx, a) def _check_abort(self, observation): @@ -367,7 +365,7 @@ def step(self, action: int): # Careful: By default we re-use Numpy arrays, so copy before! last_observation = tuple(a.copy() for a in self.last_observation) - observation, done = self.env.step(self._actions[action]) + observation, done = self.nethack.step(self.actions[action]) is_game_over = observation[self._program_state_index][0] == 1 if is_game_over or not self._allow_all_modes: observation, done = self._perform_known_steps( @@ -395,7 +393,7 @@ def step(self, action: int): info = {} info["end_status"] = end_status - info["is_ascended"] = self.env.how_done() == nethack.ASCENDED + info["is_ascended"] = self.nethack.how_done() == nethack.ASCENDED return self._get_observation(observation), reward, done, info @@ -417,7 +415,9 @@ def reset(self, wizkit_items=None): """ self._episode += 1 new_ttyrec = self._ttyrec_pattern % self._episode if self.savedir else None - self.last_observation = self.env.reset(new_ttyrec, wizkit_items=wizkit_items) + self.last_observation = self.nethack.reset( + new_ttyrec, wizkit_items=wizkit_items + ) self._steps = 0 @@ -430,7 +430,7 @@ def reset(self, wizkit_items=None): # monster at the 0th turn and gets asked to name it. # Hence the defensive iteration above. # TODO: Detect this 'in_getlin' situation and handle it. - self.last_observation, done = self.env.step(ASCII_SPACE) + self.last_observation, done = self.nethack.step(ASCII_SPACE) assert not done, "Game ended unexpectedly" else: warnings.warn( @@ -441,7 +441,7 @@ def reset(self, wizkit_items=None): return self._get_observation(self.last_observation) def close(self): - self._close_env() + self._close_nethack() super().close() def seed(self, core=None, disp=None, reseed=False): @@ -470,7 +470,7 @@ def seed(self, core=None, disp=None, reseed=False): core = self._random.randrange(sys.maxsize) if disp is None: disp = self._random.randrange(sys.maxsize) - self.env.set_initial_seeds(core, disp, reseed) + self.nethack.set_initial_seeds(core, disp, reseed) return (core, disp, reseed) def get_seeds(self): @@ -479,7 +479,7 @@ def get_seeds(self): Returns: (tuple): Current NetHack (core, disp, reseed) state. """ - return self.env.get_current_seeds() + return self.nethack.get_current_seeds() def render(self, mode="human"): """Renders the state of the environment.""" @@ -539,7 +539,7 @@ def _is_episode_end(self, observation): def _reward_fn(self, last_observation, action, observation, end_status): """Reward function. Difference between previous score and new score.""" - if not self.env.in_normal_game(): + if not self.nethack.in_normal_game(): # Before game started or after it ended stats are zero. return 0.0 old_score = last_observation[self._blstats_index][nethack.NLE_BL_SCORE] @@ -551,7 +551,7 @@ def _reward_fn(self, last_observation, action, observation, end_status): def _perform_known_steps(self, observation, done, exceptions=True): while not done: if observation[self._internal_index][3]: # xwaitforspace - observation, done = self.env.step(ASCII_SPACE) + observation, done = self.nethack.step(ASCII_SPACE) continue internal = observation[self._internal_index] @@ -559,7 +559,7 @@ def _perform_known_steps(self, observation, done, exceptions=True): in_getlin = internal[2] if in_getlin: # Game asking for a line of text. We don't do that. - observation, done = self.env.step(ASCII_ESC) + observation, done = self.nethack.step(ASCII_ESC) continue if in_yn_function: # Game asking for a single character. @@ -577,7 +577,7 @@ def _perform_known_steps(self, observation, done, exceptions=True): break # Otherwise, auto-decline. - observation, done = self.env.step(ASCII_ESC) + observation, done = self.nethack.step(ASCII_ESC) break @@ -596,7 +596,7 @@ def _quit_game(self, observation, done): # Quit the game. actions = [0x80 | ord("q"), ord("y")] # M-q y for a in actions: - observation, done = self.env.step(a) + observation, done = self.nethack.step(a) # Answer final questions. observation, done = self._perform_known_steps( diff --git a/nle/env/tasks.py b/nle/env/tasks.py index d5c682df1..a506d96c8 100644 --- a/nle/env/tasks.py +++ b/nle/env/tasks.py @@ -198,7 +198,7 @@ def _reward_fn(self, last_observation, action, observation, end_status): """Difference between previous gold and new gold.""" del end_status # Unused del action # Unused - if not self.env.in_normal_game(): + if not self.nethack.in_normal_game(): # Before game started or after it ended stats are zero. return 0.0 @@ -230,7 +230,7 @@ def _reward_fn(self, last_observation, action, observation, end_status): del end_status # Unused del action # Unused - if not self.env.in_normal_game(): + if not self.nethack.in_normal_game(): # Before game started or after it ended stats are zero. return 0.0 @@ -262,7 +262,7 @@ def _reward_fn(self, last_observation, action, observation, end_status): del end_status # Unused del action # Unused - if not self.env.in_normal_game(): + if not self.nethack.in_normal_game(): # Before game started or after it ended stats are zero. return 0.0 @@ -344,9 +344,9 @@ def __init__( def f(*args, **kwargs): raise RuntimeError("Should not try changing seeds") - self.env.set_initial_seeds = f - self.env.set_current_seeds = f - self.env.get_current_seeds = f + self.nethack.set_initial_seeds = f + self.nethack.set_current_seeds = f + self.nethack.get_current_seeds = f def reset(self, *args, **kwargs): self._turns = None diff --git a/nle/nethack/nethack.py b/nle/nethack/nethack.py index e6a75b43b..535dd064f 100644 --- a/nle/nethack/nethack.py +++ b/nle/nethack/nethack.py @@ -191,11 +191,11 @@ def __init__( if options is None: options = NETHACKOPTIONS - self._options = list(options) + ["name:" + playername] + self.options = list(options) + ["name:" + playername] if wizard: - self._options.append("playmode:debug") + self.options.append("playmode:debug") self._wizard = wizard - self._nethackoptions = ",".join(self._options) + self._nethackoptions = ",".join(self.options) if ttyrec is None: self._pynethack = _pynethack.Nethack( self.dlpath, self._vardir, self._nethackoptions, spawn_monsters diff --git a/nle/tests/test_envs.py b/nle/tests/test_envs.py index 2a354c7c4..01199009b 100644 --- a/nle/tests/test_envs.py +++ b/nle/tests/test_envs.py @@ -121,11 +121,11 @@ def test_default_wizard_mode(self, env_name, wizard): if env_name.startswith("NetHackChallenge-"): pytest.skip("No wizard mode in NetHackChallenge") env = gym.make(env_name, wizard=wizard) - assert "playmode:debug" in env.env._options + assert "playmode:debug" in env.nethack.options else: # do not send a parameter to test a default env = gym.make(env_name) - assert "playmode:debug" not in env.env._options + assert "playmode:debug" not in env.nethack.options class TestWizkit: @@ -229,8 +229,6 @@ def test_rollout_no_archive(self, env_name, rollout_len): """Tests rollout_len steps (or until termination) of random policy.""" env = gym.make(env_name, savedir=None) assert env.savedir is None - assert env._stats_file is None - assert env._stats_logger is None rollout_env(env, rollout_len) def test_seed_interface_output(self, env_name, rollout_len): @@ -335,16 +333,15 @@ def env(self): e.close() def test_kick_and_quit(self, env): - actions = env._actions env.reset() - kick = actions.index(nethack.Command.KICK) + kick = env.actions.index(nethack.Command.KICK) obs, reward, done, _ = env.step(kick) assert b"In what direction? " in bytes(obs["message"]) env.step(nethack.MiscAction.MORE) # Hack to quit. - env.env.step(nethack.M("q")) - obs, reward, done, _ = env.step(actions.index(ord("y"))) + env.nethack.step(nethack.M("q")) + obs, reward, done, _ = env.step(env.actions.index(ord("y"))) assert done assert reward == 0.0 @@ -364,11 +361,11 @@ def test_final_reward(self, env): # Hopefully, we got some positive reward by now. # Get out of any menu / yn_function. - env.step(env._actions.index(ord("\r"))) + env.step(env.actions.index(ord("\r"))) # Hack to quit. - env.env.step(nethack.M("q")) - _, reward, done, _ = env.step(env._actions.index(ord("y"))) + env.nethack.step(nethack.M("q")) + _, reward, done, _ = env.step(env.actions.index(ord("y"))) assert done assert reward == 0.0 @@ -398,8 +395,8 @@ def test_no_seed_setting(self): ): env.seed() with pytest.raises(RuntimeError, match="Should not try changing seeds"): - env.env.set_initial_seeds(0, 0, True) + env.nethack.set_initial_seeds(0, 0, True) if not nethack.NLE_ALLOW_SEEDING: with pytest.raises(RuntimeError, match="Seeding not enabled"): - env.env._pynethack.set_initial_seeds(0, 0, True) + env.nethack._pynethack.set_initial_seeds(0, 0, True) diff --git a/nle/tests/test_profile.py b/nle/tests/test_profile.py index 35eb6d8bf..273d092bc 100644 --- a/nle/tests/test_profile.py +++ b/nle/tests/test_profile.py @@ -52,7 +52,7 @@ def test_run_1k_steps(self, observation_keys, benchmark): steps = 1000 np.random.seed(seeds) - actions = np.random.choice(len(env._actions), size=steps) + actions = np.random.choice(env.action_space.n, size=steps) def seed(): if not nle.nethack.NLE_ALLOW_SEEDING: