Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU #2870

Merged
merged 9 commits into from
Aug 13, 2015
Merged

Multi-GPU #2870

merged 9 commits into from
Aug 13, 2015

Conversation

shelhamer
Copy link
Member

This is my packaging of #2114 for merge. I figured @cypof @thatguymike and company had made plenty of revisions and that I could help.

This PR is ready to use for data parallel training of networks but

  • parallel IO is only coordinated for lmdb / leveldb through a DataReader
  • it cannot be satisfactorily tested with the current design

which are resolved for merge by #2903.

@cypof @thatguymike @longjon @jeffdonahue please take a look.
@cdoersch could you fire up your parallel training test again?

Reviews and testing by the community are welcome!

@thatguymike
Copy link
Contributor

All of my quick sanity tests are passing. Despite knowing that by default this is weak scaling, e.g. the specified batch size in the train_val.prototxt is multiplied by the number of GPUs you choose to run on, I forgot that when validating accuracy graphs. I still fear that is going to bite users.

@shelhamer shelhamer force-pushed the multi_gpu branch 2 times, most recently from ba35568 to e46996b Compare August 6, 2015 21:32
@cypof
Copy link
Member

cypof commented Aug 6, 2015

OK, training works for me. The thread launch code is much better without the fields, that's great.

@shelhamer
Copy link
Member Author

Thanks for testing @thatguymike and @cypof. My short test worked so once we hear from @cdoersch about the ec2 test I think this is ready to merge.

}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start(&timer, &timing);
}
const bool display = param_.display() && iter_ % param_.display() == 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

must add 'timer.Start();' here to restart timer, or line 266 that timing for grads maybe incorrect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to line 224 before forward + backward, thanks.

@cdoersch
Copy link
Contributor

cdoersch commented Aug 7, 2015

Training seems to be working fine on ec2.

@shelhamer shelhamer force-pushed the multi_gpu branch 2 times, most recently from f165d86 to 2b51a08 Compare August 7, 2015 21:28
@shelhamer
Copy link
Member Author

After discussion with @longjon we decided the timing code is too intrusive to bundle in this change. I have stripped it but archived the branch with timing at shelhamer/caffe:time-multi_gpu. It could be re-introduced in a future PR.

@@ -211,7 +228,9 @@ void Solver<Dtype>::Step(int iters) {
losses[idx] = loss;
}
if (display) {
LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss;
if (Caffe::root_solver()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a bit late to comment on this, and to me not necessary for merge, but these conditional LOG(INFO) calls could be made a bit more compact using LOG_IFs, e.g. LOG_IF(INFO, Caffe::root_solver()) << "Iteration..."


template<typename Dtype>
Params<Dtype>::Params(shared_ptr<Solver<Dtype> > root_solver)
: size_(total_size<Dtype>(root_solver->net()->params())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call to params() and the two other calls below should be replaced with learnable_params() after #2866, I think? (I was debating whether the public params() method should just be removed, or if params() should just return learnable_params_, or...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Making the switch seems to have no effect though and I have the same test failures before and after.

@shelhamer shelhamer force-pushed the multi_gpu branch 2 times, most recently from 80dbdaa to dd3e064 Compare August 8, 2015 03:02
@shelhamer
Copy link
Member Author

@cypof @thatguymike it turns out #2114 was not rigorously checking solver updates; see #2114 (comment). Fixing the test net targets reveals that all the LeastSquaresUpdate multi-GPU tests fail. I still need to peer at this more closely to see if the issue is with multi-GPU itself or a subtlety of RNG in the tests, but this is now blocked on figuring it out. That is, the multiple solvers could be drawing targets that are not equivalent to the sequential order.

Apart from the tests, my experiments to check parallel training on real nets make progress so there's hope.

#2866 is not the problem as the same failures show up in the multi-GPU branch before the latest rebase when the test is fixed. This can be seen in shelhamer/caffe:old-multi_gpu.

@shelhamer
Copy link
Member Author

I'm fairly positive this is a test artifact due to the random Gaussian targets. The multiple solvers can't reproduce random draws equivalent to the single solver sequence:

  • The worker solvers don't inherit the root seed in the current test since the seed is set in the singleton, and not the solver param (test_gradient_based_solver.cpp:183). They make their own nondeterministic draws.
  • If the test is rewritten to make the workers inherit the seed, they'll all draw the same targets and still be wrong w.r.t. to single solver equivalent.

The solution seems to be making the solver tests take fixed external data, such as the hdf5 data used in the sample_data.h5 used in the HDF5DataLayerTest. See #2887

@shelhamer shelhamer force-pushed the multi_gpu branch 2 times, most recently from bb75c36 to 186d453 Compare August 8, 2015 21:26
@shelhamer
Copy link
Member Author

This is now based on #2887 but the multi-GPU solver tests still fail. I believe this is because DataReader only knows lmdb + leveldb so the hdf5 inputs are identical among the WorkerSolvers. This should be checked, and then DataReader could be extended to hdf5. However this does raise issues with the DataReader design since layers that do / do not support it will have different behavior for parallelism.

cypof and others added 9 commits August 9, 2015 15:13
- Interrupt the thread before waiting on join
- Provide a method for looping threads to exit on demand
- CHECK if start and stop succeed instead of returning an error
- Make sure each solver accesses a different subset of the data
- Sequential reading of DB for performance
- Prefetch a configurable amount of data to host memory
- Distribute data to solvers in round-robin way for determinism
- Parallelize batches among GPUs and tree-reduce the gradients
- The effective batch size scales with the number of devices
- Batch size is multiplied by the number of devices
- Split batches between GPUs, and tree-reduce the gradients
- Detect machine topology (twin-GPU boards, P2P connectivity)
- Track device in syncedmem (thanks @thatguymike)
- Insert a callback in the solver for minimal code change
- Accept list for gpu flag of caffe tool, e.g. '-gpu 0,1' or '-gpu all'.
  Run on default GPU if no ID given.
- Add multi-GPU solver test
- Deterministic architecture for reproducible runs
- Start with distant nodes in broadcast
- Fix outside loop to loop for full tree depth
@cypof
Copy link
Member

cypof commented Aug 11, 2015

I was off yesterday, but looking at it now.

@shelhamer
Copy link
Member Author

Everyone see #2903 for the rigorously tested and passing multi-GPU branch. @ronghanghu has developed a parallel data layer solution.

@ronghanghu ronghanghu merged commit 8771d0f into BVLC:master Aug 13, 2015
@ronghanghu
Copy link
Member

Merged in #2903

hido added a commit to chainer/chainer that referenced this pull request Aug 25, 2015
PR for Multi-GPU has been merged into the master branch of Caffe.
BVLC/caffe#2870
@shelhamer shelhamer deleted the multi_gpu branch August 25, 2015 23:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants