-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhparams.yaml
35 lines (34 loc) · 1.3 KB
/
hparams.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
train_new: &DEFAULT
dataset: omniglot # omniglot / miniimagenet / quickdraw
meta_learner: maml # maml / reptile / anil
model: small # model architecture, specified in /models/models.py
optimizer: adam
learning_rate: 1e-3
lr_finetune: 1e-1
index: 1
num_epochs: 50
n_way: 5
freeze: 2 # used for anil only, number of layers to freeze
meta_batch_size: 32 # number of tasks in training batch
k_support: 5 # k shot for support set
k_query: 15 # k shot for query set
n_inner_iter: 5 # number of iterations in inner loop
seed: 121 # random seed
saving_gradient_steps: true # save gradient steps to get which directions to plot
loss_plotting: true # whether to plot loss landscape
plot_progress: true # whether to plot progress during training
modelpath: ./models # base path to model state_dict
trajpath: ./models # base path to weights trajectories
modelname: modelname.pt
weightstrajfilename: weights_trajectories.npy # path to numpy array of weight trajectories
plot_gridsize: 20 # number of points in plot meshgrid
last_n_traj_points: 20
fix_extractor: False
fix_head: False
pretrained:
<<: *DEFAULT
dataset: omniglot
seed: 121 # random seed
loss_plotting: true # whether to plot loss landscape
plot_gridsize: 20 # number of points in plot meshgrid
last_n_traj_points: 20