Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine accuracy_op CUDA kernel #4097

Merged
merged 4 commits into from
Sep 18, 2017

Conversation

typhoonzero
Copy link
Contributor

Fix #4096

break;
}
}
}
*accuracy = static_cast<float>(correct) / static_cast<float>(N);
atomicAdd(&total, count);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOL. It seems too complicated to me. It looks correct but I really don't have enough experience to review. Please refer to @hedaoyuan review it, or write it base on some high-level library thrust.

@dzhwinter
Copy link
Contributor

If you are interested in thrust, I will give a demo later.

*accuracy = static_cast<float>(correct) / static_cast<float>(N);
atomicAdd(&total, count);
__syncthreads();
if (threadIdx.x == 0) {
Copy link
Contributor

@hedaoyuan hedaoyuan Sep 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that when num_samples is greater than 4096 there will be multiple blocks, and this result may be incorrect. I think can consider using only one block to calculate.
Try to avoid using atomicAdd, the __shared__ int total; can be replaced by __shared__ int total[block_size]; and add reduce at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much for the advise, I'll try to fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const int pred = Xdata[row * D + col];
if (pred == label) {
++correct;
__global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use template <int BlockSize> is better than use PADDLE_CUDA_NUM_THREADS macro.

Copy link
Contributor Author

@typhoonzero typhoonzero Sep 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer PADDLE_CUDA_NUM_THREADS, it's a constexpr.

If use template <int BlockSize> we have to pass BlockSize twice to when calling the kernel like: SomeKernel<BlockSize><<<1, BlockSize>>>(), and seems PADDLE_CUDA_NUM_THREADS can be always the same const value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I think we can open an issue to discuss and then write some CUDA development docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the PADDLE_CUDA_NUM_THREADS macro is also equivalent to using it twice.

__shared__ int total[PADDLE_CUDA_NUM_THREADS];

// support only 1 block
for (int i = threadIdx.x; i < (N); i += blockDim.x * gridDim.x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i += blockDim.x * gridDim.x -> i += BlockSize
Use BlockSize better than blockDim.x (BlockSize is the value at compile time).

@dzhwinter
Copy link
Contributor

I had tried to write a demo code to show how thrust use in that case. I find that it's also very painful to write it.

#include <thrust/copy.h>
#include <thrust/tuple.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/reduce.h>
#include <thrust/execution_policy.h>

#include <string.h>
#include <iostream>

void accuracy(int* data, int* label, int length, float* acc) {
  acc = .0;
  typedef thrust::device_vector<int>::iterator Iter;
  typedef thrust::tuple<Iter, Iter> IterPair;
  thrust::device_vector<int>  data_device(data, data+length);
  thrust::device_vector<int>  label_device(label, label+length);
  IterPair first = thrust::make_zip_iterator(thrust::make_tuple(data_device.begin(), label_device.begin()));
  IterPair last = thrust::make_zip_iterator(thrust::make_tuple(data_device.end(), label_device.end()));
  thrust::equal<thrust::tuple<int,int>> binary_op;
  thrust::device_vector<int>  correct(length);
  thrust::transform(first, last, correct.begin(), binary_op());

  int result = thrust::reduce(thrust::host, correct, correct+length);
  if (result != 0) acc = (float)result/correct.size();
}

related document zip_iterator

@typhoonzero
Copy link
Contributor Author

@dzhwinter Thanks for the very useful code! Both ways are OK, either use thrust or not, I also prefer using thrust::reduce it's very simple to use.

Copy link
Contributor

@hedaoyuan hedaoyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@typhoonzero typhoonzero merged commit 8580dce into PaddlePaddle:develop Sep 18, 2017
@typhoonzero typhoonzero deleted the refine_accuracy_op branch December 22, 2017 05:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants