-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathplan.py
121 lines (89 loc) · 3.65 KB
/
plan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import json
import pdb
from os.path import join
import trajectory.utils as utils
import trajectory.datasets as datasets
from trajectory.search import (
beam_plan,
make_prefix,
extract_actions,
update_context,
)
class Parser(utils.Parser):
dataset: str = 'halfcheetah-medium-expert-v2'
config: str = 'config.offline'
#######################
######## setup ########
#######################
args = Parser().parse_args('plan')
#######################
####### models ########
#######################
dataset = utils.load_from_config(args.logbase, args.dataset, args.gpt_loadpath,
'data_config.pkl')
gpt, gpt_epoch = utils.load_model(args.logbase, args.dataset, args.gpt_loadpath,
epoch=args.gpt_epoch, device=args.device)
#######################
####### dataset #######
#######################
env = datasets.load_environment(args.dataset)
renderer = utils.make_renderer(args)
timer = utils.timer.Timer()
discretizer = dataset.discretizer
discount = dataset.discount
observation_dim = dataset.observation_dim
action_dim = dataset.action_dim
value_fn = lambda x: discretizer.value_fn(x, args.percentile)
preprocess_fn = datasets.get_preprocess_fn(env.name)
#######################
###### main loop ######
#######################
observation = env.reset()
total_reward = 0
## observations for rendering
rollout = [observation.copy()]
## previous (tokenized) transitions for conditioning transformer
context = []
T = env.max_episode_steps
for t in range(T):
observation = preprocess_fn(observation)
if t % args.plan_freq == 0:
## concatenate previous transitions and current observations to input to model
prefix = make_prefix(discretizer, context, observation, args.prefix_context)
## sample sequence from model beginning with `prefix`
sequence = beam_plan(
gpt, value_fn, prefix,
args.horizon, args.beam_width, args.n_expand, observation_dim, action_dim,
discount, args.max_context_transitions, verbose=args.verbose,
k_obs=args.k_obs, k_act=args.k_act, cdf_obs=args.cdf_obs, cdf_act=args.cdf_act,
)
else:
sequence = sequence[1:]
## [ horizon x transition_dim ] convert sampled tokens to continuous trajectory
sequence_recon = discretizer.reconstruct(sequence)
## [ action_dim ] index into sampled trajectory to grab first action
action = extract_actions(sequence_recon, observation_dim, action_dim, t=0)
## execute action in environment
next_observation, reward, terminal, _ = env.step(action)
## update return
total_reward += reward
score = env.get_normalized_score(total_reward)
## update rollout observations and context transitions
rollout.append(next_observation.copy())
context = update_context(context, discretizer, observation, action, reward, args.max_context_transitions)
print(
f'[ plan ] t: {t} / {T} | r: {reward:.2f} | R: {total_reward:.2f} | score: {score:.4f} | '
f'time: {timer():.2f} | {args.dataset} | {args.exp_name} | {args.suffix}\n'
)
## visualization
if t % args.vis_freq == 0 or terminal or t == T:
## save current plan
renderer.render_plan(join(args.savepath, f'{t}_plan.mp4'), sequence_recon, env.state_vector())
## save rollout thus far
renderer.render_rollout(join(args.savepath, f'rollout.mp4'), rollout, fps=80)
if terminal: break
observation = next_observation
## save result as a json file
json_path = join(args.savepath, 'rollout.json')
json_data = {'score': score, 'step': t, 'return': total_reward, 'term': terminal, 'gpt_epoch': gpt_epoch}
json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True)