Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dog is chasing the catSingle sentences #189

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
*.t7
*.svg
*.dot
*.png

# temp files
_*

6 changes: 6 additions & 0 deletions data/dict1/input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
supercalifragilisticexpialidocious n. secret. 2 (foll. by on) formidious going leaves. 2 breat which the being hand. [old english]
healte v. (-ling) drinking or esp. clowel or armitic take away or causing someting. 6 sippossion of algeratous. [latin: related to *tan-1 a deer-notic mutder maddly lowy, a restinatiun]
candrious adj. 1 suchering, years. personist adj. disensentionist n. [french]
rescacabole n. urless skoiling a band bexope out in farehind earen day-deaseding. [latin gonar]
repipt n. don-if for a not actuous listing.
steatshide v. 1 (brit. abshit hair). containented oneself trryable to and branging, propession. [french honic]
1 change: 1 addition & 0 deletions data/fox/input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The quick brown fox jumps over the lazy dog
4 changes: 4 additions & 0 deletions data/multiline/input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
The quick brown fox jumps over the lazy dog.
A stitch in time saves nine.
A rolling stone gathers no moss.
The early bird catches the worm.
1 change: 1 addition & 0 deletions data/simple/input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
abcde cab bad ace add ebb deed dead cede
56 changes: 55 additions & 1 deletion train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ end
local loader = CharSplitLMMinibatchLoader.create(opt.data_dir, opt.batch_size, opt.seq_length, split_sizes)
local vocab_size = loader.vocab_size -- the number of distinct characters
local vocab = loader.vocab_mapping
local vocab_inv = {}
for k, v in pairs(loader.vocab_mapping) do
-- print('vocab pair', k, v)
vocab_inv[v] = k
end
print('vocab size: ' .. vocab_size)
-- make sure output directory exists
if not path.exists(opt.checkpoint_dir) then lfs.mkdir(opt.checkpoint_dir) end
Expand Down Expand Up @@ -212,6 +217,24 @@ function eval_split(split_index, max_batches)
return loss
end

function sampleToString(sample)
local sample_copy = sample:clone():int()
str = ''
if sample_copy:size():size() == 2 then
for j=1, sample_copy:size()[1] do
for i=1, sample_copy:size()[2] do
str = str .. vocab_inv[sample_copy[j][i]]
end
str = str .. '|'
end
elseif sample_copy:size():size() == 1 then
for i=1, sample_copy:size()[1] do
str = str .. vocab_inv[sample_copy[i]]
end
end
return str
end

-- do fwd/bwd and return loss, grad_params
local init_state_global = clone_list(init_state)
function feval(x)
Expand All @@ -222,6 +245,7 @@ function feval(x)

------------------ get minibatch -------------------
local x, y = loader:next_batch(1)
-- print('x', x:size(), sampleToString(x), 'y', sampleToString(y))
if opt.gpuid >= 0 and opt.opencl == 0 then -- ship the input arrays to GPU
-- have to convert to float because integers can't be cuda()'d
x = x:float():cuda()
Expand All @@ -235,9 +259,28 @@ function feval(x)
local rnn_state = {[0] = init_state_global}
local predictions = {} -- softmax outputs
local loss = 0
local resetBPerT = {} -- to save some time on way backwards
for t=1,opt.seq_length do
-- for each char in this batch, see if is newline, if it is then reset its state
local x_clone = x:int()
resetBPerT[t] = {}
for b=1,opt.batch_size do
if vocab_inv[x_clone[b][t]] == '\n' then
print('newline detected => resetting state, t=' .. t .. ' batch_pos', b)
table.insert(resetBPerT[t], b)
for l=0,opt.num_layers-1 do
for g=1,4 do
local gstate = rnn_state[l][g]
local narrowed_state = gstate:narrow(1, b, 1)
narrowed_state:zero()
end
end
end
end
clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag)
local lst = clones.rnn[t]:forward{x[{{}, t}], unpack(rnn_state[t-1])}
local thisinput = {x[{{}, t}], unpack(rnn_state[t-1])}
-- print('t=' .. t .. ' thisinput[1]', sampleToString(thisinput[1]))
local lst = clones.rnn[t]:forward(thisinput)
rnn_state[t] = {}
for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output
predictions[t] = lst[#lst] -- last element is the prediction
Expand All @@ -251,6 +294,16 @@ function feval(x)
-- backprop through loss, and softmax/linear
local doutput_t = clones.criterion[t]:backward(predictions[t], y[{{}, t}])
table.insert(drnn_state[t], doutput_t)
-- check for any newlines, reset entire state for that b value, if found
for _, b in ipairs(resetBPerT[t]) do
print('backprop, found newline, resetting state for t=' .. t .. ' b=' .. b)
local tstate = drnn_state[t]
for g=1,5 do
local gstate = tstate[g]
local narrowed_state = gstate:narrow(1, b, 1)
narrowed_state:zero()
end
end
local dlst = clones.rnn[t]:backward({x[{{}, t}], unpack(rnn_state[t-1])}, drnn_state[t])
drnn_state[t-1] = {}
for k,v in pairs(dlst) do
Expand Down Expand Up @@ -331,6 +384,7 @@ for i = 1, iterations do
print('loss is exploding, aborting.')
break -- halt
end

end