-
Notifications
You must be signed in to change notification settings - Fork 143
/
Copy pathrun_agent.py
35 lines (25 loc) · 1.25 KB
/
run_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from argparse import ArgumentParser
import pickle
from minerl.herobraine.env_specs.human_survival_specs import HumanSurvival
from agent import MineRLAgent, ENV_KWARGS
def main(model, weights):
env = HumanSurvival(**ENV_KWARGS).make()
print("---Loading model---")
agent_parameters = pickle.load(open(model, "rb"))
policy_kwargs = agent_parameters["model"]["args"]["net"]["args"]
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
agent = MineRLAgent(env, policy_kwargs=policy_kwargs, pi_head_kwargs=pi_head_kwargs)
agent.load_weights(weights)
print("---Launching MineRL enviroment (be patient)---")
obs = env.reset()
while True:
minerl_action = agent.get_action(obs)
obs, reward, done, info = env.step(minerl_action)
env.render()
if __name__ == "__main__":
parser = ArgumentParser("Run pretrained models on MineRL environment")
parser.add_argument("--weights", type=str, required=True, help="Path to the '.weights' file to be loaded.")
parser.add_argument("--model", type=str, required=True, help="Path to the '.model' file to be loaded.")
args = parser.parse_args()
main(args.model, args.weights)