Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Cleanup and add vectorization #301

Merged
merged 4 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 0 additions & 59 deletions nle/env/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import collections
import csv
import enum
import logging
import os
Expand Down Expand Up @@ -190,7 +189,6 @@ class StepStatus(enum.IntEnum):
def __init__(
self,
savedir=None,
archivefile=None,
character="mon-hum-neu-mal",
max_episode_steps=5000,
observation_keys=(
Expand Down Expand Up @@ -243,8 +241,6 @@ def __init__(
If set to True, do not decline menus, text input or auto 'MORE'.
If set to False, only skip click through 'MORE' on death.
"""
del archivefile # TODO: Remove once we change the API.

self.character = character
self._max_episode_steps = max_episode_steps
self._allow_all_yn_questions = allow_all_yn_questions
Expand Down Expand Up @@ -278,12 +274,6 @@ def __init__(
else:
logger.info("Not saving any NLE data.")

# TODO: Fix stats_file logic.
# self._setup_statsfile = self.savedir is not None
self._setup_statsfile = False
self._stats_file = None
self._stats_logger = None

self._observation_keys = list(observation_keys)

if "internal" in self._observation_keys:
Expand Down Expand Up @@ -404,42 +394,11 @@ def step(self, action: int):
done = True

info = {}
# TODO: fix stats
# if end_status:
# # stats = self._collect_stats(last_observation, end_status)
# # stats = stats._asdict()
# # stats = {}
# # info["stats"] = stats
#
# # if self._stats_logger is not None:
# # self._stats_logger.writerow(stats)

info["end_status"] = end_status
info["is_ascended"] = self.env.how_done() == nethack.ASCENDED

return self._get_observation(observation), reward, done, info

def _collect_stats(self, message, end_status):
"""Updates a stats dict tracking several env stats."""
# Using class rather than instance to allow tasks to reuse this with
# super()
# return NLE.Stats(
# end_status=int(end_status),
# score=_get(message, "Blstats.score", required=True),
# time=_get(message, "Blstats.time", required=True),
# steps=self._steps,
# hp=_get(message, "Blstats.hitpoints", required=True),
# exp=_get(message, "Blstats.experience_points", required=True),
# exp_lev=_get(message, "Blstats.experience_level", required=True),
# gold=_get(message, "Blstats.gold", required=True),
# hunger=_get(message, "You.uhunger", required=True),
# # killer_name=self._killer_name,
# deepest_lev=_get(message, "Internal.deepest_lev_reached", required=True),
# episode=self._episode,
# seeds=self.get_seeds(),
# ttyrec=self.env._process.filename,
# )

def _in_moveloop(self, observation):
program_state = observation[self._program_state_index]
return program_state[3] # in_moveloop
Expand All @@ -460,21 +419,6 @@ def reset(self, wizkit_items=None):
new_ttyrec = self._ttyrec_pattern % self._episode if self.savedir else None
self.last_observation = self.env.reset(new_ttyrec, wizkit_items=wizkit_items)

# Only run on the first reset to initialize stats file
if self._setup_statsfile:
filename = os.path.join(self.savedir, "stats.csv")
add_header = not os.path.exists(filename)

self._stats_file = open(filename, "a", 1) # line buffered.
self._stats_logger = csv.DictWriter(
self._stats_file, fieldnames=self.Stats._fields
)
if add_header:
self._stats_logger.writeheader()
self._setup_statsfile = False

# self._killer_name = "UNK"

self._steps = 0

for _ in range(1000):
Expand Down Expand Up @@ -610,9 +554,6 @@ def _perform_known_steps(self, observation, done, exceptions=True):
observation, done = self.env.step(ASCII_SPACE)
continue

# TODO: Think about killer_name.
# if self._killer_name == "UNK"

internal = observation[self._internal_index]
in_yn_function = internal[1]
in_getlin = internal[2]
Expand Down
8 changes: 3 additions & 5 deletions nle/env/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,9 @@ def _is_episode_end(self, observation):
blstats = observation[self._blstats_index]
x, y = blstats[:2]

neighbors = glyphs[y - 1 : y + 2, x - 1 : x + 2].reshape(-1).tolist()
# TODO: vectorize
for glyph in neighbors:
if nethack.glyph_is_pet(glyph):
return self.StepStatus.TASK_SUCCESSFUL
neighbors = glyphs[y - 1 : y + 2, x - 1 : x + 2]
if np.any(nethack.glyph_is_pet(neighbors)):
return self.StepStatus.TASK_SUCCESSFUL
return self.StepStatus.RUNNING


Expand Down
172 changes: 172 additions & 0 deletions nle/tests/test_nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,178 @@ def test_glyph2tile(self):
assert nethack.glyph2tile[nethack.GLYPH_PET_OFF] == 0
assert nethack.glyph2tile[nethack.GLYPH_DETECT_OFF] == 0

def test_glyph_is(self):
assert nethack.glyph_is_monster(nethack.GLYPH_MON_OFF)
assert nethack.glyph_is_pet(nethack.GLYPH_PET_OFF)
assert nethack.glyph_is_invisible(nethack.GLYPH_INVIS_OFF)
assert nethack.glyph_is_detected_monster(nethack.GLYPH_DETECT_OFF)
assert nethack.glyph_is_body(nethack.GLYPH_BODY_OFF)
assert nethack.glyph_is_ridden_monster(nethack.GLYPH_RIDDEN_OFF)
assert nethack.glyph_is_object(nethack.GLYPH_OBJ_OFF)
assert nethack.glyph_is_cmap(nethack.GLYPH_CMAP_OFF)
# No glyph_is_explode, glyph_is_zap in NH.
assert nethack.glyph_is_swallow(nethack.GLYPH_SWALLOW_OFF)
assert nethack.glyph_is_warning(nethack.GLYPH_WARNING_OFF)
assert nethack.glyph_is_statue(nethack.GLYPH_STATUE_OFF)

vec = np.array(
[
nethack.GLYPH_MON_OFF,
nethack.GLYPH_PET_OFF,
nethack.GLYPH_INVIS_OFF,
nethack.GLYPH_DETECT_OFF,
nethack.GLYPH_BODY_OFF,
nethack.GLYPH_RIDDEN_OFF,
nethack.GLYPH_OBJ_OFF,
nethack.GLYPH_CMAP_OFF,
nethack.GLYPH_EXPLODE_OFF,
nethack.GLYPH_ZAP_OFF,
nethack.GLYPH_SWALLOW_OFF,
nethack.GLYPH_WARNING_OFF,
nethack.GLYPH_STATUE_OFF,
],
dtype=np.int32,
)
np.testing.assert_array_equal(
nethack.glyph_is_monster(vec),
np.isin(
vec,
[
nethack.GLYPH_MON_OFF,
nethack.GLYPH_PET_OFF,
nethack.GLYPH_DETECT_OFF,
nethack.GLYPH_RIDDEN_OFF,
],
),
)
np.testing.assert_array_equal(
nethack.glyph_is_pet(vec),
np.isin(vec, [nethack.GLYPH_PET_OFF]),
)
np.testing.assert_array_equal(
nethack.glyph_is_invisible(vec),
np.isin(vec, [nethack.GLYPH_INVIS_OFF]),
)
np.testing.assert_array_equal(
nethack.glyph_is_normal_object(vec),
np.isin(vec, [nethack.GLYPH_OBJ_OFF]),
)
np.testing.assert_array_equal(
nethack.glyph_is_detected_monster(vec),
np.isin(vec, [nethack.GLYPH_DETECT_OFF]),
)
np.testing.assert_array_equal(
nethack.glyph_is_body(vec),
np.isin(vec, [nethack.GLYPH_BODY_OFF]),
)
np.testing.assert_array_equal(
nethack.glyph_is_ridden_monster(vec),
np.isin(vec, [nethack.GLYPH_RIDDEN_OFF]),
)
np.testing.assert_array_equal(
nethack.glyph_is_object(vec),
np.isin(
vec,
[
nethack.GLYPH_BODY_OFF,
nethack.GLYPH_OBJ_OFF,
nethack.GLYPH_STATUE_OFF,
],
),
)
assert np.all(nethack.glyph_is_trap(vec) == 0)
for idx in range(nethack.MAXPCHARS): # Find an actual trap.
if "trap" in nethack.symdef.from_idx(idx).explanation:
assert nethack.glyph_is_trap(nethack.GLYPH_CMAP_OFF + idx)
break
np.testing.assert_array_equal( # Explosions are cmaps?
nethack.glyph_is_cmap(vec),
np.isin(vec, [nethack.GLYPH_CMAP_OFF, nethack.GLYPH_EXPLODE_OFF]),
)
# No glyph_is_explode, glyph_is_zap in NH.
np.testing.assert_array_equal(
nethack.glyph_is_swallow(vec),
np.isin(vec, [nethack.GLYPH_SWALLOW_OFF]),
)
np.testing.assert_array_equal(
nethack.glyph_is_warning(vec),
np.isin(vec, [nethack.GLYPH_WARNING_OFF]),
)
np.testing.assert_array_equal(
nethack.glyph_is_statue(vec),
np.isin(vec, [nethack.GLYPH_STATUE_OFF]),
)

# Test some non-offset value too.
assert nethack.glyph_is_warning(
(nethack.GLYPH_WARNING_OFF + nethack.GLYPH_STATUE_OFF) // 2
)

def test_glyph_to(self):
assert np.all(
nethack.glyph_to_mon(
np.array(
[
nethack.GLYPH_MON_OFF,
nethack.GLYPH_PET_OFF,
nethack.GLYPH_DETECT_OFF,
nethack.GLYPH_RIDDEN_OFF,
nethack.GLYPH_STATUE_OFF,
]
)
)
== 0
)

# STATUE and CORPSE from onames.h (generated by makedefs).
# Returned by glyph_to_obj.
corpse = get_object("corpse").oc_name_idx
statue = get_object("statue").oc_name_idx
np.testing.assert_array_equal(
nethack.glyph_to_obj(
np.array(
[
nethack.GLYPH_BODY_OFF,
nethack.GLYPH_STATUE_OFF,
nethack.GLYPH_OBJ_OFF,
]
)
),
np.array([corpse, statue, 0]),
)

for idx in range(nethack.MAXPCHARS): # Find the arrow trap.
if nethack.symdef.from_idx(idx).explanation == "arrow trap":
np.testing.assert_array_equal(
nethack.glyph_to_trap(
np.array([nethack.GLYPH_CMAP_OFF, nethack.GLYPH_CMAP_OFF + idx])
),
# Traps are one-indexed in defsym_to_trap as per rm.h.
np.array([nethack.NO_GLYPH, 1]),
)
break

np.testing.assert_array_equal(
nethack.glyph_to_cmap(
np.array(
[
nethack.GLYPH_CMAP_OFF,
nethack.GLYPH_STATUE_OFF,
]
)
),
np.array([0, nethack.NO_GLYPH]),
)

assert nethack.glyph_to_swallow(nethack.GLYPH_SWALLOW_OFF) == 0

np.testing.assert_array_equal(
nethack.glyph_to_warning(
np.arange(nethack.GLYPH_WARNING_OFF, nethack.GLYPH_STATUE_OFF)
),
np.arange(nethack.WARNCOUNT),
)


class TestNethackGlanceObservation:
@pytest.fixture
Expand Down
75 changes: 47 additions & 28 deletions win/rl/pynethack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,31 +535,40 @@ PYBIND11_MODULE(_pynethack, m)
py::int_(MG_OBJPILE); /* more than one stack of objects */
mn.attr("MG_BW_LAVA") = py::int_(MG_BW_LAVA); /* 'black & white lava' */

// Expose macros as Python functions.
// Expose macros as Python functions, with optional vectorization.
mn.def("glyph_is_monster",
[](int glyph) { return glyph_is_monster(glyph); });
mn.def("glyph_is_normal_monster",
[](int glyph) { return glyph_is_normal_monster(glyph); });
mn.def("glyph_is_pet", [](int glyph) { return glyph_is_pet(glyph); });
mn.def("glyph_is_body", [](int glyph) { return glyph_is_body(glyph); });
py::vectorize([](int glyph) { return glyph_is_monster(glyph); }));
mn.def("glyph_is_normal_monster", py::vectorize([](int glyph) {
return glyph_is_normal_monster(glyph);
}));
mn.def("glyph_is_pet",
py::vectorize([](int glyph) { return glyph_is_pet(glyph); }));
mn.def("glyph_is_body",
py::vectorize([](int glyph) { return glyph_is_body(glyph); }));
mn.def("glyph_is_statue",
[](int glyph) { return glyph_is_statue(glyph); });
mn.def("glyph_is_ridden_monster",
[](int glyph) { return glyph_is_ridden_monster(glyph); });
mn.def("glyph_is_detected_monster",
[](int glyph) { return glyph_is_detected_monster(glyph); });
mn.def("glyph_is_invisible",
[](int glyph) { return glyph_is_invisible(glyph); });
mn.def("glyph_is_normal_object",
[](int glyph) { return glyph_is_normal_object(glyph); });
py::vectorize([](int glyph) { return glyph_is_statue(glyph); }));
mn.def("glyph_is_ridden_monster", py::vectorize([](int glyph) {
return glyph_is_ridden_monster(glyph);
}));
mn.def("glyph_is_detected_monster", py::vectorize([](int glyph) {
return glyph_is_detected_monster(glyph);
}));
mn.def("glyph_is_invisible", py::vectorize([](int glyph) {
return glyph_is_invisible(glyph);
}));
mn.def("glyph_is_normal_object", py::vectorize([](int glyph) {
return glyph_is_normal_object(glyph);
}));
mn.def("glyph_is_object",
[](int glyph) { return glyph_is_object(glyph); });
mn.def("glyph_is_trap", [](int glyph) { return glyph_is_trap(glyph); });
mn.def("glyph_is_cmap", [](int glyph) { return glyph_is_cmap(glyph); });
py::vectorize([](int glyph) { return glyph_is_object(glyph); }));
mn.def("glyph_is_trap",
py::vectorize([](int glyph) { return glyph_is_trap(glyph); }));
mn.def("glyph_is_cmap",
py::vectorize([](int glyph) { return glyph_is_cmap(glyph); }));
mn.def("glyph_is_swallow",
[](int glyph) { return glyph_is_swallow(glyph); });
py::vectorize([](int glyph) { return glyph_is_swallow(glyph); }));
mn.def("glyph_is_warning",
[](int glyph) { return glyph_is_warning(glyph); });
py::vectorize([](int glyph) { return glyph_is_warning(glyph); }));

#ifdef NLE_USE_TILES
mn.attr("glyph2tile") =
Expand Down Expand Up @@ -647,14 +656,24 @@ PYBIND11_MODULE(_pynethack, m)
+ "' explain='" + std::string(cs.explain) + "'>";
});

mn.def("glyph_to_mon", [](int glyph) { return glyph_to_mon(glyph); });
mn.def("glyph_to_obj", [](int glyph) { return glyph_to_obj(glyph); });
mn.def("glyph_to_trap", [](int glyph) { return glyph_to_trap(glyph); });
mn.def("glyph_to_cmap", [](int glyph) { return glyph_to_cmap(glyph); });
mn.def("glyph_to_swallow",
[](int glyph) { return glyph_to_swallow(glyph); });
mn.def("glyph_to_warning",
[](int glyph) { return glyph_to_warning(glyph); });
mn.def("glyph_to_mon", py::vectorize([](int glyph) -> int {
return glyph_to_mon(glyph);
}));
mn.def("glyph_to_obj", py::vectorize([](int glyph) -> int {
return glyph_to_obj(glyph);
}));
mn.def("glyph_to_trap", py::vectorize([](int glyph) -> int {
return glyph_to_trap(glyph);
}));
mn.def("glyph_to_cmap", py::vectorize([](int glyph) -> int {
return glyph_to_cmap(glyph);
}));
mn.def("glyph_to_swallow", py::vectorize([](int glyph) -> int {
return glyph_to_swallow(glyph);
}));
mn.def("glyph_to_warning", py::vectorize([](int glyph) -> int {
return glyph_to_warning(glyph);
}));

py::class_<objclass>(
mn, "objclass",
Expand Down