We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi, the softmax backward with mode SOFTMAX_LOG is wrong, please double check the formula.
My cuda implementation for this mode is as belows; can use this as a formula reference.
__global__ void SoftmaxBwdPerInstance(const float* top_diff, const float* top_dat, float* bot_diff, const int channel, const int height, const int width, const int spatial_dim) { __shared__ float mid_result[256]; int batch_id = blockIdx.x; int tid = threadIdx.x; //in one loop each threadblock will work on one channel float channel_dot = 0.f; for(int i = tid; i < channel; i+= blockDim.x) { for(int h = 0; h < height; h++) { for(int w = 0; w < width; w++) { float value = top_diff[(batch_id * channel + i) * spatial_dim + h * width + w]; channel_dot = float_sum(channel_dot, value); } } } mid_result[tid] = channel_dot; __syncthreads(); for(int i = (blockDim.x >> 1); i > 0; i >>= 1) { if(tid < i) { mid_result[tid] += mid_result[tid + i]; } __syncthreads(); } channel_dot = mid_result[0]; for(int i = tid; i < channel; i+= blockDim.x) { for(int h = 0; h < height; h++) { for(int w = 0; w < width; w++) { int s = h * width + w; bot_diff[(batch_id * channel + i) * spatial_dim + s] = top_diff[(batch_id * channel + i) * spatial_dim + s] - channel_dot * expf(top_dat[(batch_id * channel + i) * spatial_dim + h * width + w]); } } }
The text was updated successfully, but these errors were encountered:
Thank you for your feedback! We will address it in next release.
Sorry, something went wrong.
Resolved in 2.1.0
ce1adon
No branches or pull requests
Hi, the softmax backward with mode SOFTMAX_LOG is wrong, please double check the formula.
My cuda implementation for this mode is as belows; can use this as a formula reference.
The text was updated successfully, but these errors were encountered: