Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AbstractRecurrent:clearState() clearing too much #395

Open
FabbyD opened this issue Feb 17, 2017 · 2 comments
Open

AbstractRecurrent:clearState() clearing too much #395

FabbyD opened this issue Feb 17, 2017 · 2 comments

Comments

@FabbyD
Copy link

FabbyD commented Feb 17, 2017

Example:

local net = nn.Sequencer(
  nn.Sequential()
    :add(nn.LSTM(100,100))
    :add(nn.Linear(100,100))
    :add(nn.LSTM(100,100))
  )

local inputs = {}
for i=1,10 do
  table.insert(inputs, torch.randn(100))
end

net:forward(inputs) -- This should create 10 clones of my network

net:clearState()

The last line will clear all 10 clones of nn.Sequential. Every one of these clones will also clear all the nn.LSTM clones. Since the nn.LSTM module takes care of its clones internally, we are clearing the same 10 clones 10 times.

By adding a few prints in AbstractRecurrent:clearState() we get something like this:

clearState nn.Recursor @ nn.Sequential {
  [input -> (1) -> (2) -> (3) -> output]
  (1): nn.LSTM(100 -> 100)
  (2): nn.Linear(100 -> 100)
  (3): nn.LSTM(100 -> 100)
}	
clearState nn.LSTM(100 -> 100)	
  cleared clone 1 in 0.00020408630371094	
  cleared clone 2 in 0.00022411346435547	
  cleared clone 3 in 0.0002291202545166	
  cleared clone 4 in 0.00021004676818848	
  cleared clone 5 in 0.00021100044250488	
  cleared clone 6 in 0.00018811225891113	
  cleared clone 7 in 0.00021791458129883	
  cleared clone 8 in 0.0002140998840332	
  cleared clone 9 in 0.00021195411682129	
  cleared clone 10 in 0.00020408630371094	
clearState nn.LSTM(100 -> 100)	
  cleared clone 1 in 0.0001978874206543	
  cleared clone 2 in 0.00060486793518066	
  cleared clone 3 in 0.00049901008605957	
  cleared clone 4 in 0.0002589225769043	
  cleared clone 5 in 0.00022697448730469	
  cleared clone 6 in 0.00019097328186035	
  cleared clone 7 in 0.00020694732666016	
  cleared clone 8 in 0.00022196769714355	
  cleared clone 9 in 0.00023078918457031	
  cleared clone 10 in 0.00024318695068359	
  cleared clone 1 in 0.0056848526000977	   <-- The first nn.Sequential clone
...	
clearState nn.LSTM(100 -> 100)	
  cleared clone 1 in 0.00015807151794434	
  cleared clone 2 in 0.00019478797912598	
  cleared clone 3 in 0.00017786026000977	
  cleared clone 4 in 0.00020194053649902	
  cleared clone 5 in 0.00017094612121582	
  cleared clone 6 in 0.00017809867858887	
  cleared clone 7 in 0.00016403198242188	
  cleared clone 8 in 0.00015807151794434	
  cleared clone 9 in 0.00016117095947266	
  cleared clone 10 in 0.00016188621520996	
clearState nn.LSTM(100 -> 100)	
  cleared clone 1 in 0.00016117095947266	
  cleared clone 2 in 0.00016403198242188	
  cleared clone 3 in 0.00015997886657715	
  cleared clone 4 in 0.00016307830810547	
  cleared clone 5 in 0.00016498565673828	
  cleared clone 6 in 0.00015807151794434	
  cleared clone 7 in 0.00015902519226074	
  cleared clone 8 in 0.00016093254089355	
  cleared clone 9 in 0.00016188621520996	
  cleared clone 10 in 0.00016617774963379	
  cleared clone 10 in 0.0038068294525146   <-- The last nn.Sequential clone

This might not seem like a big deal but that means #clones x #clones x 2 (there are 2 LSTMs in this example) calls to clearState. When dealing with longer sequences like documents, this can take a very long time to finish. I sometimes have sequences of 10k inputs (I'm experimenting with stuff...) which means 10k*10k calls taking each ~0.0002 seconds which is roughly 5.5 hours only to do clearState() before saving the model to disk.

Because nn.LSTM manages its clones internally and is contained inside the nn.Sequential, the same clones are being cleared again and again as I explained at the beginning. Is there a way I could clear those LSTMs only once effectively reducing the number of calls from O(n^2) to O(n)?

@FabbyD
Copy link
Author

FabbyD commented Feb 20, 2017

I just found out that this actually seems to happen for a few other methods that iterate through clones with AbstractRecurrent:includingSharedClones(f). In order to reproduce this issue, all you need is to wrap a nn.Container containing any nn.AbstractRecurrent module with nn.Recursor.

@FabbyD
Copy link
Author

FabbyD commented Mar 6, 2017

Actually, why is clearState not removing all clones? After all, clones are mostly used to keep those extra output and gradInput buffers. Calling clearState IMO means: "Remove all buffers from memory. I will reallocate afterwards if necessary."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant