-
Notifications
You must be signed in to change notification settings - Fork 167
/
finetune.py
412 lines (355 loc) · 12.3 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
import datetime
from functools import partial
import os
from absl import app, flags, logging
import flax
from flax.traverse_util import flatten_dict
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from ml_collections import config_flags, ConfigDict
import optax
import tensorflow as tf
import tqdm
import wandb
from octo.data.dataset import make_single_dataset
from octo.model.octo_model import OctoModel
from octo.utils.jax_utils import initialize_compilation_cache
from octo.utils.spec import ModuleSpec
from octo.utils.train_callbacks import (
RolloutVisualizationCallback,
SaveCallback,
ValidationCallback,
VisualizationCallback,
)
from octo.utils.train_utils import (
check_config_diff,
create_optimizer,
format_name_with_config,
merge_params,
process_text,
Timer,
TrainState,
)
try:
from jax_smi import initialise_tracking # type: ignore
initialise_tracking()
except ImportError:
pass
FLAGS = flags.FLAGS
flags.DEFINE_string("name", "experiment", "Experiment name.")
flags.DEFINE_bool("debug", False, "Debug config (no wandb logging)")
default_config_file = os.path.join(
os.path.dirname(__file__), "configs/finetune_config.py"
)
config_flags.DEFINE_config_file(
"config",
default_config_file,
"File path to the training hyperparameter configuration.",
lock_config=False,
)
def main(_):
initialize_compilation_cache()
devices = jax.devices()
logging.info(
f"""
Octo Finetuning Script
======================
Pretrained model: {FLAGS.config.pretrained_path}
Finetuning Dataset: {FLAGS.config.dataset_kwargs.name}
Data dir: {FLAGS.config.dataset_kwargs.data_dir}
Task Modality: {FLAGS.config.modality}
Finetuning Mode: {FLAGS.config.finetuning_mode}
# Devices: {jax.device_count()}
Batch size: {FLAGS.config.batch_size} ({FLAGS.config.batch_size // len(devices) } per device)
# Steps: {FLAGS.config.num_steps}
"""
)
#########
#
# Setup Jax Data Parallelism
#
#########
assert (
FLAGS.config.batch_size % len(devices) == 0
), f"Batch size ({FLAGS.config.batch_size}) must be divisible by the number of devices ({len(devices)})"
assert (
FLAGS.config.viz_kwargs.eval_batch_size % len(devices) == 0
), f"Eval batch size ({FLAGS.config.viz_kwargs.eval_batch_size}) must be divisible by the number of devices ({len(devices)})"
# create a 1D mesh with a single axis named "batch"
mesh = Mesh(jax.devices(), axis_names="batch")
# Our batches will be data-parallel sharded -- each device will get a slice of the batch
dp_sharding = NamedSharding(mesh, PartitionSpec("batch"))
# Our model will be replicated across devices (we are only doing data parallelism, not model parallelism)
replicated_sharding = NamedSharding(mesh, PartitionSpec())
# prevent tensorflow from using GPU memory since it's only used for data loading
tf.config.set_visible_devices([], "GPU")
#########
#
# Setup WandB
#
#########
name = format_name_with_config(
FLAGS.name,
FLAGS.config.to_dict(),
)
wandb_id = "{name}_{time}".format(
name=name,
time=datetime.datetime.now().strftime("%Y%m%d_%H%M%S"),
)
wandb.init(
config=FLAGS.config.to_dict(),
id=wandb_id,
name=name,
mode="disabled" if FLAGS.debug else None,
**FLAGS.config.wandb,
)
#########
#
# Load Pretrained model + optionally modify config
#
#########
pretrained_model = OctoModel.load_pretrained(
FLAGS.config.pretrained_path,
step=FLAGS.config.pretrained_step,
)
flat_config = flax.traverse_util.flatten_dict(
pretrained_model.config, keep_empty_nodes=True
)
for d_key in flax.traverse_util.flatten_dict(
FLAGS.config.get("config_delete_keys", ConfigDict()).to_dict()
):
for c_key in list(flat_config.keys()):
if ".".join(c_key).startswith(".".join(d_key)):
del flat_config[c_key]
config = ConfigDict(flax.traverse_util.unflatten_dict(flat_config))
config.update(FLAGS.config.get("update_config", ConfigDict()))
config = config.to_dict()
check_config_diff(config, pretrained_model.config)
#########
#
# Setup Data Loader
#
#########
# create text processor
if config["text_processor"] is None:
text_processor = None
else:
text_processor = ModuleSpec.instantiate(config["text_processor"])()
def process_batch(batch):
batch = process_text(batch, text_processor)
del batch["dataset_name"]
return batch
dataset = make_single_dataset(
FLAGS.config.dataset_kwargs,
traj_transform_kwargs=FLAGS.config.traj_transform_kwargs,
frame_transform_kwargs=FLAGS.config.frame_transform_kwargs,
train=True,
)
train_data_iter = (
dataset.repeat()
.unbatch()
.shuffle(FLAGS.config.shuffle_buffer_size)
.batch(FLAGS.config.batch_size)
.iterator()
)
train_data_iter = map(process_batch, train_data_iter)
example_batch = next(train_data_iter)
#########
#
# Load Pretrained Model
#
#########
rng = jax.random.PRNGKey(FLAGS.config.seed)
rng, init_rng = jax.random.split(rng)
model = OctoModel.from_config(
config,
example_batch,
text_processor,
rng=init_rng,
dataset_statistics=dataset.dataset_statistics,
)
merged_params = merge_params(model.params, pretrained_model.params)
model = model.replace(params=merged_params)
del pretrained_model
#########
#
# Setup Optimizer and Train State
#
#########
params = model.params
if FLAGS.config.optimizer.frozen_keys is None:
FLAGS.config.optimizer.frozen_keys = model.config["optimizer"]["frozen_keys"]
tx, lr_callable, param_norm_callable = create_optimizer(
params,
**FLAGS.config.optimizer.to_dict(),
)
train_state = TrainState.create(
model=model,
tx=tx,
rng=rng,
)
#########
#
# Save all metadata
#
#########
if FLAGS.config.save_dir is not None:
save_dir = tf.io.gfile.join(
FLAGS.config.save_dir,
FLAGS.config.wandb.project,
FLAGS.config.wandb.group or "",
wandb_id,
)
wandb.config.update(dict(save_dir=save_dir), allow_val_change=True)
logging.info("Saving to %s", save_dir)
save_callback = SaveCallback(save_dir)
# Add window_size to top of config, to make eval easier
new_config = ConfigDict(model.config)
new_config["window_size"] = example_batch["observation"][
"timestep_pad_mask"
].shape[1]
model = model.replace(config=new_config)
# Save finetuning config since it's not saved by SaveCallback, i.e. as part of model.save_pretrained()
with tf.io.gfile.GFile(
tf.io.gfile.join(save_dir, "finetune_config.json"), "w"
) as config_file:
config_file.write(FLAGS.config.to_json_best_effort())
else:
save_dir = None
save_callback = SaveCallback(None)
logging.warning("save_dir not passed in, not saving checkpoints")
example_batch_spec = jax.tree_map(
lambda arr: (arr.shape, str(arr.dtype)), example_batch
)
wandb.config.update(
dict(example_batch_spec=example_batch_spec), allow_val_change=True
)
#########
#
# Define loss, train_step, and eval_step
#
#########
def loss_fn(params, batch, rng, train=True):
bound_module = model.module.bind({"params": params}, rngs={"dropout": rng})
transformer_embeddings = bound_module.octo_transformer(
batch["observation"],
batch["task"],
batch["observation"]["timestep_pad_mask"],
train=train,
)
action_loss, action_metrics = bound_module.heads["action"].loss(
transformer_embeddings, # action head knows to pull out the "action" readout_key
batch["action"],
batch["observation"]["timestep_pad_mask"],
batch["action_pad_mask"],
train=train,
)
return action_loss, action_metrics
# Data parallelism
# Model is replicated across devices, data is split across devices
@partial(
jax.jit,
in_shardings=[replicated_sharding, dp_sharding],
)
def train_step(state: TrainState, batch):
rng, dropout_rng = jax.random.split(state.rng)
(loss, info), grads = jax.value_and_grad(loss_fn, has_aux=True)(
state.model.params, batch, dropout_rng, train=True
)
grad_norm = optax.global_norm(grads)
updates, _ = state.tx.update(grads, state.opt_state, state.model.params)
update_norm = optax.global_norm(updates)
info.update(
{
"grad_norm": grad_norm,
"update_norm": update_norm,
"param_norm": param_norm_callable(state.model.params),
"learning_rate": lr_callable(state.step),
}
)
new_state = state.apply_gradients(grads=grads, rng=rng)
return new_state, info
#########
#
# Build validation & visualization callbacks
#
#########
if FLAGS.config.modality == "image_conditioned":
modes_to_evaluate = ["image_conditioned"]
elif FLAGS.config.modality == "text_conditioned":
modes_to_evaluate = ["text_conditioned"]
elif FLAGS.config.modality == "multimodal":
modes_to_evaluate = ["image_conditioned", "text_conditioned"]
else:
modes_to_evaluate = ["base"]
dataset_kwargs_list = [FLAGS.config.dataset_kwargs]
val_callback = ValidationCallback(
loss_fn=loss_fn,
process_batch_fn=process_batch,
text_processor=text_processor,
val_dataset_kwargs_list=dataset_kwargs_list,
dataset_kwargs=FLAGS.config,
modes_to_evaluate=modes_to_evaluate,
**FLAGS.config.val_kwargs,
)
viz_callback = VisualizationCallback(
text_processor=text_processor,
val_dataset_kwargs_list=dataset_kwargs_list,
dataset_kwargs=FLAGS.config,
modes_to_evaluate=modes_to_evaluate,
**FLAGS.config.viz_kwargs,
)
#########
#
# Optionally build visualizers for sim env evals
#
#########
if "rollout_kwargs" in FLAGS.config:
rollout_callback = RolloutVisualizationCallback(
text_processor=text_processor,
unnormalization_statistics=dataset.dataset_statistics["action"],
**FLAGS.config.rollout_kwargs.to_dict(),
)
else:
rollout_callback = None
#########
#
# Train loop
#
#########
def wandb_log(info, step):
wandb.log(flatten_dict(info, sep="/"), step=step)
timer = Timer()
for i in tqdm.tqdm(
range(0, int(FLAGS.config.num_steps)),
total=int(FLAGS.config.num_steps),
dynamic_ncols=True,
):
timer.tick("total")
with timer("dataset"):
batch = next(train_data_iter)
with timer("train"):
train_state, update_info = train_step(train_state, batch)
timer.tock("total")
if (i + 1) % FLAGS.config.log_interval == 0:
update_info = jax.device_get(update_info)
wandb_log(
{"training": update_info, "timer": timer.get_average_times()}, step=i
)
if (i + 1) % FLAGS.config.eval_interval == 0:
logging.info("Evaluating...")
with timer("val"):
val_metrics = val_callback(train_state, i + 1)
wandb_log(val_metrics, step=i)
with timer("visualize"):
viz_metrics = viz_callback(train_state, i + 1)
wandb_log(viz_metrics, step=i)
if rollout_callback is not None:
with timer("rollout"):
rollout_metrics = rollout_callback(train_state, i + 1)
wandb_log(rollout_metrics, step=i)
if (i + 1) % FLAGS.config.save_interval == 0 and save_dir is not None:
logging.info("Saving checkpoint...")
save_callback(train_state, i + 1)
if __name__ == "__main__":
app.run(main)