Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix softmax block forward with small element size #14475

Merged
merged 5 commits into from
Feb 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 83 additions & 88 deletions onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@

/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// The code below is mostly copied from Pytorch SoftMax.cuh

Expand All @@ -23,7 +23,6 @@
namespace onnxruntime {
namespace cuda {

constexpr int ALIGN_BYTES = 16;
const int max_threads = 1024;

dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
Expand All @@ -45,33 +44,28 @@ dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
return dim3(static_cast<unsigned int>(block_size));
}


////////////////////////////////////////////////////////////////////////////////
// Regular kernel (fast when dim_size is large; requires inner_size == 1)
////////////////////////////////////////////////////////////////////////////////


template <typename T, typename AccumT>
struct MaxFloat
{
struct MaxFloat {
__device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
return ::max(max, (AccumT)v);
}
};

template<typename T, typename AccumT>
struct AddFloat
{
template <typename T, typename AccumT>
struct AddFloat {
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + (AccumT)v;
}
};

template<typename T, typename AccumT>
struct SumExpFloat
{
template <typename T, typename AccumT>
struct SumExpFloat {
__device__ __forceinline__ SumExpFloat(AccumT v)
: max_k(v) {}
: max_k(v) {}

__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + std::exp((AccumT)v - max_k);
Expand All @@ -80,12 +74,23 @@ struct SumExpFloat
const AccumT max_k;
};

template <template<typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT
blockReduce(AccumT* smem, AccumT val,
const Reduction<AccumT>& r,
AccumT defaultVal)
{
// One block has N(warps_per_block) warps, one warp has M(WARP_SIZE) threads.
// 1. All the threads in one block read data into shared memory.
// 2. Reduce all data to the first warp. Only the threads of warp-0 are used. Each thread in warp-0 reads data from the
// same location of every warp and computes result. For example, thread-0 computes the first data of every warp and
// writes the result into the location of data0.
// Shared memory
// -----------------------------------------------------------------------------------------------------------------------
// | data0 | data1 | data2 | .... | dataM | ... | dataM*2 | ... |
// -----------------------------------------------------------------------------------------------------------------------
// | | | |
// -------------------warp-0----------------------------------warp-1----------------------------------warp-2--------------
// 3. Thread-0 reduces all data in warp-0 and writes the results into the location of data0, then return data0.

template <template <typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT blockReduce(AccumT* smem, AccumT val,
const Reduction<AccumT>& r,
AccumT defaultVal) {
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads();

Expand All @@ -96,19 +101,12 @@ blockReduce(AccumT* smem, AccumT val,
AccumT warpVal = defaultVal;

// First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / GPU_WARP_SIZE)) - 1;
if (threadIdx.x < GPU_WARP_SIZE) {
int lane = threadIdx.x % GPU_WARP_SIZE;
if (lane < blockDim.x / GPU_WARP_SIZE) {
#pragma unroll
for (int i = 0; i < GPU_WARP_SIZE; ++i) {
warpVal = r(warpVal, smem[lane * GPU_WARP_SIZE + i]);
}
#if !defined(USE_ROCM)
__syncwarp(mask);
#endif
smem[lane] = warpVal;
int warps_per_block = blockDim.x / GPU_WARP_SIZE;
for (int i = 0; i < warps_per_block; ++i) {
warpVal = r(warpVal, smem[i * GPU_WARP_SIZE + threadIdx.x]);
}
smem[threadIdx.x] = warpVal;
}

__syncthreads();
Expand All @@ -117,7 +115,8 @@ blockReduce(AccumT* smem, AccumT val,
AccumT blockVal = defaultVal;

if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / GPU_WARP_SIZE; ++i) {
#pragma unroll
for (int i = 0; i < GPU_WARP_SIZE; ++i) {
blockVal = r(blockVal, smem[i]);
}
smem[0] = blockVal;
Expand All @@ -128,29 +127,29 @@ blockReduce(AccumT* smem, AccumT val,
return smem[0];
}


template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
__device__ __forceinline__ AccumT
ilpReduce(int shift,
T* data,
int size,
const Reduction<T, AccumT>& r,
AccumT defaultVal)
{
template <template <typename, typename> class Reduction, int ILP, typename T, typename AccumT>
__device__ __forceinline__ AccumT ilpReduce(int shift,
T* data,
int size,
const Reduction<T, AccumT>& r,
AccumT defaultVal) {
using LoadT = aligned_vector<T, ILP>;
AccumT threadVal = defaultVal;
int offset = threadIdx.x;

// shift and do 1
if(shift > 0){
if (shift > 0) {
data -= shift;
size += shift;
if(threadIdx.x >= shift){
if (threadIdx.x >= shift && threadIdx.x < size) {
pengwa marked this conversation as resolved.
Show resolved Hide resolved
threadVal = r(threadVal, data[offset]);
}
size -= blockDim.x;
data += blockDim.x;
}

if (size <= 0) return threadVal;

int last = size % (ILP * blockDim.x);

T v[ILP];
Expand All @@ -176,14 +175,12 @@ ilpReduce(int shift,
/**
* This will apply the Epilogue with vectorized reads & writes when input & output have the same shift
*/
template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__device__ __forceinline__ void
WriteFpropResultsVectorized(
int size,
const int shift,
scalar_t *input,
outscalar_t *output,
Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
__device__ __forceinline__ void WriteFpropResultsVectorized(int size,
const int shift,
scalar_t* input,
outscalar_t* output,
Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
using LoadT = aligned_vector<scalar_t, ILP>;
using StoreT = aligned_vector<outscalar_t, ILP>;

Expand All @@ -195,14 +192,16 @@ WriteFpropResultsVectorized(
output -= shift;
size += shift;

if (threadIdx.x >= shift) {
if (threadIdx.x >= shift && threadIdx.x < size) {
output[offset] = epilogue(input[offset]);
}
size -= blockDim.x;
input += blockDim.x;
output += blockDim.x;
}

if (size <= 0) return;

const int last = size % (ILP * blockDim.x);

scalar_t in_v[ILP];
Expand All @@ -229,17 +228,14 @@ WriteFpropResultsVectorized(
}
}


/**
* This will apply the Epilogue with non-vectrorized reads & writes for the general case
*/
template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__device__ __forceinline__ void
WriteFpropResults(
int classes,
scalar_t *input,
outscalar_t *output,
Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
__device__ __forceinline__ void WriteFpropResults(int classes,
scalar_t* input,
outscalar_t* output,
Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
int offset = threadIdx.x;

int last = classes % (ILP * blockDim.x);
Expand All @@ -265,22 +261,22 @@ WriteFpropResults(
}

template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
template <typename, typename, typename> class Epilogue>
__global__ void
softmax_block_forward(outscalar_t* output, scalar_t* input, int classes, int input_stride, int output_stride) {
template <typename, typename, typename> class Epilogue>
__global__ void softmax_block_forward(outscalar_t* output, scalar_t* input, int classes,
int input_stride, int output_stride) {
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);

using LoadT = aligned_vector<scalar_t, ILP>;
using StoreT = aligned_vector<outscalar_t, ILP>;

// forward pointers to batch[blockIdx.x]
// each block handles a sample in the mini-batch
input += blockIdx.x * input_stride;
output += blockIdx.x * output_stride;

const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t);
const int input_align_bytes = ILP * sizeof(scalar_t);
const int output_align_bytes = ILP * sizeof(outscalar_t);

const int shift = ((uint64_t)input) % input_align_bytes / sizeof(scalar_t);
const int output_shift = ((uint64_t)output) % output_align_bytes / sizeof(outscalar_t);

// find the max
accscalar_t threadMax = ilpReduce<MaxFloat, ILP, scalar_t, accscalar_t>(
Expand All @@ -303,24 +299,23 @@ softmax_block_forward(outscalar_t* output, scalar_t* input, int classes, int inp
}
}

template<typename T, typename AccumT, typename OutT>
template <typename T, typename AccumT, typename OutT>
struct LogSoftMaxForwardEpilogue {
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
: max_input(max_input), logsum(std::log(sum)) {}
: max_input(max_input), logsum(std::log(sum)) {}

__device__ __forceinline__ OutT operator()(T input) const {
return static_cast<OutT>((AccumT)input - max_input - logsum);
}
}

const AccumT max_input;
const AccumT logsum;
};

template<typename T, typename AccumT, typename OutT>
template <typename T, typename AccumT, typename OutT>
struct SoftMaxForwardEpilogue {
__device__ __forceinline__ SoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
: max_input(max_input)
, sum(sum) {}
: max_input(max_input), sum(sum) {}

__device__ __forceinline__ OutT operator()(T input) const {
return static_cast<OutT>(std::exp((AccumT)input - max_input) / sum);
Expand All @@ -330,5 +325,5 @@ struct SoftMaxForwardEpilogue {
const AccumT sum;
};

}
}
} // namespace cuda
} // namespace onnxruntime