Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU support in GPUPredictor. #3738

Merged
merged 8 commits into from
Oct 24, 2018
Merged

Conversation

canonizer
Copy link
Contributor

@canonizer canonizer commented Sep 28, 2018

  • GPUPredictor is multi-GPU
  • removed DeviceMatrix, as it has been made obsolete by using HostDeviceVector in DMatrix

Closes #3756

- GPUPredictor is multi-GPU
- removed DeviceMatrix, as it has been made obsolete by using HostDeviceVector in DMatrix
auto& offsets = *out_offsets;
offsets.resize(devices.Size() + 1);
offsets[0] = 0;
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for dropping by. Just to be safe, it might be better to save the current device before spawning threads that can change it. Otherwise subsequent code could potentially access memory at wrong device.

class SaveCudaContext {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cudaSetDevice() changes the device of the caller thread only, so it does not alter the device in a different thread.

In case only a single device is used (and no threads are spawned), cudaSetDevice() will be called with that device.

Otherwise, the code using a GPU is responsible for setting the device being used. This means cudaSetDevice() in public methods, in shards and in OpenMP loops (with 1 iteration per device). Private methods can assume that the right device has been set already.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just had a discussion with @trivialfis about this. My view is that @canonizer 's approach is fine for now, although it would be good to look at better ways of managing the global multi-GPU state.

We have already had one difficult to find bug because the active device was not what was expected. Perhaps there is a way to manage this so we can explicitly prevent kernels being called with the incorrect active device index.

@hcho3
Copy link
Collaborator

hcho3 commented Sep 28, 2018

@canonizer Can you add a test for multi-GPU prediction? I am about to add a multi-GPU slave worker to the Jenkins CI server. The multi-GPU tests will run as a separate task than single-GPU tests.

Copy link
Member

@RAMitchell RAMitchell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned by @hcho3 would be good if you can add explicit multi-GPU testing. We now have multi-GPU machines running on Jenkins.

You will also need to rebase this due to some minor conflicts with my recent dmatrix changes.

@@ -143,19 +100,21 @@ struct DevicePredictionNode {

struct ElementLoader {
bool use_shared;
size_t* d_row_ptr;
Entry* d_data;
const size_t* d_row_ptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to use span instead of raw pointers here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto begin_ptr = d_data + d_row_ptr[ridx];
auto end_ptr = d_data + d_row_ptr[ridx + 1];
Entry* previous_middle = nullptr;
auto begin_ptr = d_data + d_row_ptr[ridx] - entry_start;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we generalise our other binary search code and use that here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but probably in a different pull request.

If you want it in this pull request, feel free to comment on this, and I'll get back to it on Thursday.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to rush into cleaning up. We can do it later.

@@ -225,14 +184,15 @@ __device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree,
template <int BLOCK_THREADS>
__global__ void PredictKernel(const DevicePredictionNode* d_nodes,
float* d_out_predictions, size_t* d_tree_segments,
int* d_tree_group, size_t* d_row_ptr,
Entry* d_data, size_t tree_begin,
int* d_tree_group, const size_t* d_row_ptr,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might as well upgrade all of these raw pointers to spans.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto& offsets = *out_offsets;
offsets.resize(devices.Size() + 1);
offsets[0] = 0;
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just had a discussion with @trivialfis about this. My view is that @canonizer 's approach is fine for now, although it would be good to look at better ways of managing the global multi-GPU state.

We have already had one difficult to find bug because the active device was not what was expected. Perhaps there is a way to manage this so we can explicitly prevent kernels being called with the incorrect active device index.

@hcho3
Copy link
Collaborator

hcho3 commented Oct 10, 2018

Any updates on multi-GPU tests?

@canonizer
Copy link
Contributor Author

Added a multi-GPU test for GPUPredictor and addressed reviewers' comments.

auto begin_ptr = d_data + d_row_ptr[ridx];
auto end_ptr = d_data + d_row_ptr[ridx + 1];
Entry* previous_middle = nullptr;
auto begin_ptr = d_data.begin() + d_row_ptr[ridx] - entry_start;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to

Suggested change
auto begin_ptr = d_data.begin() + d_row_ptr[ridx] - entry_start;
auto begin_ptr = d_data.begin() + (d_row_ptr[ridx] - entry_start);

should pass the multi-gpu test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can make this change yourself now that you are a member.

auto end_ptr = d_data + d_row_ptr[ridx + 1];
Entry* previous_middle = nullptr;
auto begin_ptr = d_data.begin() + d_row_ptr[ridx] - entry_start;
auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1] - entry_start;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1] - entry_start;
auto end_ptr = d_data.begin() + (d_row_ptr[ridx + 1] - entry_start);

And this one. :)

@trivialfis
Copy link
Member

Okay, I fixed a bug that's caught by my profiling script. But I can't reproduce the one on Jenkins.

@hcho3
Copy link
Collaborator

hcho3 commented Oct 23, 2018

It may be out of memory error? Let me run it on my end using the same instance type.

@trivialfis
Copy link
Member

@hcho3 Thanks! I ran cuda-memcheck and Sanitizer with no luck so far. The test requires very small amount of memory.

@hcho3
Copy link
Collaborator

hcho3 commented Oct 23, 2018

I compiled and ran this pull request on my p2.8xlarge instance and got the same error. I will run it through cuda-gdb to see if it helps.

@hcho3
Copy link
Collaborator

hcho3 commented Oct 23, 2018

@trivialfis @canonizer I got this backtrace by running testxgboost through gdb:

#0  __memmove_avx_unaligned () at ../sysdeps/x86_64/multiarch/memcpy-avx-unaligned.S:136
#1  0x00007ffff5958bdf in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007ffff5a123ee in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007ffff5a126ed in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007ffff5a139b6 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#5  0x00007ffff59256be in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#6  0x00007ffff59259d8 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#7  0x00007ffff5a761d5 in cuMemcpy () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#8  0x000000000088f022 in cudart::driverHelper::memcpyDispatch(void*, void const*, unsigned long, cudaMemcpyKind, bool) ()
#9  0x000000000086f896 in cudart::cudaApiMemcpy(void*, void const*, unsigned long, cudaMemcpyKind) ()
#10 0x0000000000891ed8 in cudaMemcpy ()
#11 0x0000000000779b52 in xgboost::predictor::GPUPredictor::DeviceOffsets () at /home/ubuntu/xgboost/src/predictor/gpu_predictor.cu:234
#12 0x00007ffff712d43e in ?? () from /usr/lib/x86_64-linux-gnu/libgomp.so.1
#13 0x00007ffff6cf26ba in start_thread (arg=0x7ffff25d7700) at pthread_create.c:333
#14 0x00007ffff6a2841d in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:109

Maybe this line is problematic?

// copy the last element from every shard
dh::safe_cuda(cudaMemcpy(&offsets[shard + 1],
data.DevicePointer(device) + data.DeviceSize(device) - 1,
sizeof(size_t), cudaMemcpyDefault));

@hcho3
Copy link
Collaborator

hcho3 commented Oct 23, 2018

I added some diagnostic printout before the cudaMemcpy line:

[08:44:06] /home/ubuntu/xgboost/src/predictor/gpu_predictor.cu:233: cudaMemcpy(&offsets[0 + 1], 0x7c136c0200 + 2 - 1, sizeof(size_t), cudaMemcpyDefault));
[08:44:06] /home/ubuntu/xgboost/src/predictor/gpu_predictor.cu:233: cudaMemcpy(&offsets[1 + 1], 0x7c13ec0200 + 2 - 1, sizeof(size_t), cudaMemcpyDefault));
[08:44:06] /home/ubuntu/xgboost/src/predictor/gpu_predictor.cu:233: cudaMemcpy(&offsets[2 + 1], 0x7c13ac0200 + 2 - 1, sizeof(size_t), cudaMemcpyDefault));
[08:44:06] /home/ubuntu/xgboost/src/predictor/gpu_predictor.cu:233: cudaMemcpy(&offsets[3 + 1], 0x7c132c0200 + 2 - 1, sizeof(size_t), cudaMemcpyDefault));
[08:44:06] /home/ubuntu/xgboost/src/predictor/gpu_predictor.cu:233: cudaMemcpy(&offsets[4 + 1], 0x7c142c0200 + 2 - 1, sizeof(size_t), cudaMemcpyDefault));
[08:44:06] /home/ubuntu/xgboost/src/predictor/gpu_predictor.cu:233: cudaMemcpy(&offsets[5 + 1], 0x7c18380000 + 1 - 1, sizeof(size_t), cudaMemcpyDefault));
[08:44:07] /home/ubuntu/xgboost/src/predictor/gpu_predictor.cu:233: cudaMemcpy(&offsets[6 + 1], 0 + 0 - 1, sizeof(size_t), cudaMemcpyDefault));

The last cudaMemcpy call fails because the device pointer for shard 6 is null.

model.CommitModel(std::move(trees), 0);
model.param.num_output_group = 1;

int n_row = 5;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this line to

Suggested change
int n_row = 5;
int n_row = 8;

gets rid of segmentation fault.

The p2.8xlarge instance has 8 GPUs, so with the matrix with 5 rows, some of the GPUs were getting 0 row. We should handle this edge case either by restricting the number of devices when too few rows are given, or by correctly handling zero-row shards.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis How should we handle this edge case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hcho3 cuda-gdb doesn't work for me with NCCL, neither on Fedora nor on Ubuntu.. :(
The GPUSet::All() has an optional parameter specifying number of rows.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

* Reinitialize shards when GPUSet is changed.
* Tests range of data.
@trivialfis
Copy link
Member

@hcho3 This again comes down to changing parameters. We need to handle the situation when n_gpus is limited hence changed by the n_rows of input data.

@trivialfis
Copy link
Member

@canonizer , @hcho3 , @RAMitchell I tried to overcome it with a check to see if GPUSet is changed, if so all DeviceShards are re-built. It might not be a nice solution, suggestions are welcomed.

@codecov-io
Copy link

codecov-io commented Oct 23, 2018

Codecov Report

Merging #3738 into master will decrease coverage by 0.03%.
The diff coverage is 60%.

Impacted file tree graph

@@             Coverage Diff              @@
##             master    #3738      +/-   ##
============================================
- Coverage     52.09%   52.06%   -0.04%     
- Complexity      196      203       +7     
============================================
  Files           181      181              
  Lines         14341    14358      +17     
  Branches        489      495       +6     
============================================
+ Hits           7471     7475       +4     
- Misses         6636     6645       +9     
- Partials        234      238       +4
Impacted Files Coverage Δ Complexity Δ
src/common/span.h 98.63% <ø> (ø) 0 <0> (ø) ⬇️
src/gbm/gbtree.cc 18.67% <0%> (ø) 0 <0> (ø) ⬇️
src/objective/multiclass_obj.cu 93.75% <100%> (+0.41%) 0 <0> (ø) ⬇️
src/objective/hinge.cu 82.35% <100%> (ø) 0 <0> (ø) ⬇️
src/objective/regression_obj.cu 87.46% <100%> (ø) 0 <0> (ø) ⬇️
src/common/host_device_vector.h 75% <0%> (-5%) 0% <0%> (ø)
.../src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java 84.13% <0%> (-2.59%) 41% <0%> (+7%)
src/common/host_device_vector.cc 63.88% <0%> (-1.39%) 0% <0%> (ø)
...oost4j/scala/spark/params/LearningTaskParams.scala 81.08% <0%> (-1.28%) 0% <0%> (ø)
.../scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala 74.9% <0%> (-0.86%) 0% <0%> (ø)
... and 3 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update abf2f66...96fe214. Read the comment docs.

Copy link
Member

@RAMitchell RAMitchell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@trivialfis
Copy link
Member

@hcho3 Hi, could you give another look before I merge it?

out_gpair->Reshard(GPUSet::Empty());
preds.Reshard(GPUSet::Empty());
// out_gpair->Reshard(GPUSet::Empty());
// preds.Reshard(GPUSet::Empty());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we commenting out lines? If they are not needed, we should just remove them

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

TEST(gpu_predictor, Test) {
std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor"));
std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"));

// gpu_predictor->Init({std::pair<std::string, std::string>("n_gpus", "1")}, {});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank for cleaning up for me.

@hcho3
Copy link
Collaborator

hcho3 commented Oct 24, 2018

Thanks everyone!

@hcho3 hcho3 merged commit 2a59ff2 into dmlc:master Oct 24, 2018
CodingCat pushed a commit to CodingCat/xgboost that referenced this pull request Oct 25, 2018
* Multi-GPU support in GPUPredictor.

- GPUPredictor is multi-GPU
- removed DeviceMatrix, as it has been made obsolete by using HostDeviceVector in DMatrix

* Replaced pointers with spans in GPUPredictor.

* Added a multi-GPU predictor test.

* Fix multi-gpu test.

* Fix n_rows < n_gpus.

* Reinitialize shards when GPUSet is changed.
* Tests range of data.

* Remove commented code.

* Remove commented code.
alois-bissuel pushed a commit to criteo-forks/xgboost that referenced this pull request Dec 4, 2018
* Multi-GPU support in GPUPredictor.

- GPUPredictor is multi-GPU
- removed DeviceMatrix, as it has been made obsolete by using HostDeviceVector in DMatrix

* Replaced pointers with spans in GPUPredictor.

* Added a multi-GPU predictor test.

* Fix multi-gpu test.

* Fix n_rows < n_gpus.

* Reinitialize shards when GPUSet is changed.
* Tests range of data.

* Remove commented code.

* Remove commented code.
@lock lock bot locked as resolved and limited conversation to collaborators Jan 22, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants