Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataParallelTable with Element-Research RNN library #233

Open
mxh opened this issue Mar 5, 2016 · 3 comments
Open

DataParallelTable with Element-Research RNN library #233

mxh opened this issue Mar 5, 2016 · 3 comments

Comments

@mxh
Copy link

mxh commented Mar 5, 2016

I was trying to use a DataParallelTable with an RNN from https://github.com/Element-Research/rnn to train on 4 GPUs. This didn't work out of the box, as rnn adds nn.Module:forget, which recursively calls this on all submodules.

Without DataParallelTable, to make a memory module forget its memory parameters, you would call model:forget. With DataParallelTable, this will call :forget on self.modules[1], which does not get transmitted to the GPU. Instead, I worked around this by calling model.impl:exec(function(m) m:forget() end).

This might be useful to someone else. Is this something you want to include either in documentation or make explicit?

@soumith
Copy link
Member

soumith commented Mar 5, 2016

Since :forget is not a standard nn.Module API, this is exactly what we would expect as behavior.
@nicholas-leonard can you figure this out and see if things can be made smoother?

@nicholas-leonard
Copy link
Member

@mxh Could you make a Pull Request to rnn to add this functionality to the rnn library? A unit test would also be nice as I have little experience with DataParallelTable. I can help you with the details.

@achalddave
Copy link

What is the right way to implement new methods for nn.DataParallelTable? The rnn library defines the forget method (and others) here: https://github.com/Element-Research/rnn/blob/ef98a97b16f55f598830293435af32b509ffc5bd/Module.lua#L4 . Is the right thing to check if the module is of type nn.DataParallelTable, and call .impl:exec instead of the for loop? Or is there a simpler way that works for both normal Containers and DataParallelTables?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants