-
Notifications
You must be signed in to change notification settings - Fork 4
/
slot-attention.yaml
42 lines (38 loc) · 929 Bytes
/
slot-attention.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# @package _global_
defaults:
- _base
batch_size: 64
trainer:
_target_: models.slot_attention.trainer.SlotAttentionTrainer
steps: 500_000
use_warmup_lr: true
warmup_steps: 10_000
use_exp_decay: true
exp_decay_rate: 0.5
exp_decay_steps: 100_000
optimizer_config:
alg: Adam
lr: 0.0004
model:
_target_: models.slot_attention.model.SlotAttentionAE
name: slot-attention
num_slots: ${dataset.max_num_objects}
latent_size: 64
encoder_params:
channels: [32, 32, 32, 32]
kernels: [5, 5, 5, 5]
paddings: [2, 2, 2, 2]
strides: [1, 1, 1, 1]
decoder_params:
conv_transposes: false
channels: [32, 32, 32, 4]
kernels: [5, 5, 5, 3]
strides: [1, 1, 1, 1]
paddings: [2, 2, 2, 1]
output_paddings: [0, 0, 0, 0]
activations: [relu, relu, relu, null]
attention_iters: 3
mlp_size: 128
eps: 1e-8
h_broadcast: ${dataset.height}
w_broadcast: ${dataset.width}