Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine accuracy_op CUDA kernel #4097

Merged
merged 4 commits into from
Sep 18, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions paddle/operators/accuracy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include "paddle/operators/accuracy_op.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;

__global__ void AccuracySingleKernel(const int N, const int D, const int top_k,
const int* Xdata, const int* labelData,
float* accuracy) {
int correct = 0;
for (int row = 0; row < N; row++) {
const int label = labelData[row];
for (int col = 0; col < D; col++) {
const int pred = Xdata[row * D + col];
if (pred == label) {
++correct;
template <int BlockSize>
__global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
const int* labeldata, float* accuracy) {
int count = 0;
__shared__ int total[BlockSize];

// support only 1 block
for (int i = threadIdx.x; i < (N); i += BlockSize) {
for (int j = 0; j < D; ++j) {
if (Xdata[i * D + j] == labeldata[i]) {
++count;
break;
}
}
}
*accuracy = static_cast<float>(correct) / static_cast<float>(N);
total[threadIdx.x] = count;
__syncthreads();

// reduce the count with init value 0, and output accuracy.
int result = thrust::reduce(thrust::device, total, total + BlockSize, 0);
if (threadIdx.x == 0) {
*accuracy = static_cast<float>(result) / static_cast<float>(N);
}
}

template <typename T>
Expand All @@ -57,8 +69,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return;
}

AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data,
label_data, accuracy_data);
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS>>>(
num_samples, infer_width, inference_data, label_data, accuracy_data);
}
};

Expand Down
5 changes: 5 additions & 0 deletions paddle/platform/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ namespace platform {
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }

// Default thread count per block(or block size).
// TODO(typhoonzero): need to benchmark against setting this value
// to 1024.
constexpr int PADDLE_CUDA_NUM_THREADS = 512;

// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);

Expand Down
9 changes: 5 additions & 4 deletions python/paddle/v2/framework/tests/test_accuracy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
class TestAccuracyOp(OpTest):
def setUp(self):
self.op_type = "accuracy"
infer = np.random.randint(0, 2, (32, 1)).astype("int")
label = np.random.randint(0, 2, (32, )).astype("int")
n = 8192
infer = np.random.randint(0, 2, (n, 1)).astype("int")
label = np.random.randint(0, 2, (n, )).astype("int")
self.inputs = {'Inference': infer, "Label": label}
num_correct = 0
for rowid in xrange(32):
for rowid in xrange(n):
for ele in infer[rowid]:
if ele == label[rowid]:
num_correct += 1
break
self.outputs = {'Accuracy': [num_correct / 32.0]}
self.outputs = {'Accuracy': [num_correct / float(n)]}

def test_check_output(self):
self.check_output()
Expand Down