Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Commit

Permalink
Fix _underscore access, rename Nethack variable as nethack. (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
Heinrich Kuttler authored Jan 26, 2022
1 parent 367e6d8 commit f4c0e55
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 42 deletions.
38 changes: 19 additions & 19 deletions nle/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,13 @@ def __init__(

if actions is None:
actions = FULL_ACTIONS
self._actions = actions
self.actions = actions

self.last_observation = ()

try:
if savedir is None:
self.savedir = None
self._stats_file = None
self._stats_logger = None
elif savedir:
self.savedir = os.path.abspath(savedir)
os.makedirs(self.savedir)
Expand Down Expand Up @@ -312,15 +310,15 @@ def __init__(
else:
ttyrec = None

self.env = nethack.Nethack(
self.nethack = nethack.Nethack(
observation_keys=self._observation_keys,
options=options,
playername="Agent-" + self.character,
ttyrec=ttyrec,
wizard=wizard,
spawn_monsters=spawn_monsters,
)
self._close_env = weakref.finalize(self, self.env.close)
self._close_nethack = weakref.finalize(self, self.nethack.close)

self._random = random.SystemRandom()

Expand All @@ -332,7 +330,7 @@ def __init__(
{key: space_dict[key] for key in observation_keys}
)

self.action_space = gym.spaces.Discrete(len(self._actions))
self.action_space = gym.spaces.Discrete(len(self.actions))

def _get_observation(self, observation):
return {
Expand All @@ -341,7 +339,7 @@ def _get_observation(self, observation):
}

def print_action_meanings(self):
for a_idx, a in enumerate(self._actions):
for a_idx, a in enumerate(self.actions):
print(a_idx, a)

def _check_abort(self, observation):
Expand All @@ -367,7 +365,7 @@ def step(self, action: int):
# Careful: By default we re-use Numpy arrays, so copy before!
last_observation = tuple(a.copy() for a in self.last_observation)

observation, done = self.env.step(self._actions[action])
observation, done = self.nethack.step(self.actions[action])
is_game_over = observation[self._program_state_index][0] == 1
if is_game_over or not self._allow_all_modes:
observation, done = self._perform_known_steps(
Expand Down Expand Up @@ -395,7 +393,7 @@ def step(self, action: int):

info = {}
info["end_status"] = end_status
info["is_ascended"] = self.env.how_done() == nethack.ASCENDED
info["is_ascended"] = self.nethack.how_done() == nethack.ASCENDED

return self._get_observation(observation), reward, done, info

Expand All @@ -417,7 +415,9 @@ def reset(self, wizkit_items=None):
"""
self._episode += 1
new_ttyrec = self._ttyrec_pattern % self._episode if self.savedir else None
self.last_observation = self.env.reset(new_ttyrec, wizkit_items=wizkit_items)
self.last_observation = self.nethack.reset(
new_ttyrec, wizkit_items=wizkit_items
)

self._steps = 0

Expand All @@ -430,7 +430,7 @@ def reset(self, wizkit_items=None):
# monster at the 0th turn and gets asked to name it.
# Hence the defensive iteration above.
# TODO: Detect this 'in_getlin' situation and handle it.
self.last_observation, done = self.env.step(ASCII_SPACE)
self.last_observation, done = self.nethack.step(ASCII_SPACE)
assert not done, "Game ended unexpectedly"
else:
warnings.warn(
Expand All @@ -441,7 +441,7 @@ def reset(self, wizkit_items=None):
return self._get_observation(self.last_observation)

def close(self):
self._close_env()
self._close_nethack()
super().close()

def seed(self, core=None, disp=None, reseed=False):
Expand Down Expand Up @@ -470,7 +470,7 @@ def seed(self, core=None, disp=None, reseed=False):
core = self._random.randrange(sys.maxsize)
if disp is None:
disp = self._random.randrange(sys.maxsize)
self.env.set_initial_seeds(core, disp, reseed)
self.nethack.set_initial_seeds(core, disp, reseed)
return (core, disp, reseed)

def get_seeds(self):
Expand All @@ -479,7 +479,7 @@ def get_seeds(self):
Returns:
(tuple): Current NetHack (core, disp, reseed) state.
"""
return self.env.get_current_seeds()
return self.nethack.get_current_seeds()

def render(self, mode="human"):
"""Renders the state of the environment."""
Expand Down Expand Up @@ -539,7 +539,7 @@ def _is_episode_end(self, observation):

def _reward_fn(self, last_observation, action, observation, end_status):
"""Reward function. Difference between previous score and new score."""
if not self.env.in_normal_game():
if not self.nethack.in_normal_game():
# Before game started or after it ended stats are zero.
return 0.0
old_score = last_observation[self._blstats_index][nethack.NLE_BL_SCORE]
Expand All @@ -551,15 +551,15 @@ def _reward_fn(self, last_observation, action, observation, end_status):
def _perform_known_steps(self, observation, done, exceptions=True):
while not done:
if observation[self._internal_index][3]: # xwaitforspace
observation, done = self.env.step(ASCII_SPACE)
observation, done = self.nethack.step(ASCII_SPACE)
continue

internal = observation[self._internal_index]
in_yn_function = internal[1]
in_getlin = internal[2]

if in_getlin: # Game asking for a line of text. We don't do that.
observation, done = self.env.step(ASCII_ESC)
observation, done = self.nethack.step(ASCII_ESC)
continue

if in_yn_function: # Game asking for a single character.
Expand All @@ -577,7 +577,7 @@ def _perform_known_steps(self, observation, done, exceptions=True):
break

# Otherwise, auto-decline.
observation, done = self.env.step(ASCII_ESC)
observation, done = self.nethack.step(ASCII_ESC)

break

Expand All @@ -596,7 +596,7 @@ def _quit_game(self, observation, done):
# Quit the game.
actions = [0x80 | ord("q"), ord("y")] # M-q y
for a in actions:
observation, done = self.env.step(a)
observation, done = self.nethack.step(a)

# Answer final questions.
observation, done = self._perform_known_steps(
Expand Down
12 changes: 6 additions & 6 deletions nle/env/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _reward_fn(self, last_observation, action, observation, end_status):
"""Difference between previous gold and new gold."""
del end_status # Unused
del action # Unused
if not self.env.in_normal_game():
if not self.nethack.in_normal_game():
# Before game started or after it ended stats are zero.
return 0.0

Expand Down Expand Up @@ -230,7 +230,7 @@ def _reward_fn(self, last_observation, action, observation, end_status):
del end_status # Unused
del action # Unused

if not self.env.in_normal_game():
if not self.nethack.in_normal_game():
# Before game started or after it ended stats are zero.
return 0.0

Expand Down Expand Up @@ -262,7 +262,7 @@ def _reward_fn(self, last_observation, action, observation, end_status):
del end_status # Unused
del action # Unused

if not self.env.in_normal_game():
if not self.nethack.in_normal_game():
# Before game started or after it ended stats are zero.
return 0.0

Expand Down Expand Up @@ -344,9 +344,9 @@ def __init__(
def f(*args, **kwargs):
raise RuntimeError("Should not try changing seeds")

self.env.set_initial_seeds = f
self.env.set_current_seeds = f
self.env.get_current_seeds = f
self.nethack.set_initial_seeds = f
self.nethack.set_current_seeds = f
self.nethack.get_current_seeds = f

def reset(self, *args, **kwargs):
self._turns = None
Expand Down
6 changes: 3 additions & 3 deletions nle/nethack/nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ def __init__(

if options is None:
options = NETHACKOPTIONS
self._options = list(options) + ["name:" + playername]
self.options = list(options) + ["name:" + playername]
if wizard:
self._options.append("playmode:debug")
self.options.append("playmode:debug")
self._wizard = wizard
self._nethackoptions = ",".join(self._options)
self._nethackoptions = ",".join(self.options)
if ttyrec is None:
self._pynethack = _pynethack.Nethack(
self.dlpath, self._vardir, self._nethackoptions, spawn_monsters
Expand Down
23 changes: 10 additions & 13 deletions nle/tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def test_default_wizard_mode(self, env_name, wizard):
if env_name.startswith("NetHackChallenge-"):
pytest.skip("No wizard mode in NetHackChallenge")
env = gym.make(env_name, wizard=wizard)
assert "playmode:debug" in env.env._options
assert "playmode:debug" in env.nethack.options
else:
# do not send a parameter to test a default
env = gym.make(env_name)
assert "playmode:debug" not in env.env._options
assert "playmode:debug" not in env.nethack.options


class TestWizkit:
Expand Down Expand Up @@ -229,8 +229,6 @@ def test_rollout_no_archive(self, env_name, rollout_len):
"""Tests rollout_len steps (or until termination) of random policy."""
env = gym.make(env_name, savedir=None)
assert env.savedir is None
assert env._stats_file is None
assert env._stats_logger is None
rollout_env(env, rollout_len)

def test_seed_interface_output(self, env_name, rollout_len):
Expand Down Expand Up @@ -335,16 +333,15 @@ def env(self):
e.close()

def test_kick_and_quit(self, env):
actions = env._actions
env.reset()
kick = actions.index(nethack.Command.KICK)
kick = env.actions.index(nethack.Command.KICK)
obs, reward, done, _ = env.step(kick)
assert b"In what direction? " in bytes(obs["message"])
env.step(nethack.MiscAction.MORE)

# Hack to quit.
env.env.step(nethack.M("q"))
obs, reward, done, _ = env.step(actions.index(ord("y")))
env.nethack.step(nethack.M("q"))
obs, reward, done, _ = env.step(env.actions.index(ord("y")))

assert done
assert reward == 0.0
Expand All @@ -364,11 +361,11 @@ def test_final_reward(self, env):
# Hopefully, we got some positive reward by now.

# Get out of any menu / yn_function.
env.step(env._actions.index(ord("\r")))
env.step(env.actions.index(ord("\r")))

# Hack to quit.
env.env.step(nethack.M("q"))
_, reward, done, _ = env.step(env._actions.index(ord("y")))
env.nethack.step(nethack.M("q"))
_, reward, done, _ = env.step(env.actions.index(ord("y")))

assert done
assert reward == 0.0
Expand Down Expand Up @@ -398,8 +395,8 @@ def test_no_seed_setting(self):
):
env.seed()
with pytest.raises(RuntimeError, match="Should not try changing seeds"):
env.env.set_initial_seeds(0, 0, True)
env.nethack.set_initial_seeds(0, 0, True)

if not nethack.NLE_ALLOW_SEEDING:
with pytest.raises(RuntimeError, match="Seeding not enabled"):
env.env._pynethack.set_initial_seeds(0, 0, True)
env.nethack._pynethack.set_initial_seeds(0, 0, True)
2 changes: 1 addition & 1 deletion nle/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_run_1k_steps(self, observation_keys, benchmark):
steps = 1000

np.random.seed(seeds)
actions = np.random.choice(len(env._actions), size=steps)
actions = np.random.choice(env.action_space.n, size=steps)

def seed():
if not nle.nethack.NLE_ALLOW_SEEDING:
Expand Down

0 comments on commit f4c0e55

Please sign in to comment.