Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a batch norm inference kernel. #3309

Merged
merged 6 commits into from
Aug 7, 2017

Conversation

qingqing01
Copy link
Contributor

Fix #929

@qingqing01
Copy link
Contributor Author

qingqing01 commented Aug 7, 2017

cudnn lib有bug,在cudnn 5.1上 n > 1024时出错,可以使用下面代码验证:

#include <cuda.h>
#include <cudnn.h>
#include <iostream>
#include <sstream>
#include <fstream>

#include <stdio.h>

#define TOSTR_(s)   #s
#define TOSTR(s)    TOSTR_(s)
#define CUDNN_VERSION_STR  TOSTR(CUDNN_MAJOR) "." TOSTR (CUDNN_MINOR) "." TOSTR(CUDNN_PATCHLEVEL)

#define FatalError(s) {                                                \
    std::stringstream _where, _message;                                \
    _where << __FILE__ << ':' << __LINE__;                             \
    _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__;\
    std::cerr << _message.str() << "\nAborting...\n";                  \
    cudaDeviceReset();                                                 \
    exit(EXIT_FAILURE);                                                \
}

#define checkCUDNN(status) {                                           \
    std::stringstream _error;                                          \
    if (status != CUDNN_STATUS_SUCCESS) {                              \
      _error << "CUDNN failure\nError: " << cudnnGetErrorString(status); \
      FatalError(_error.str());                                        \
    }                                                                  \
}

#define checkCUDA(status) {                                      \
    std::stringstream _error;                                          \
    if (status != 0) {                                                 \
      _error << "Cuda failure\nError: " << cudaGetErrorString(status); \
      FatalError(_error.str());                                        \
    }                                                                  \
}



#include <sys/time.h>
#include <unistd.h>


void create(float** h_v, float** d_v, int n) {
  *h_v = (float *)malloc(n * sizeof(float));
  checkCUDA(cudaMalloc(d_v, n * sizeof(float)));
  for(int i = 0; i < n; i++)
    (*h_v)[i] = 1.0f;
  checkCUDA(cudaMemcpy(*d_v, *h_v, n * sizeof(float), cudaMemcpyHostToDevice));
}

int main(int argc, char *argv[]) {   

  int version = (int)cudnnGetVersion();
  printf("cudnnGetVersion() : %d , CUDNN_VERSION from cudnn.h : %d (%s)\n",
      version, CUDNN_VERSION, CUDNN_VERSION_STR);
  cudaSetDevice(0);

  /* input dim */
  int n, c, h, w;
  n = 1025;
  c = 512;
  h = 1;
  w = 1;

  /* Handles */
  cudnnHandle_t cudnnHandle;
  cudnnTensorDescriptor_t ioDesc, bnDesc;
  
  /* Create Handles and Descriptor*/
  checkCUDNN( cudnnCreate(&cudnnHandle));
  checkCUDNN( cudnnCreateTensorDescriptor(&ioDesc));
  checkCUDNN( cudnnCreateTensorDescriptor(&bnDesc));

  /* some constants */
  cudnnDataType_t dataType = CUDNN_DATA_FLOAT;
  cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;

  /* initilize input and output buffers */
  float* h_input;
  float* d_input;
  float* h_scale;
  float* d_scale;
  float* h_bias;
  float* d_bias;
  float* h_estimated_mean;
  float* d_estimated_mean;
  float* h_estimated_var;
  float* d_estimated_var;

  float* h_output;
  float* d_output;

  create(&h_input, &d_input, n * c * h * w);
  create(&h_output, &d_output, n * c * h * w);
  create(&h_scale, &d_scale, c);
  create(&h_bias, &d_bias, c);
  create(&h_estimated_mean, &d_estimated_mean, c);
  create(&h_estimated_var, &d_estimated_var, c);

  /* initilize handles */
  const int stride_w = 1;
  const int stride_h = w * stride_w;
  const int stride_c = h * stride_h;
  const int stride_n = c * stride_c;

  printf("set cudnn tensor\n");
  checkCUDNN(cudnnSetTensor4dDescriptorEx(ioDesc, dataType, n, c,
      h, w, stride_n, stride_c, stride_h, stride_w));
  checkCUDNN(cudnnSetTensor4dDescriptorEx(bnDesc, dataType, 1, c,
      1, 1, c, 1, 1, 1));

  float alpha, beta;
  alpha = 1.0f;
  beta = 0.0f;
  double epsilon = 1E-5;

  checkCUDNN(cudnnBatchNormalizationForwardInference(cudnnHandle,
                                          mode,
                                          &alpha,
                                          &beta,
                                          ioDesc,
                                          d_input,
                                          ioDesc,
                                          d_output,
                                          bnDesc,
                                          d_scale,
                                          d_bias,
                                          d_estimated_mean,
                                          d_estimated_var,
                                          epsilon));
  checkCUDA(cudaMemcpy(h_output, d_output, (n * c * h * w) * sizeof(float),
      cudaMemcpyDeviceToHost));

  free(h_input);
  free(h_output);
  free(h_scale);
  free(h_bias);
  free(h_estimated_mean);
  free(h_estimated_var);
  checkCUDA(cudaFree(d_input));
  checkCUDA(cudaFree(d_output));
  checkCUDA(cudaFree(d_scale));
  checkCUDA(cudaFree(d_bias));
  checkCUDA(cudaFree(d_estimated_mean));
  checkCUDA(cudaFree(d_estimated_var));

  /* Destroy Handles */
  checkCUDNN( cudnnDestroyTensorDescriptor(ioDesc) );
  checkCUDNN( cudnnDestroyTensorDescriptor(bnDesc) );
  checkCUDNN( cudnnDestroy(cudnnHandle) );
  return 0;
}

size_t height,
size_t width) {
dim3 block(256, 1);
dim3 grid(1, batchSize);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gird(batchSize, 1) is better,Maximum x-dimension is 2^32 - 1, Maximum y dimension is 65536.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

size_t channel,
size_t height,
size_t width) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blockIdx.x * blockDim.x can be removed, blockIdx.x is always equal 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const int num = channel * height * width;
const int batch = blockIdx.y;
for (int i = tid; i < num; i += blockDim.x) {
const int c = (i / (height * width)) % channel;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove % channel, i / (height * width) is smaller than the channel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

movingVar,
EPS);
if (batchSize > 1024) {
// there is a bug in cudnn library when the batch size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some places say this is a limitation of CUDNN, not bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify the comments.

Copy link
Contributor

@hedaoyuan hedaoyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wangkuiyi wangkuiyi merged commit 81c3136 into PaddlePaddle:develop Aug 7, 2017
@qingqing01 qingqing01 deleted the bn_infer branch March 7, 2018 12:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants