-
Notifications
You must be signed in to change notification settings - Fork 313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AbstractRecurrent:clearState() clearing too much #395
Comments
I just found out that this actually seems to happen for a few other methods that iterate through clones with |
Actually, why is clearState not removing all clones? After all, clones are mostly used to keep those extra output and gradInput buffers. Calling clearState IMO means: "Remove all buffers from memory. I will reallocate afterwards if necessary." |
Example:
The last line will clear all 10 clones of
nn.Sequential
. Every one of these clones will also clear all thenn.LSTM
clones. Since thenn.LSTM
module takes care of its clones internally, we are clearing the same 10 clones 10 times.By adding a few prints in
AbstractRecurrent:clearState()
we get something like this:This might not seem like a big deal but that means #clones x #clones x 2 (there are 2 LSTMs in this example) calls to clearState. When dealing with longer sequences like documents, this can take a very long time to finish. I sometimes have sequences of 10k inputs (I'm experimenting with stuff...) which means 10k*10k calls taking each ~0.0002 seconds which is roughly 5.5 hours only to do clearState() before saving the model to disk.
Because
nn.LSTM
manages its clones internally and is contained inside thenn.Sequential
, the same clones are being cleared again and again as I explained at the beginning. Is there a way I could clear those LSTMs only once effectively reducing the number of calls from O(n^2) to O(n)?The text was updated successfully, but these errors were encountered: