diff --git a/nle/env/base.py b/nle/env/base.py index ab669a664..4f2f7981a 100644 --- a/nle/env/base.py +++ b/nle/env/base.py @@ -1,6 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. import collections -import csv import enum import logging import os @@ -190,7 +189,6 @@ class StepStatus(enum.IntEnum): def __init__( self, savedir=None, - archivefile=None, character="mon-hum-neu-mal", max_episode_steps=5000, observation_keys=( @@ -243,8 +241,6 @@ def __init__( If set to True, do not decline menus, text input or auto 'MORE'. If set to False, only skip click through 'MORE' on death. """ - del archivefile # TODO: Remove once we change the API. - self.character = character self._max_episode_steps = max_episode_steps self._allow_all_yn_questions = allow_all_yn_questions @@ -278,12 +274,6 @@ def __init__( else: logger.info("Not saving any NLE data.") - # TODO: Fix stats_file logic. - # self._setup_statsfile = self.savedir is not None - self._setup_statsfile = False - self._stats_file = None - self._stats_logger = None - self._observation_keys = list(observation_keys) if "internal" in self._observation_keys: @@ -404,42 +394,11 @@ def step(self, action: int): done = True info = {} - # TODO: fix stats - # if end_status: - # # stats = self._collect_stats(last_observation, end_status) - # # stats = stats._asdict() - # # stats = {} - # # info["stats"] = stats - # - # # if self._stats_logger is not None: - # # self._stats_logger.writerow(stats) - info["end_status"] = end_status info["is_ascended"] = self.env.how_done() == nethack.ASCENDED return self._get_observation(observation), reward, done, info - def _collect_stats(self, message, end_status): - """Updates a stats dict tracking several env stats.""" - # Using class rather than instance to allow tasks to reuse this with - # super() - # return NLE.Stats( - # end_status=int(end_status), - # score=_get(message, "Blstats.score", required=True), - # time=_get(message, "Blstats.time", required=True), - # steps=self._steps, - # hp=_get(message, "Blstats.hitpoints", required=True), - # exp=_get(message, "Blstats.experience_points", required=True), - # exp_lev=_get(message, "Blstats.experience_level", required=True), - # gold=_get(message, "Blstats.gold", required=True), - # hunger=_get(message, "You.uhunger", required=True), - # # killer_name=self._killer_name, - # deepest_lev=_get(message, "Internal.deepest_lev_reached", required=True), - # episode=self._episode, - # seeds=self.get_seeds(), - # ttyrec=self.env._process.filename, - # ) - def _in_moveloop(self, observation): program_state = observation[self._program_state_index] return program_state[3] # in_moveloop @@ -460,21 +419,6 @@ def reset(self, wizkit_items=None): new_ttyrec = self._ttyrec_pattern % self._episode if self.savedir else None self.last_observation = self.env.reset(new_ttyrec, wizkit_items=wizkit_items) - # Only run on the first reset to initialize stats file - if self._setup_statsfile: - filename = os.path.join(self.savedir, "stats.csv") - add_header = not os.path.exists(filename) - - self._stats_file = open(filename, "a", 1) # line buffered. - self._stats_logger = csv.DictWriter( - self._stats_file, fieldnames=self.Stats._fields - ) - if add_header: - self._stats_logger.writeheader() - self._setup_statsfile = False - - # self._killer_name = "UNK" - self._steps = 0 for _ in range(1000): @@ -610,9 +554,6 @@ def _perform_known_steps(self, observation, done, exceptions=True): observation, done = self.env.step(ASCII_SPACE) continue - # TODO: Think about killer_name. - # if self._killer_name == "UNK" - internal = observation[self._internal_index] in_yn_function = internal[1] in_getlin = internal[2] diff --git a/nle/env/tasks.py b/nle/env/tasks.py index 14d88ae02..d5c682df1 100644 --- a/nle/env/tasks.py +++ b/nle/env/tasks.py @@ -137,11 +137,9 @@ def _is_episode_end(self, observation): blstats = observation[self._blstats_index] x, y = blstats[:2] - neighbors = glyphs[y - 1 : y + 2, x - 1 : x + 2].reshape(-1).tolist() - # TODO: vectorize - for glyph in neighbors: - if nethack.glyph_is_pet(glyph): - return self.StepStatus.TASK_SUCCESSFUL + neighbors = glyphs[y - 1 : y + 2, x - 1 : x + 2] + if np.any(nethack.glyph_is_pet(neighbors)): + return self.StepStatus.TASK_SUCCESSFUL return self.StepStatus.RUNNING diff --git a/nle/tests/test_nethack.py b/nle/tests/test_nethack.py index 23e064af4..ce30c5253 100644 --- a/nle/tests/test_nethack.py +++ b/nle/tests/test_nethack.py @@ -368,6 +368,178 @@ def test_glyph2tile(self): assert nethack.glyph2tile[nethack.GLYPH_PET_OFF] == 0 assert nethack.glyph2tile[nethack.GLYPH_DETECT_OFF] == 0 + def test_glyph_is(self): + assert nethack.glyph_is_monster(nethack.GLYPH_MON_OFF) + assert nethack.glyph_is_pet(nethack.GLYPH_PET_OFF) + assert nethack.glyph_is_invisible(nethack.GLYPH_INVIS_OFF) + assert nethack.glyph_is_detected_monster(nethack.GLYPH_DETECT_OFF) + assert nethack.glyph_is_body(nethack.GLYPH_BODY_OFF) + assert nethack.glyph_is_ridden_monster(nethack.GLYPH_RIDDEN_OFF) + assert nethack.glyph_is_object(nethack.GLYPH_OBJ_OFF) + assert nethack.glyph_is_cmap(nethack.GLYPH_CMAP_OFF) + # No glyph_is_explode, glyph_is_zap in NH. + assert nethack.glyph_is_swallow(nethack.GLYPH_SWALLOW_OFF) + assert nethack.glyph_is_warning(nethack.GLYPH_WARNING_OFF) + assert nethack.glyph_is_statue(nethack.GLYPH_STATUE_OFF) + + vec = np.array( + [ + nethack.GLYPH_MON_OFF, + nethack.GLYPH_PET_OFF, + nethack.GLYPH_INVIS_OFF, + nethack.GLYPH_DETECT_OFF, + nethack.GLYPH_BODY_OFF, + nethack.GLYPH_RIDDEN_OFF, + nethack.GLYPH_OBJ_OFF, + nethack.GLYPH_CMAP_OFF, + nethack.GLYPH_EXPLODE_OFF, + nethack.GLYPH_ZAP_OFF, + nethack.GLYPH_SWALLOW_OFF, + nethack.GLYPH_WARNING_OFF, + nethack.GLYPH_STATUE_OFF, + ], + dtype=np.int32, + ) + np.testing.assert_array_equal( + nethack.glyph_is_monster(vec), + np.isin( + vec, + [ + nethack.GLYPH_MON_OFF, + nethack.GLYPH_PET_OFF, + nethack.GLYPH_DETECT_OFF, + nethack.GLYPH_RIDDEN_OFF, + ], + ), + ) + np.testing.assert_array_equal( + nethack.glyph_is_pet(vec), + np.isin(vec, [nethack.GLYPH_PET_OFF]), + ) + np.testing.assert_array_equal( + nethack.glyph_is_invisible(vec), + np.isin(vec, [nethack.GLYPH_INVIS_OFF]), + ) + np.testing.assert_array_equal( + nethack.glyph_is_normal_object(vec), + np.isin(vec, [nethack.GLYPH_OBJ_OFF]), + ) + np.testing.assert_array_equal( + nethack.glyph_is_detected_monster(vec), + np.isin(vec, [nethack.GLYPH_DETECT_OFF]), + ) + np.testing.assert_array_equal( + nethack.glyph_is_body(vec), + np.isin(vec, [nethack.GLYPH_BODY_OFF]), + ) + np.testing.assert_array_equal( + nethack.glyph_is_ridden_monster(vec), + np.isin(vec, [nethack.GLYPH_RIDDEN_OFF]), + ) + np.testing.assert_array_equal( + nethack.glyph_is_object(vec), + np.isin( + vec, + [ + nethack.GLYPH_BODY_OFF, + nethack.GLYPH_OBJ_OFF, + nethack.GLYPH_STATUE_OFF, + ], + ), + ) + assert np.all(nethack.glyph_is_trap(vec) == 0) + for idx in range(nethack.MAXPCHARS): # Find an actual trap. + if "trap" in nethack.symdef.from_idx(idx).explanation: + assert nethack.glyph_is_trap(nethack.GLYPH_CMAP_OFF + idx) + break + np.testing.assert_array_equal( # Explosions are cmaps? + nethack.glyph_is_cmap(vec), + np.isin(vec, [nethack.GLYPH_CMAP_OFF, nethack.GLYPH_EXPLODE_OFF]), + ) + # No glyph_is_explode, glyph_is_zap in NH. + np.testing.assert_array_equal( + nethack.glyph_is_swallow(vec), + np.isin(vec, [nethack.GLYPH_SWALLOW_OFF]), + ) + np.testing.assert_array_equal( + nethack.glyph_is_warning(vec), + np.isin(vec, [nethack.GLYPH_WARNING_OFF]), + ) + np.testing.assert_array_equal( + nethack.glyph_is_statue(vec), + np.isin(vec, [nethack.GLYPH_STATUE_OFF]), + ) + + # Test some non-offset value too. + assert nethack.glyph_is_warning( + (nethack.GLYPH_WARNING_OFF + nethack.GLYPH_STATUE_OFF) // 2 + ) + + def test_glyph_to(self): + assert np.all( + nethack.glyph_to_mon( + np.array( + [ + nethack.GLYPH_MON_OFF, + nethack.GLYPH_PET_OFF, + nethack.GLYPH_DETECT_OFF, + nethack.GLYPH_RIDDEN_OFF, + nethack.GLYPH_STATUE_OFF, + ] + ) + ) + == 0 + ) + + # STATUE and CORPSE from onames.h (generated by makedefs). + # Returned by glyph_to_obj. + corpse = get_object("corpse").oc_name_idx + statue = get_object("statue").oc_name_idx + np.testing.assert_array_equal( + nethack.glyph_to_obj( + np.array( + [ + nethack.GLYPH_BODY_OFF, + nethack.GLYPH_STATUE_OFF, + nethack.GLYPH_OBJ_OFF, + ] + ) + ), + np.array([corpse, statue, 0]), + ) + + for idx in range(nethack.MAXPCHARS): # Find the arrow trap. + if nethack.symdef.from_idx(idx).explanation == "arrow trap": + np.testing.assert_array_equal( + nethack.glyph_to_trap( + np.array([nethack.GLYPH_CMAP_OFF, nethack.GLYPH_CMAP_OFF + idx]) + ), + # Traps are one-indexed in defsym_to_trap as per rm.h. + np.array([nethack.NO_GLYPH, 1]), + ) + break + + np.testing.assert_array_equal( + nethack.glyph_to_cmap( + np.array( + [ + nethack.GLYPH_CMAP_OFF, + nethack.GLYPH_STATUE_OFF, + ] + ) + ), + np.array([0, nethack.NO_GLYPH]), + ) + + assert nethack.glyph_to_swallow(nethack.GLYPH_SWALLOW_OFF) == 0 + + np.testing.assert_array_equal( + nethack.glyph_to_warning( + np.arange(nethack.GLYPH_WARNING_OFF, nethack.GLYPH_STATUE_OFF) + ), + np.arange(nethack.WARNCOUNT), + ) + class TestNethackGlanceObservation: @pytest.fixture diff --git a/win/rl/pynethack.cc b/win/rl/pynethack.cc index b057ef613..e8a91f694 100644 --- a/win/rl/pynethack.cc +++ b/win/rl/pynethack.cc @@ -535,31 +535,40 @@ PYBIND11_MODULE(_pynethack, m) py::int_(MG_OBJPILE); /* more than one stack of objects */ mn.attr("MG_BW_LAVA") = py::int_(MG_BW_LAVA); /* 'black & white lava' */ - // Expose macros as Python functions. + // Expose macros as Python functions, with optional vectorization. mn.def("glyph_is_monster", - [](int glyph) { return glyph_is_monster(glyph); }); - mn.def("glyph_is_normal_monster", - [](int glyph) { return glyph_is_normal_monster(glyph); }); - mn.def("glyph_is_pet", [](int glyph) { return glyph_is_pet(glyph); }); - mn.def("glyph_is_body", [](int glyph) { return glyph_is_body(glyph); }); + py::vectorize([](int glyph) { return glyph_is_monster(glyph); })); + mn.def("glyph_is_normal_monster", py::vectorize([](int glyph) { + return glyph_is_normal_monster(glyph); + })); + mn.def("glyph_is_pet", + py::vectorize([](int glyph) { return glyph_is_pet(glyph); })); + mn.def("glyph_is_body", + py::vectorize([](int glyph) { return glyph_is_body(glyph); })); mn.def("glyph_is_statue", - [](int glyph) { return glyph_is_statue(glyph); }); - mn.def("glyph_is_ridden_monster", - [](int glyph) { return glyph_is_ridden_monster(glyph); }); - mn.def("glyph_is_detected_monster", - [](int glyph) { return glyph_is_detected_monster(glyph); }); - mn.def("glyph_is_invisible", - [](int glyph) { return glyph_is_invisible(glyph); }); - mn.def("glyph_is_normal_object", - [](int glyph) { return glyph_is_normal_object(glyph); }); + py::vectorize([](int glyph) { return glyph_is_statue(glyph); })); + mn.def("glyph_is_ridden_monster", py::vectorize([](int glyph) { + return glyph_is_ridden_monster(glyph); + })); + mn.def("glyph_is_detected_monster", py::vectorize([](int glyph) { + return glyph_is_detected_monster(glyph); + })); + mn.def("glyph_is_invisible", py::vectorize([](int glyph) { + return glyph_is_invisible(glyph); + })); + mn.def("glyph_is_normal_object", py::vectorize([](int glyph) { + return glyph_is_normal_object(glyph); + })); mn.def("glyph_is_object", - [](int glyph) { return glyph_is_object(glyph); }); - mn.def("glyph_is_trap", [](int glyph) { return glyph_is_trap(glyph); }); - mn.def("glyph_is_cmap", [](int glyph) { return glyph_is_cmap(glyph); }); + py::vectorize([](int glyph) { return glyph_is_object(glyph); })); + mn.def("glyph_is_trap", + py::vectorize([](int glyph) { return glyph_is_trap(glyph); })); + mn.def("glyph_is_cmap", + py::vectorize([](int glyph) { return glyph_is_cmap(glyph); })); mn.def("glyph_is_swallow", - [](int glyph) { return glyph_is_swallow(glyph); }); + py::vectorize([](int glyph) { return glyph_is_swallow(glyph); })); mn.def("glyph_is_warning", - [](int glyph) { return glyph_is_warning(glyph); }); + py::vectorize([](int glyph) { return glyph_is_warning(glyph); })); #ifdef NLE_USE_TILES mn.attr("glyph2tile") = @@ -647,14 +656,24 @@ PYBIND11_MODULE(_pynethack, m) + "' explain='" + std::string(cs.explain) + "'>"; }); - mn.def("glyph_to_mon", [](int glyph) { return glyph_to_mon(glyph); }); - mn.def("glyph_to_obj", [](int glyph) { return glyph_to_obj(glyph); }); - mn.def("glyph_to_trap", [](int glyph) { return glyph_to_trap(glyph); }); - mn.def("glyph_to_cmap", [](int glyph) { return glyph_to_cmap(glyph); }); - mn.def("glyph_to_swallow", - [](int glyph) { return glyph_to_swallow(glyph); }); - mn.def("glyph_to_warning", - [](int glyph) { return glyph_to_warning(glyph); }); + mn.def("glyph_to_mon", py::vectorize([](int glyph) -> int { + return glyph_to_mon(glyph); + })); + mn.def("glyph_to_obj", py::vectorize([](int glyph) -> int { + return glyph_to_obj(glyph); + })); + mn.def("glyph_to_trap", py::vectorize([](int glyph) -> int { + return glyph_to_trap(glyph); + })); + mn.def("glyph_to_cmap", py::vectorize([](int glyph) -> int { + return glyph_to_cmap(glyph); + })); + mn.def("glyph_to_swallow", py::vectorize([](int glyph) -> int { + return glyph_to_swallow(glyph); + })); + mn.def("glyph_to_warning", py::vectorize([](int glyph) -> int { + return glyph_to_warning(glyph); + })); py::class_( mn, "objclass",