Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support broadcast in fused_softmax kernel #8321

Merged
merged 14 commits into from
Jun 20, 2022
Merged
4 changes: 2 additions & 2 deletions oneflow/core/cuda/softmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ template<typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 1024) {
if (cols < 1024) {
return DispatchSoftmaxWarpImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols);
} else {
Expand Down Expand Up @@ -1288,7 +1288,7 @@ template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols <= 1024) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里好像不用改吧,我记得反向改了速度稍微不如之前的

a4ee3b1

我这里只改了前向。这个是实测的结果

if (cols < 1024) {
return DispatchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load_y, load_dy, store, rows, cols);
} else {
Expand Down
258 changes: 169 additions & 89 deletions oneflow/user/kernels/fused_scale_mask_softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,64 +16,88 @@ limitations under the License.
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/cuda/softmax.cuh"
#include "oneflow/core/ep/cuda/cuda_stream.h"

#include "oneflow/user/kernels/fused_scale_mask_softmax.cuh"
namespace oneflow {

template<typename SRC, typename DST>
struct ScaleMaskLoad {
ScaleMaskLoad(const SRC* src, const bool* mask, int64_t row_size, SRC fill, SRC scale)
: src(src), mask(mask), row_size(row_size), fill(fill), scale(scale) {}
template<int N>
__device__ void load(DST* dst, int64_t row, int64_t col) {
cuda::softmax::Pack<SRC, N> pack;
const int64_t offset = (row * row_size + col) / N;
pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(src) + offset);
cuda::softmax::Pack<bool, N> mask_pack;
mask_pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<bool, N>*>(mask) + offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
if (mask_pack.elem[i] == 0) {
dst[i] = static_cast<DST>(fill);
} else {
dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(scale);
}
}
}
const SRC* src;
const bool* mask;
int64_t row_size;
SRC fill;
SRC scale;
};
namespace {

template<typename SRC, typename DST>
struct ScaleMaskStore {
ScaleMaskStore(DST* dst, const bool* mask, int64_t row_size, DST fill, DST scale)
: dst(dst), mask(mask), row_size(row_size), fill(fill), scale(scale) {}
template<int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
cuda::softmax::Pack<DST, N> pack;
const int64_t offset = (row * row_size + col) / N;
cuda::softmax::Pack<bool, N> mask_pack;
mask_pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<bool, N>*>(mask) + offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
if (mask_pack.elem[i] == 0) {
pack.elem[i] = fill;
} else {
pack.elem[i] = static_cast<DST>(src[i]) * static_cast<DST>(scale);
}
}
*(reinterpret_cast<cuda::softmax::PackType<DST, N>*>(dst) + offset) = pack.storage;
}
DST* dst;
const bool* mask;
int64_t row_size;
DST fill;
DST scale;
};
template<typename T, typename ComputeType, typename MASK, size_t num_dims>
void LaunchBroadcastForwardKernel(cudaStream_t stream, const T* x, T* y, const MASK* mask,
const int64_t elem_cnt, const int64_t rows, const int64_t cols,
const float fill, const float scale, const int64_t* input_dims,
const int64_t* mask_dims) {
NdIndexOffsetHelper<int32_t, num_dims> input_index_helper(input_dims);
NdIndexOffsetHelper<int32_t, num_dims> mask_index_helper(mask_dims);
fused_scale_mask_softmax::BroadcastMaskSoftmaxParams<num_dims, int32_t> params;
params.src_index_helper = input_index_helper;
params.mask_index_helper = mask_index_helper;
params.mask_dims = mask_dims;
params.row_size = cols;
params.fill = fill;
params.scale = scale;
fused_scale_mask_softmax::BroadcastScaleMaskLoad<T, ComputeType, MASK, num_dims, int32_t> load(
x, mask, params);
cuda::softmax::DirectStore<ComputeType, T> store(y, cols);
OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(
stream, load, store, rows, cols)));
}

template<typename T>
template<typename T, typename ComputeType, typename MASK>
void LaunchElementwiseForwardKernel(cudaStream_t stream, const T* x, T* y, const MASK* mask,
const int64_t rows, const int64_t cols, const float fill,
const float scale) {
oneflow::fused_scale_mask_softmax::ElementwiseMaskSoftmaxParams params;
params.row_size = cols;
params.fill = fill;
params.scale = scale;
fused_scale_mask_softmax::ElementwiseScaleMaskLoad<T, ComputeType, MASK> load(x, mask, params);
cuda::softmax::DirectStore<ComputeType, T> store(y, cols);
OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(
stream, load, store, rows, cols)));
}

template<typename T, typename ComputeType, typename MASK, size_t num_dims>
void LaunchBroadcastBackwardKernel(cudaStream_t stream, const T* y, const T* dy, T* dx,
const MASK* mask, const int64_t elem_cnt, const int64_t rows,
const int64_t cols, const float fill, const float scale,
const int64_t* input_dims, const int64_t* mask_dims) {
NdIndexOffsetHelper<int32_t, num_dims> input_index_helper(input_dims);
NdIndexOffsetHelper<int32_t, num_dims> mask_index_helper(mask_dims);
fused_scale_mask_softmax::BroadcastMaskSoftmaxParams<num_dims, int32_t> params;
params.src_index_helper = input_index_helper;
params.mask_index_helper = mask_index_helper;
params.mask_dims = mask_dims;
params.row_size = cols;
params.fill = fill;
params.scale = scale;
cuda::softmax::DirectLoad<T, ComputeType> load_y(y, cols);
cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy, cols);
fused_scale_mask_softmax::BroadcastScaleMaskStore<ComputeType, T, MASK, num_dims, int32_t> store(
dx, mask, params);
OF_CUDA_CHECK((
cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy), decltype(store),
ComputeType>(stream, load_y, load_dy, store, rows, cols)));
}

template<typename T, typename ComputeType, typename MASK>
void LaunchElementwiseBackwardKernel(cudaStream_t stream, const T* y, const T* dy, T* dx,
const MASK* mask, const int64_t rows, const int64_t cols,
const float fill, const float scale) {
fused_scale_mask_softmax::ElementwiseMaskSoftmaxParams params;
params.row_size = cols;
params.fill = fill;
params.scale = scale;
cuda::softmax::DirectLoad<T, ComputeType> load_y(y, cols);
cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy, cols);
fused_scale_mask_softmax::ElementwiseScaleMaskStore<ComputeType, T, MASK> store(dx, mask, params);
OF_CUDA_CHECK((
cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy), decltype(store),
ComputeType>(stream, load_y, load_dy, store, rows, cols)));
}

constexpr int32_t kMaxNumDims = 5;

template<typename T, typename MASK>
class FusedScaleMaskSoftmaxKernel final : public user_op::OpKernel {
public:
FusedScaleMaskSoftmaxKernel() = default;
Expand All @@ -85,33 +109,49 @@ class FusedScaleMaskSoftmaxKernel final : public user_op::OpKernel {
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0);
user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
const float mask_fill_value = ctx->Attr<float>("mask_fill_value");
const float scale_value = ctx->Attr<float>("scale_value");
const ShapeView& x_shape = x->shape();
const ShapeView& mask_shape = mask->shape();
CHECK_GE(x_shape.NumAxes(), 2);
const int64_t elem_cnt = x_shape.elem_cnt();
const int64_t cols = x_shape.At(x_shape.NumAxes() - 1);
const int64_t rows = x_shape.Count(0, x_shape.NumAxes() - 1);
const size_t num_input_dims = x_shape.NumAxes();
const int64_t* input_dims = x_shape.ptr();
const size_t num_mask_dims = mask_shape.NumAxes();
const int64_t* mask_dims = mask_shape.ptr();
using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;
ScaleMaskLoad<T, ComputeType> load(x->dptr<T>(), mask->dptr<bool>(), cols,
ctx->Attr<float>("mask_fill_value"),
ctx->Attr<float>("scale_value"));
cuda::softmax::DirectStore<ComputeType, T> store(y->mut_dptr<T>(), cols);
OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(
ctx->stream()->As<ep::CudaStream>()->cuda_stream(), load, store, rows, cols)));

size_t simplified_num_dims = 0;
int64_t simplified_input_dims[kMaxNumDims];
int64_t simplified_mask_dims[kMaxNumDims];
fused_scale_mask_softmax::SimplifyBroadcastDims(num_input_dims, input_dims, num_mask_dims,
mask_dims, &simplified_num_dims,
simplified_input_dims, simplified_mask_dims);
if (simplified_num_dims == 1) {
LaunchElementwiseForwardKernel<T, ComputeType, MASK>(
ctx->stream()->As<ep::CudaStream>()->cuda_stream(), x->dptr<T>(), y->mut_dptr<T>(),
mask->dptr<MASK>(), rows, cols, mask_fill_value, scale_value);
}
#define DEFINE_ONE_ELIF(dims) \
else if (simplified_num_dims == dims) { \
LaunchBroadcastForwardKernel<T, ComputeType, MASK, dims>( \
ctx->stream()->As<ep::CudaStream>()->cuda_stream(), x->dptr<T>(), y->mut_dptr<T>(), \
mask->dptr<MASK>(), elem_cnt, rows, cols, mask_fill_value, scale_value, \
simplified_input_dims, simplified_mask_dims); \
}
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
#undef DEFINE_ONE_ELIF
else {
UNIMPLEMENTED();
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_FUCED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("fused_scale_mask_softmax") \
.SetCreateFn<FusedScaleMaskSoftmaxKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("y", 0) == GetDataType<dtype>::value));

REGISTER_FUCED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(half)
REGISTER_FUCED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(float)
REGISTER_FUCED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(double)
#undef REGISTER_FUCED_SCALE_MASK_SOFTMAX_CUDA_KERNEL

template<typename T>
template<typename T, typename MASK>
class FusedScaleMaskSoftmaxGradKernel final : public user_op::OpKernel {
public:
FusedScaleMaskSoftmaxGradKernel() = default;
Expand All @@ -124,31 +164,71 @@ class FusedScaleMaskSoftmaxGradKernel final : public user_op::OpKernel {
const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0);
user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
const float scale_value = ctx->Attr<float>("scale_value");
const float mask_fill_value = static_cast<float>(0.0);
const ShapeView& dy_shape = dy->shape();
const ShapeView& mask_shape = mask->shape();
CHECK_GE(dy_shape.NumAxes(), 2);
const int64_t elem_cnt = dy_shape.elem_cnt();
const int64_t cols = dy_shape.At(dy_shape.NumAxes() - 1);
const int64_t rows = dy_shape.Count(0, dy_shape.NumAxes() - 1);
const int64_t* input_dims = dy_shape.ptr();
const size_t num_input_dims = dy_shape.NumAxes();
const int64_t* mask_dims = mask_shape.ptr();
const size_t num_mask_dims = mask_shape.NumAxes();

using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;
cuda::softmax::DirectLoad<T, ComputeType> load_y(y->dptr<T>(), cols);
cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy->dptr<T>(), cols);
ScaleMaskStore<ComputeType, T> store(dx->mut_dptr<T>(), mask->dptr<bool>(), cols,
static_cast<T>(0.0), ctx->Attr<float>("scale_value"));
OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy),
decltype(store), ComputeType>(
ctx->stream()->As<ep::CudaStream>()->cuda_stream(), load_y, load_dy, store, rows, cols)));

size_t simplified_num_dims = 0;
int64_t simplified_input_dims[kMaxNumDims];
int64_t simplified_mask_dims[kMaxNumDims];
fused_scale_mask_softmax::SimplifyBroadcastDims(num_input_dims, input_dims, num_mask_dims,
mask_dims, &simplified_num_dims,
simplified_input_dims, simplified_mask_dims);
if (simplified_num_dims == 1) {
LaunchElementwiseBackwardKernel<T, ComputeType, MASK>(
ctx->stream()->As<ep::CudaStream>()->cuda_stream(), y->dptr<T>(), dy->dptr<T>(),
dx->mut_dptr<T>(), mask->dptr<MASK>(), rows, cols, mask_fill_value, scale_value);
}
#define DEFINE_ONE_ELIF(dims) \
else if (simplified_num_dims == dims) { \
LaunchBroadcastBackwardKernel<T, ComputeType, MASK, dims>( \
ctx->stream()->As<ep::CudaStream>()->cuda_stream(), y->dptr<T>(), dy->dptr<T>(), \
dx->mut_dptr<T>(), mask->dptr<MASK>(), elem_cnt, rows, cols, mask_fill_value, scale_value, \
simplified_input_dims, simplified_mask_dims); \
}
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
#undef DEFINE_ONE_ELIF
else {
UNIMPLEMENTED();
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_FUCED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(dtype) \
REGISTER_USER_KERNEL("fused_scale_mask_softmax_grad") \
.SetCreateFn<FusedScaleMaskSoftmaxGradKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
} // namespace

#define REGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(dtype, mask_dtype) \
REGISTER_USER_KERNEL("fused_scale_mask_softmax") \
.SetCreateFn<FusedScaleMaskSoftmaxKernel<dtype, mask_dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value) \
&& (user_op::HobDataType("mask", 0) == GetDataType<mask_dtype>::value));

REGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(half, bool)
REGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(float, bool)
#undef REGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL

#define REGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(dtype, mask_dtype) \
REGISTER_USER_KERNEL("fused_scale_mask_softmax_grad") \
.SetCreateFn<FusedScaleMaskSoftmaxGradKernel<dtype, mask_dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value) \
&& (user_op::HobDataType("mask", 0) == GetDataType<mask_dtype>::value));

REGISTER_FUCED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(half)
REGISTER_FUCED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(float)
REGISTER_FUCED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(double)
#undef REGISTER_FUCED_SCALE_MASK_SOFTMAX_GRAD_KERNEL
REGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(half, bool)
REGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(float, bool)
#undef REGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL

} // namespace oneflow
} // namespace oneflow
Loading