-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
91 lines (63 loc) · 2.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import unicode_literals, print_function, division
import random
import csv
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import cPickle
from model import LSTM
def inputTensor(input_idx):
tensor = torch.LongTensor(input_idx)
return autograd.Variable(tensor)
def targetTensor(input_idx, char2idx):
input_idx = input_idx[1:]
input_idx.append(char2idx['EOS'])
tensor = torch.LongTensor(input_idx)
return autograd.Variable(tensor)
def train(model, criterion, input, target):
hidden = model.initHidden()
model.zero_grad()
output, _ = model(input, hidden)
_, predY = torch.max(output.data, 1)
loss = criterion(output, target)
loss.backward()
return loss.data[0] / input.size()[0]
def read_csv(filname):
names_str = []
with open(filname) as f:
reader = csv.reader(f)
reader.next()
for row in reader:
names_str.append(row[4].decode('utf-8'))
return names_str
def main():
names_str = read_csv(filname='data/names/names.csv')
all_char_str = set([char for name in names_str for char in name])
char2idx = {char: i for i, char in enumerate(all_char_str)}
char2idx['EOS'] = len(char2idx)
# save char dictionary
cPickle.dump(char2idx, open("dic.p", "wb"))
names_idx = [[char2idx[char_str] for char_str in name_str]
for name_str in names_str]
# build model
model = LSTM(input_dim=len(char2idx), embed_dim=100, hidden_dim=128)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())
n_iters = 5
for iter in range(1, n_iters + 1):
# data shuffle
random.shuffle(names_idx)
total_loss = 0
for i, name_idx in enumerate(names_idx):
input = inputTensor(name_idx)
target = targetTensor(name_idx, char2idx)
loss = train(model, criterion, input, target)
total_loss += loss
optimizer.step()
print(iter, "/", n_iters)
print("loss {:.4}".format(float(total_loss / len(names_idx))))
# save trained model
torch.save(model.state_dict(), "model.pt")
if __name__ == '__main__':
main()