-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
26 lines (19 loc) · 1.12 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from core.rl_solve_bu import CaptioningSolver
from core.rl_model_bu import CaptionGenerator
from core.utils import load_coco_data
def main():
# load train dataset
data = load_coco_data(data_path='./data', split='train')
word_to_idx = data['word_to_idx']
# load val dataset to print out bleu scores every epoch
val_data = load_coco_data(data_path='./data', split='val')
model = CaptionGenerator(word_to_idx, dim_feature=[[121, 1536], [36, 2048]], dim_embed=1000,
dim_hidden=1000, n_time_step=26, prev2out=True, dim_att_hid=512,
ctx2out=True, alpha_c=0, selector=True, dropout=True)
solver = CaptioningSolver(model, data, val_data, n_epochs=25, batch_size=50, update_rule='adam',
learning_rate=5e-6, print_every=500, save_every=1, image_path='./image/',
pretrained_model='./model/all_att/model-11', model_path='./model/rl_all_att/',
n_batches=6000, print_bleu=True, log_path='./log/')
solver.train()
if __name__ == "__main__":
main()