-
Notifications
You must be signed in to change notification settings - Fork 1
/
ego_gmm_pre_bc.yaml
41 lines (36 loc) · 1.02 KB
/
ego_gmm_pre_bc.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# @package _global_
defaults:
# - override /trainer: ddp
- override /model: ego_gmm
model:
model_config:
lr: 5e-4
token_processor:
map_token_sampling: # open-loop
num_k: 1 # for k nearest neighbors
temp: 1.0 # uniform sampling
agent_token_sampling: # closed-loop
num_k: 1 # for k nearest neighbors
temp: 1.0
validation_rollout_sampling:
criterium: topk_prob # {topk_prob, topk_prob_sampled_with_dist}
num_k: 3 # for k most likely
temp_mode: 1e-3
temp_cov: 1e-3
training_rollout_sampling:
criterium: topk_prob # {topk_prob, topk_prob_sampled_with_dist}
num_k: -1 # for k nearest neighbors, set to -1 to turn-off closed-loop training
temp_mode: 1e-3
temp_cov: 1e-3
ckpt_path: null
# ckpt_path: CKPT_FOR_RESUME.ckpt # to resume training
trainer:
limit_train_batches: 1.0
limit_val_batches: 0.1
check_val_every_n_epoch: 1
max_epochs: 64
data:
train_batch_size: 10
val_batch_size: 10
test_batch_size: 10
num_workers: 10