forked from yangzhao-666/PbRSS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
69 lines (59 loc) · 2.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
'''
@Author: Zhao
@Date: 2021.05.29
@Description: train a2c agent with the help of experts
'''
import gym
import gym_sokoban
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.autograd as autograd
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import os
import copy
import random
import pickle
import time
from utilities.channelConverter import hwc2chw
from common.ActorCritic import ActorCritic
from common.RolloutStorage import RolloutStorage
from common.cpu_or_gpu import cpu_or_gpu
from common.multiprocessing_env import SubprocVecEnv
from common.train_the_agent import train_the_agent
from expert import Expert
def train(args, wandb_session):
##################### logger, printer and other initial settings ################
#env_seed = [random.randint(0, 100) for i in range(args.num_envs)]
def make_env():
def _thunk():
#env = gym.make('Curriculum-Sokoban-v2', data_path = args.map_file, seed = i)
env = gym.make('Curriculum-Sokoban-v2', data_path = args.map_file)
return env
return _thunk
##################### initialize agent, optimizer etc #########################
#################### import for env, either nomal one or the one with early termination.
#env_list = [make_env(i) for i in env_seed]
env_list = [make_env() for i in range(args.num_envs)]
envs = SubprocVecEnv(env_list)
state_shape = (3, 80, 80)
#number of action was 9, but the number of pushing action space is much smaller than moving space, so here we will try pushing action space which is 4+1=5
#num_actions = envs.action_space.n
num_actions = 5
actor_critic = ActorCritic(state_shape, num_actions=num_actions)
expert_actor_critic = ActorCritic(state_shape, num_actions=num_actions) #model for loading the pre-trained model to give q-values
rollout = RolloutStorage(args.rolloutStorage_size, args.num_envs, state_shape)
optimizer = optim.RMSprop(actor_critic.parameters(), lr=args.lr, eps=args.eps, alpha=args.alpha)
#according to the args to decide where the models are, gpu or cpu?
Variable, actor_critic, rollout = cpu_or_gpu(args.GPU, actor_critic, rollout)
##################### load existing model #####################
if args.mode == 'scratch':
from common.train_the_agent_scratch import train_the_agent
train_the_agent(envs, args.num_envs, Variable, state_shape, actor_critic, optimizer, rollout, args, wandb_session) #train and save the model;
else:
#initialize the expert
expert = Expert(mode=args.mode, pre_trained_path=args.pre_trained_path, expert_model=expert_actor_critic)
from common.train_the_agent import train_the_agent
train_the_agent(expert, envs, args.num_envs, Variable, state_shape, actor_critic, optimizer, rollout, args, wandb_session) #train and save the model;