-
Notifications
You must be signed in to change notification settings - Fork 14
/
vizdoom.py
337 lines (273 loc) · 11.3 KB
/
vizdoom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import copy
import itertools
import os
import tempfile
import gym
from gym import error, spaces
from gym.utils import seeding
import numpy as np
import vizdoom
import omg
ASSET_PATH = os.path.join(os.path.dirname(__file__), 'assets', 'vizdoom')
# Texture set A.
TEXTURES_SET_A = [line.strip() for line in open(os.path.join(ASSET_PATH, 'texture_set_a.txt'))]
# Texture set B.
TEXTURES_SET_B = [line.strip() for line in open(os.path.join(ASSET_PATH, 'texture_set_b.txt'))]
# Thing set A.
THINGS_SET_A = [int(line.strip()) for line in open(os.path.join(ASSET_PATH, 'thing_set_a.txt'))]
# Thing set B.
THINGS_SET_B = [int(line.strip()) for line in open(os.path.join(ASSET_PATH, 'thing_set_b.txt'))]
# Map cache to avoid re-parsing the map multiple times.
_MAP_CACHE = {}
def sampler_with_map_editor(sampler):
"""Wrap sampler function to provide the map editor.
Maps must be in UDMF format.
"""
def wrapper(env, config):
# Load source WAD.
scenario = os.path.join(ASSET_PATH, config['scenario'])
map_name = config.get('map', 'MAP01').upper()
cache_key = (scenario, map_name)
if cache_key not in _MAP_CACHE:
wad = omg.WadIO(scenario)
editor = omg.UDMFMapEditor(wad)
editor.load(map_name)
_MAP_CACHE[cache_key] = editor
else:
editor = _MAP_CACHE[cache_key]
editor = copy.deepcopy(editor)
sampler(env, config, editor)
# Create temporary WAD and write updated level there.
updated_wad = tempfile.mktemp(suffix='.wad')
editor.save(updated_wad)
return updated_wad
return wrapper
def sample_textures(textures, sides=True, sectors=True):
"""Perform texture sampling.
:param textures: A set of textures to sample from
:param sides: Update all side textures (walls)
:param sectors: Update all sector textures (floor, ceiling)
"""
@sampler_with_map_editor
def sampler(env, config, editor):
"""Perform texture sampling.
:param env: Environment instance
:param config: Configuration dictionary
:param editor: Map editor
"""
# Update side textures.
if sides:
for side in editor.sidedefs:
side.texturemiddle = str(env.np_random.choice(textures))
# Update floor and ceiling textures for all sectors.
if sectors:
for sector in editor.sectors:
sector.texturefloor = str(env.np_random.choice(textures))
sector.textureceiling = str(env.np_random.choice(textures))
return sampler
def sample_things(things, modify_things):
"""Perform thing sampling.
:param things: A set of things to sample from
:param modify_things: A set of things that can be modified
"""
@sampler_with_map_editor
def sampler(env, config, editor):
"""Perform thing sampling.
:param env: Environment instance
:param config: Configuration dictionary
:param editor: Map editor
"""
for thing in editor.things:
# Ignore player start position.
if thing.type not in modify_things:
continue
thing.type = int(env.np_random.choice(things))
return sampler
class VizDoomEnvironment(gym.Env):
metadata = {
'render.modes': ['rgb_array'],
'video.frames_per_second': 35,
}
# Scenario definitions. Within each scenario definition, configuration is inherited
# from the baseline variant to avoid repetition.
scenarios = {
'basic': {
'baseline': {
'scenario': 'basic.wad',
'living_reward': 1,
'death_penalty': 0,
'reward': 'health',
},
'floor_ceiling_flipped': {'scenario': 'basic_floor_ceiling_flipped.wad'},
'torches': {'scenario': 'basic_torches.wad'},
'random_textures_set_a': {'sampler': sample_textures(TEXTURES_SET_A)},
'random_textures_set_b': {'sampler': sample_textures(TEXTURES_SET_B)},
'random_things_set_a': {
'scenario': 'basic_torches.wad',
'sampler': sample_things(THINGS_SET_A, modify_things=[56]),
},
'random_things_set_b': {
'scenario': 'basic_torches.wad',
'sampler': sample_things(THINGS_SET_B, modify_things=[56]),
},
},
'navigation': {
'baseline': {
'scenario': 'navigation.wad',
'living_reward': 1,
'death_penalty': 0,
'reward': 'health',
},
'new_layout': {'scenario': 'navigation_new_layout.wad'},
'floor_ceiling_flipped': {'scenario': 'navigation_floor_ceiling_flipped.wad'},
'torches': {'scenario': 'navigation_torches.wad'},
'random_textures_set_a': {'sampler': sample_textures(TEXTURES_SET_A)},
'random_textures_set_b': {'sampler': sample_textures(TEXTURES_SET_B)},
'random_things_set_a': {
'scenario': 'navigation_torches.wad',
'sampler': sample_things(THINGS_SET_A, modify_things=[56]),
},
'random_things_set_b': {
'scenario': 'navigation_torches.wad',
'sampler': sample_things(THINGS_SET_B, modify_things=[56]),
},
}
}
# Available buttons.
buttons = [
vizdoom.Button.MOVE_FORWARD,
vizdoom.Button.MOVE_BACKWARD,
vizdoom.Button.MOVE_RIGHT,
vizdoom.Button.MOVE_LEFT,
vizdoom.Button.TURN_LEFT,
vizdoom.Button.TURN_RIGHT,
vizdoom.Button.ATTACK,
vizdoom.Button.SPEED,
]
opposite_button_pairs = [
(vizdoom.Button.MOVE_FORWARD, vizdoom.Button.MOVE_BACKWARD),
(vizdoom.Button.MOVE_RIGHT, vizdoom.Button.MOVE_LEFT),
(vizdoom.Button.TURN_LEFT, vizdoom.Button.TURN_RIGHT),
]
def __init__(self, scenario, variant, obs_type='image', frameskip=4):
if scenario not in self.scenarios:
raise error.Error("Unsupported scenario: {}".format(scenario))
if variant not in self.scenarios[scenario]:
raise error.Error("Unsupported scenario variant: {}".format(variant))
# Generate config (extend from baseline).
config = {}
config.update(self.scenarios[scenario]['baseline'])
config.update(self.scenarios[scenario][variant])
self._config = config
self._vizdoom = vizdoom.DoomGame()
self._vizdoom.set_doom_scenario_path(os.path.join(ASSET_PATH, config['scenario']))
self._vizdoom.set_doom_map(config.get('map', 'MAP01'))
self._vizdoom.set_screen_resolution(vizdoom.ScreenResolution.RES_640X480)
self._vizdoom.set_screen_format(vizdoom.ScreenFormat.BGR24)
self._vizdoom.set_mode(vizdoom.Mode.PLAYER)
self._width = 640
self._height = 480
self._depth = 3
# Entity visibility.
self._vizdoom.set_render_hud(False)
self._vizdoom.set_render_minimal_hud(False)
self._vizdoom.set_render_crosshair(False)
self._vizdoom.set_render_weapon(False)
self._vizdoom.set_render_decals(False)
self._vizdoom.set_render_particles(False)
self._vizdoom.set_render_effects_sprites(False)
self._vizdoom.set_render_messages(False)
self._vizdoom.set_render_corpses(False)
self._vizdoom.set_window_visible(False)
self._vizdoom.set_sound_enabled(False)
# Rewards.
self._vizdoom.set_living_reward(config.get('living_reward', 1))
self._vizdoom.set_death_penalty(config.get('death_penalty', 100))
# Duration.
self._vizdoom.set_episode_timeout(config.get('episode_timeout', 2100))
# Generate action space from buttons.
for button in self.buttons:
self._vizdoom.add_available_button(button)
self._action_button_map = []
for combination in itertools.product([False, True], repeat=len(self.buttons)):
# Exclude any pairs where opposite buttons are pressed.
valid = True
for a, b in self.opposite_button_pairs:
if combination[self.buttons.index(a)] and combination[self.buttons.index(b)]:
valid = False
break
if valid:
self._action_button_map.append(list(combination))
self.action_space = spaces.Discrete(len(self._action_button_map))
if obs_type == 'image':
self.observation_space = spaces.Box(low=0, high=255, shape=(self._height, self._width, self._depth))
else:
raise error.Error("Unrecognized observation type: {}".format(obs_type))
self._scenario = scenario
self._variant = variant
self._obs_type = obs_type
self._frameskip = frameskip
self._initialized = False
self._temporary_scenario = None
self._seed()
def __getstate__(self):
return {
'scenario': self._scenario,
'variant': self._variant,
'obs_type': self._obs_type,
'frameskip': self._frameskip,
}
def __setstate__(self, state):
self.__init__(**state)
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
self._vizdoom.set_seed(seed % 2**32)
return [seed]
def _get_observation(self):
state = self._vizdoom.get_state()
if self._obs_type == 'image':
if not state:
return np.zeros([self._height, self._width, self._depth])
return state.screen_buffer
raise NotImplementedError
def _reset(self):
# Sample scenario when configured.
sampler = self._config.get('sampler', None)
if sampler:
# Remove previous temporary scenario.
if self._temporary_scenario:
try:
os.remove(self._temporary_scenario)
except OSError:
pass
self._temporary_scenario = None
self._temporary_scenario = sampler(self, self._config)
self._vizdoom.set_doom_scenario_path(self._temporary_scenario)
if not self._initialized:
self._vizdoom.init()
self._initialized = True
self._vizdoom.new_episode()
return self._get_observation()
def _get_state_variables(self):
return {
'health': self._vizdoom.get_game_variable(vizdoom.GameVariable.HEALTH),
'frags': self._vizdoom.get_game_variable(vizdoom.GameVariable.FRAGCOUNT),
}
def _step(self, action):
previous_info = self._get_state_variables()
action = self._action_button_map[action]
scenario_reward = self._vizdoom.make_action(action, self._frameskip)
terminal = self._vizdoom.is_episode_finished() or self._vizdoom.is_player_dead()
observation = self._get_observation()
info = self._get_state_variables()
reward_value = self._config.get('reward', 'reward')
if reward_value == 'reward':
reward = scenario_reward
else:
reward = info[reward_value] - previous_info[reward_value]
return observation, reward, terminal, info
def get_keys_to_action(self):
# TODO.
return {
(): 0,
}