Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ak/fix train feeding #5

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,29 @@

def generate(decoder, prime_str='A', predict_len=100, temperature=0.8, cuda=False):
hidden = decoder.init_hidden(1)
prime_input = Variable(char_tensor(prime_str).unsqueeze(0))
prime_input = Variable(char_tensor(prime_str).unsqueeze(0), volatile=True)

if cuda:
hidden = hidden.cuda()
prime_input = prime_input.cuda()
predicted = prime_str

# Use priming string to "build up" hidden state
for p in range(len(prime_str) - 1):
_, hidden = decoder(prime_input[:,p], hidden)

inp = prime_input[:,-1]

_, hidden = decoder(prime_input, hidden)

inp = prime_input[0,-1].unsqueeze(0)

for p in range(predict_len):
output, hidden = decoder(inp, hidden)

# Sample from the network as a multinomial distribution
output_dist = output.data.view(-1).div(temperature).exp()
top_i = torch.multinomial(output_dist, 1)[0]

# Add predicted character to string and use as next input
predicted_char = all_characters[top_i]
predicted += predicted_char
inp = Variable(char_tensor(predicted_char).unsqueeze(0))
inp = Variable(char_tensor(predicted_char).unsqueeze(0), volatile=True)
if cuda:
inp = inp.cuda()

Expand Down
18 changes: 8 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,20 @@ def __init__(self, input_size, hidden_size, output_size, model="gru", n_layers=1

self.encoder = nn.Embedding(input_size, hidden_size)
if self.model == "gru":
self.rnn = nn.GRU(hidden_size, hidden_size, n_layers)
self.rnn = nn.GRU(hidden_size, hidden_size, n_layers, batch_first=True)
elif self.model == "lstm":
self.rnn = nn.LSTM(hidden_size, hidden_size, n_layers)
self.rnn = nn.LSTM(hidden_size, hidden_size, n_layers, batch_first=True)
self.decoder = nn.Linear(hidden_size, output_size)

def forward(self, input, hidden):
"""
input: shape=(batch_size, seq_size)
output: shape=(batch_size, seq_size, output_size)
"""
batch_size = input.size(0)
encoded = self.encoder(input)
output, hidden = self.rnn(encoded.view(1, batch_size, -1), hidden)
output = self.decoder(output.view(batch_size, -1))
return output, hidden

def forward2(self, input, hidden):
encoded = self.encoder(input.view(1, -1))
output, hidden = self.rnn(encoded.view(1, 1, -1), hidden)
output = self.decoder(output.view(1, -1))
output, hidden = self.rnn(encoded, hidden)
output = self.decoder(output)
return output, hidden

def init_hidden(self, batch_size):
Expand Down
14 changes: 8 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def random_training_set(chunk_len, batch_size):
inp = torch.LongTensor(batch_size, chunk_len)
target = torch.LongTensor(batch_size, chunk_len)
for bi in range(batch_size):
start_index = random.randint(0, file_len - chunk_len)
start_index = random.randint(0, file_len - chunk_len - 1)
end_index = start_index + chunk_len + 1
chunk = file[start_index:end_index]
inp[bi] = char_tensor(chunk[:-1])
Expand All @@ -49,20 +49,22 @@ def random_training_set(chunk_len, batch_size):
return inp, target

def train(inp, target):
"""
inp: (batch_size, seq_size)
target: (batch_size, seq_size)
"""
hidden = decoder.init_hidden(args.batch_size)
if args.cuda:
hidden = hidden.cuda()
decoder.zero_grad()
loss = 0

for c in range(args.chunk_len):
output, hidden = decoder(inp[:,c], hidden)
loss += criterion(output.view(args.batch_size, -1), target[:,c])
output, hidden = decoder(inp, hidden)
loss = criterion(output.view(-1, output.size(-1)), target.view(-1))

loss.backward()
decoder_optimizer.step()

return loss.data[0] / args.chunk_len
return loss.data[0]

def save():
save_filename = os.path.splitext(os.path.basename(args.filename))[0] + '.pt'
Expand Down