Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENHANCEMENT] Suggest to insert nvtx Range into cuda generated code. #443

Open
LeiWang1999 opened this issue Jun 30, 2022 · 0 comments
Open
Labels
enhancement New feature or request

Comments

@LeiWang1999
Copy link
Contributor

Motivation
I'm currently profiling the cuda code generated by nnfusion, to better profile the program, I suggest using NVIDIA Tools Extension, by inserting nvtxRangePush and nvtxRangePop into the start and the end of the function, we can observe which low-level kernel function that cudnn invoked.

// Node name:	Convolution_109
// Description:	Convolution
// Input:
//	- name: Parameter_108_0	type: float	shape: Shape{1, 3, 224, 224}
//	- name: Constant_0_0	type: float	shape: Shape{64, 3, 7, 7}
// Output:
//	- name: Convolution_109_0	type: float	shape: Shape{1, 64, 112, 112}
void Convolution_float_float_float_cuda_lib_Convolution_109(cudnnHandle_t cudnn_handle, float* input0, float* input1, float* input2, float* output0)
{
#ifdef _NVTX_RANGE_DEBUG_
    nvtxRangePush(__FUNCTION__);
#endif
    cudnnTensorDescriptor_t tensor_desc_0;
    CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&tensor_desc_0));
    CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(tensor_desc_0, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, 3, 224, 224));
    cudnnTensorDescriptor_t tensor_desc_1;
    CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&tensor_desc_1));
    CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(tensor_desc_1, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, 64, 112, 112));
    cudnnFilterDescriptor_t filter_desc;
    CUDNN_SAFE_CALL(cudnnCreateFilterDescriptor(&filter_desc));
    CUDNN_SAFE_CALL(cudnnSetFilter4dDescriptor(filter_desc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 64, 3, 7, 7));
    cudnnConvolutionDescriptor_t conv_desc;
    CUDNN_SAFE_CALL(cudnnCreateConvolutionDescriptor(&conv_desc));
    CUDNN_SAFE_CALL(cudnnSetConvolution2dDescriptor(conv_desc, 3, 3, 2, 2, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT));

    static bool selected_algo = false;
    static cudnnConvolutionFwdAlgo_t conv_fwd_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;

    if (!selected_algo) {
        int num_algos;
        int max_algos = 0;
        // cudnnGetConvolutionForwardAlgorithm_v7;
        CUDNN_SAFE_CALL(
            cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
        std::vector<cudnnConvolutionFwdAlgoPerf_t> results(max_algos);
        CUDNN_SAFE_CALL(cudnnFindConvolutionForwardAlgorithm(cudnn_handle,
                                                tensor_desc_0,
                                                filter_desc,
                                                conv_desc,
                                                tensor_desc_1,
                                                static_cast<int>(results.size()),
                                                &num_algos,
                                                results.data()));
        results.resize(num_algos);
        for (size_t i = 0; i != results.size(); ++i) {
            cudnnConvolutionFwdAlgoPerf_t const& result = results[i];
            if (result.status == CUDNN_STATUS_SUCCESS) {
                conv_fwd_algo = result.algo;
                break;
            }
        }
        selected_algo = true;
    }
    const float alpha = 1.0;
    const float beta = 0.0;
    static void *workspace_ptr_0 = NULL;
    static size_t workspace_size_in_bytes = 0;
    if (!workspace_ptr_0)
    {
        CUDNN_SAFE_CALL(cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle, tensor_desc_0, filter_desc, conv_desc, tensor_desc_1, conv_fwd_algo, &workspace_size_in_bytes));
        CUDA_SAFE_CALL(cudaMalloc(&workspace_ptr_0, workspace_size_in_bytes));
    }
    const float alpha2 = 0.0;

    cudnnTensorDescriptor_t bias_desc;
    CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&bias_desc));
    CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(bias_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, 64, 1, 1));

    cudnnActivationDescriptor_t relu_desc;
    CUDNN_SAFE_CALL(cudnnCreateActivationDescriptor(&relu_desc));
    CUDNN_SAFE_CALL(cudnnSetActivationDescriptor(relu_desc, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0));

    CUDNN_SAFE_CALL(cudnnConvolutionBiasActivationForward(cudnn_handle, &alpha, tensor_desc_0, input0, filter_desc, input1, conv_desc, conv_fwd_algo, workspace_ptr_0, workspace_size_in_bytes, &alpha2, tensor_desc_1, output0, bias_desc, input2, relu_desc, tensor_desc_1, output0));
    
    // CUDNN_SAFE_CALL(cudnnConvolutionForward(cudnn_handle, &alpha, tensor_desc_0, input0,filter_desc, input1, conv_desc, conv_fwd_algo, workspace_ptr_0, workspace_size_in_bytes, &beta, tensor_desc_1, output0));
    CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(bias_desc));
    CUDNN_SAFE_CALL(cudnnDestroyActivationDescriptor(relu_desc));
    CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(tensor_desc_0));
    CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(tensor_desc_1));
    CUDNN_SAFE_CALL(cudnnDestroyFilterDescriptor(filter_desc));
    CUDNN_SAFE_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
#ifdef _NVTX_RANGE_DEBUG_
    nvtxRangePop();
#endif
}

the output will contain kernels used by this op function:

==13690== NVTX result:
==13690==   Thread "<unnamed>" (id = 2074996736)
==13690==     Domain "<unnamed>"
==13690==       Range "Convolution_float_float_float_cuda_lib_Convolution_109"
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
          Range:  100.00%  897.61ms       105  8.5487ms  47.299us  891.85ms  Convolution_float_float_float_cuda_lib_Convolution_109
 GPU activities:   68.65%  5.0048ms       107  46.773us  46.271us  49.568us  void implicit_convolve_sgemm<float, float, int=128, int=5, int=5, int=3, int=3, int=3, int=1, bool=0, bool=1, bool=1>(int, int, int, float const *, int, float*, float const *, kernel_conv_params, __int64, int, float, float, int, float const *, float const *, bool, int, int)
                   12.93%  942.42us        81  11.634us  11.168us  13.152us  void gemv2N_kernel<int, int, float2, float2, float2, int=128, int=8, int=4, int=4, int=1, bool=0, cublasGemvParams<cublasGemvTensorStridedBatched<float2 const >, cublasGemvTensorStridedBatched<float2>, float2>>(float2 const )
                   10.88%  793.05us        81  9.7900us  9.1840us  14.688us  void fft2d_c2r_32x32<float, bool=1, bool=0, unsigned int=0, bool=0, bool=0>(float*, float2 const *, int, int, int, int, int, int, int, int, int, float, float, cudnn::reduced_divisor, bool, float*, float*, int2, int, int)
                    5.73%  418.05us        81  5.1610us  4.6080us  14.080us  void fft2d_r2c_32x32<float, bool=0, unsigned int=0, bool=0>(float2*, float const *, int, int, int, int, int, int, int, int, int, cudnn::reduced_divisor, bool, int2, int, int)
                    0.59%  42.752us         1  42.752us  42.752us  42.752us  void explicit_convolve_sgemm<float, int, int=1024, int=5, int=5, int=3, int=3, int=3, int=0, bool=0>(int, int, int, float const *, int, float const *, int, float*, kernel_conv_params, __int64, int, __int64, int, float, float, int, float const *, float const *)
                    0.53%  38.720us         1  38.720us  38.720us  38.720us  volta_scudnn_128x32_relu_medium_nn_v1
                    0.39%  28.511us         1  28.511us  28.511us  28.511us  void cudnn::cnn::im2col4d_kernel<float, long>(cudnn::cnn::im2col4d_params, cudnnConvolutionStruct, cudnnTensor4dStruct, float const *, cudnnTensor4dStruct*)
                    0.23%  16.448us         1  16.448us  16.448us  16.448us  void fft2d_r2c_32x32<float, bool=0, unsigned int=1, bool=1>(float2*, float const *, int, int, int, int, int, int, int, int, int, cudnn::reduced_divisor, bool, int2, int, int)
                    0.07%  5.4400us         1  5.4400us  5.4400us  5.4400us  cask_cudnn::computeOffsetsKernel(cask_cudnn::ComputeOffsetsParams)
      API calls:  100.00%  2.4452ms       355  6.8880us  4.0000us  57.998us  cudaLaunchKernel
@LeiWang1999 LeiWang1999 added the enhancement New feature or request label Jun 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant