Skip to content

Commit

Permalink
Created function that process observations #4
Browse files Browse the repository at this point in the history
  • Loading branch information
drkostas committed Dec 12, 2022
1 parent e7f5cde commit ae2c30e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,5 @@ MalmoPlatform
# Custom
.DS_Store
tmp*
logs
logs
minex86
60 changes: 42 additions & 18 deletions demo_no_RL.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
from pathlib import Path
import random
import json

from yaml_config_wrapper import Configuration
from RLcraft import MalmoMazeEnv
Expand All @@ -22,10 +23,32 @@ def get_args():
type=int,
required=False,
default=0,
help="number of gpus to use for trianing")
help="number of gpus to use for training")
args = parser.parse_args()
return args

def process_obs(np_obs, info):
""" Process the observation from the environment. """
# obs is a numpy array of shape (height, width, 3)
# info is a dictionary but we have to transform it to use it
info_obs = json.loads(info.observations[-1].text)
floor_data = info_obs['floor10x10']
time_data = info_obs['TotalTime']
xpos_data = info_obs['XPos']
ypos_data = info_obs['YPos']
zpos_data = info_obs['ZPos']
yaw_data = info_obs['Yaw'] # where the player is facing
xp_data = info_obs['XP']
obs = {}
obs['rgb'] = np_obs
obs['floor'] = floor_data
obs['time'] = time_data
obs['xpos'] = xpos_data
obs['ypos'] = ypos_data
obs['zpos'] = zpos_data
obs['yaw'] = yaw_data
obs['xp'] = xp_data
return obs

def main():
""" Run a the game with a random agent. """
Expand All @@ -37,7 +60,8 @@ def main():
c_general = c.get_config('general')[0]
c_tuner = c.get_config('tuner')[0]
# Load the values from the config
run_config = c_tuner['config']['env_config']
run_config = c_tuner['config']
env_config = run_config['env_config']
c_general = c_general['config']

run = True
Expand All @@ -46,20 +70,20 @@ def main():
print("Generating new seed ...")
maze_seed = random.randint(1, 9999)
print("Loading environment ...")
xml = Path(run_config["mission_file"]).read_text()
env = MalmoMazeEnv(
xml=xml,
width=run_config["width"],
height=run_config["height"],
millisec_per_tick=run_config["millisec_per_tick"],
mission_timeout_ms=run_config['mission_timeout_ms'],
step_reward=run_config['step_reward'],
win_reward=run_config['win_reward'],
lose_reward=run_config['lose_reward'],
action_space=run_config['action_space'],
client_port=run_config['client_port'],
time_wait=run_config['time_wait'],
max_loop=run_config['max_loop'])
width=env_config["width"],
height=env_config["height"],
mazeseed=maze_seed,
xml=env_config["mission_file"],
millisec_per_tick=env_config['millisec_per_tick'],
max_loop=c_general['max_loop'],
mission_timeout_ms=c_general['mission_timeout_ms'],
step_reward=c_general['step_reward'],
win_reward=c_general['win_reward'],
lose_reward=c_general['lose_reward'],
action_space=c_general['action_space'],
client_port=env_config['client_port'],
time_wait=c_general['time_wait'])
print("Resetting environment ...")
print(env.reset())
print("The world is loaded.")
Expand All @@ -69,10 +93,10 @@ def main():
while not done:
action = env.action_space.sample()
# Actions: 0 -> move (frwd), 1 -> right, 2 -> left
obs, reward, done, info = env.step(action)
np_obs, reward, done, info = env.step(action)
observations = process_obs(np_obs, info)
done = True
print(len(obs))
# obs is a numpy array of shape (height, width, 3)
print(observations)
env.render()
user_choice = input(
"Enter 'N' to exit, 'Y' to run new episode [Y/n]: ").lower()
Expand Down

0 comments on commit ae2c30e

Please sign in to comment.