diff --git a/.gitignore b/.gitignore index a4690ba442520..bc745ba02be03 100644 --- a/.gitignore +++ b/.gitignore @@ -96,3 +96,6 @@ paddle/phi/api/profiler/__init__.py python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py paddle/phi/kernels/fusion/cutlass/conv2d/generated/* python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py + +# these files are auto-generated by memory_efficient_fmha_variable +autogen* diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 45e4d763e6bd9..71decdfc47b1b 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -169,6 +169,8 @@ bool PaddleTensorToDenseTensor(const PaddleTensor &pt, input_ptr = t->mutable_data(ddim, place); } else if (pt.dtype == PaddleDType::FLOAT16) { input_ptr = t->mutable_data(ddim, place); + } else if (pt.dtype == PaddleDType::BFLOAT16) { + input_ptr = t->mutable_data(ddim, place); } else { LOG(ERROR) << "unsupported feed type " << pt.dtype; return false; @@ -1226,6 +1228,9 @@ bool AnalysisPredictor::GetFetch(std::vector *outputs, } else if (type == framework::proto::VarType::FP16) { GetFetchOne(fetch, output); output->dtype = PaddleDType::FLOAT16; + } else if (type == framework::proto::VarType::BF16) { + GetFetchOne(fetch, output); + output->dtype = PaddleDType::BFLOAT16; } else { LOG(ERROR) << "unknown type, only support float32, float16, int64 and " "int32 now."; @@ -1766,6 +1771,8 @@ AnalysisPredictor::GetInputTypes() { input_type[name] = paddle_infer::DataType::FLOAT32; } else if (dtype == paddle::framework::proto::VarType::FP16) { input_type[name] = paddle_infer::DataType::FLOAT16; + } else if (dtype == paddle::framework::proto::VarType::BF16) { + input_type[name] = paddle_infer::DataType::BFLOAT16; } else if (dtype == paddle::framework::proto::VarType::INT64) { input_type[name] = paddle_infer::DataType::INT64; } else if (dtype == paddle::framework::proto::VarType::INT32) { @@ -1819,6 +1826,8 @@ AnalysisPredictor::GetOutputTypes() { output_type[name] = paddle_infer::DataType::FLOAT32; } else if (dtype == paddle::framework::proto::VarType::FP16) { output_type[name] = paddle_infer::DataType::FLOAT16; + } else if (dtype == paddle::framework::proto::VarType::BF16) { + output_type[name] = paddle_infer::DataType::BFLOAT16; } else if (dtype == paddle::framework::proto::VarType::INT64) { output_type[name] = paddle_infer::DataType::INT64; } else if (dtype == paddle::framework::proto::VarType::INT32) { diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 83207a8bfd654..36d2537f20d40 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -31,7 +31,9 @@ #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/resource_manager.h" #include "paddle/fluid/platform/device/gpu/gpu_types.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/string/printf.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/dense_tensor.h" #ifdef PADDLE_WITH_TESTING #include @@ -39,6 +41,8 @@ #endif namespace paddle_infer { +using float16 = paddle::platform::float16; +using bfloat16 = phi::dtype::bfloat16; namespace experimental { class InternalUtils; }; diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc index 054b4668c4cc6..248c70bdbf603 100644 --- a/paddle/fluid/inference/api/api.cc +++ b/paddle/fluid/inference/api/api.cc @@ -28,6 +28,8 @@ int PaddleDtypeSize(PaddleDType dtype) { switch (dtype) { case PaddleDType::FLOAT32: return sizeof(float); + case PaddleDType::BFLOAT16: + return sizeof(uint16_t); case PaddleDType::INT64: return sizeof(int64_t); case PaddleDType::INT32: diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 0d5c8f98020a8..b86c490037979 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -223,6 +223,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, input_ptr = input.mutable_data(ddim, place_); } else if (inputs[i].dtype == PaddleDType::INT32) { input_ptr = input.mutable_data(ddim, place_); + } else if (inputs[i].dtype == PaddleDType::BFLOAT16) { + input_ptr = input.mutable_data(ddim, place_); } else { LOG(ERROR) << "unsupported feed type " << inputs[i].dtype; return false; diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 52204ff3658f4..d3ecf02a8cd8b 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -31,6 +31,7 @@ namespace paddle_infer { using float16 = paddle::platform::float16; +using bfloat16 = paddle::platform::bfloat16; void Tensor::Reshape(const std::vector &shape) { #ifdef PADDLE_WITH_ONNXRUNTIME @@ -178,6 +179,8 @@ DataType Tensor::type() const { return DataType::FLOAT32; } else if (type == paddle::framework::proto::VarType::FP16) { return DataType::FLOAT16; + } else if (type == paddle::framework::proto::VarType::BF16) { + return DataType::BFLOAT16; } else if (type == paddle::framework::proto::VarType::INT64) { return DataType::INT64; } else if (type == paddle::framework::proto::VarType::INT32) { @@ -289,6 +292,11 @@ struct DataTypeInfo { phi::DataType TYPE = phi::DataType::FLOAT16; }; +template <> +struct DataTypeInfo { + phi::DataType TYPE = phi::DataType::BFLOAT16; +}; + template <> struct DataTypeInfo { phi::DataType TYPE = phi::DataType::INT64; @@ -500,6 +508,7 @@ template PD_INFER_DECL void Tensor::CopyFromCpu(const int32_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const uint8_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const int8_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const float16 *data); +template PD_INFER_DECL void Tensor::CopyFromCpu(const bfloat16 *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const bool *data); template PD_INFER_DECL void Tensor::ShareExternalData( @@ -537,6 +546,11 @@ template PD_INFER_DECL void Tensor::ShareExternalData( const std::vector &shape, PlaceType place, DataLayout layout); +template PD_INFER_DECL void Tensor::ShareExternalData( + const bfloat16 *data, + const std::vector &shape, + PlaceType place, + DataLayout layout); template PD_INFER_DECL void Tensor::ShareExternalData( const bool *data, const std::vector &shape, @@ -550,6 +564,7 @@ template PD_INFER_DECL void Tensor::CopyToCpu(int32_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(uint8_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(int8_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(float16 *data) const; +template PD_INFER_DECL void Tensor::CopyToCpu(bfloat16 *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(bool *data) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl( @@ -568,6 +583,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuImpl( int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl( float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; +template PD_INFER_DECL void Tensor::CopyToCpuImpl( + bfloat16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl(bool *data, void *exec_stream, CallbackFunc cb, @@ -587,6 +604,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync( int8_t *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float16 *data, void *exec_stream) const; +template PD_INFER_DECL void Tensor::CopyToCpuAsync( + bfloat16 *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( bool *data, void *exec_stream) const; @@ -604,6 +623,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync( int8_t *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float16 *data, CallbackFunc cb, void *cb_params) const; +template PD_INFER_DECL void Tensor::CopyToCpuAsync( + bfloat16 *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync(bool *data, CallbackFunc cb, void *cb_params) const; @@ -622,6 +643,8 @@ template PD_INFER_DECL int8_t *Tensor::data(PlaceType *place, int *size) const; template PD_INFER_DECL float16 *Tensor::data(PlaceType *place, int *size) const; +template PD_INFER_DECL bfloat16 *Tensor::data(PlaceType *place, + int *size) const; template PD_INFER_DECL bool *Tensor::data(PlaceType *place, int *size) const; @@ -632,6 +655,8 @@ template PD_INFER_DECL int32_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL uint8_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL int8_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL float16 *Tensor::mutable_data(PlaceType place); +template PD_INFER_DECL bfloat16 *Tensor::mutable_data( + PlaceType place); template PD_INFER_DECL bool *Tensor::mutable_data(PlaceType place); Tensor::Tensor(void *scope, const void *device_contexts) @@ -783,6 +808,7 @@ template void Tensor::ORTCopyToCpu(int32_t *data) const; template void Tensor::ORTCopyToCpu(uint8_t *data) const; template void Tensor::ORTCopyToCpu(int8_t *data) const; template void Tensor::ORTCopyToCpu(float16 *data) const; +template void Tensor::ORTCopyToCpu(bfloat16 *data) const; #endif namespace experimental { @@ -921,6 +947,8 @@ template void InternalUtils::CopyFromCpuWithIoStream( paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream); template void InternalUtils::CopyFromCpuWithIoStream( paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream); +template void InternalUtils::CopyFromCpuWithIoStream( + paddle_infer::Tensor *t, const bfloat16 *data, cudaStream_t stream); template void InternalUtils::CopyFromCpuWithIoStream( paddle_infer::Tensor *t, const bool *data, cudaStream_t stream); @@ -938,6 +966,8 @@ template void InternalUtils::CopyToCpuWithIoStream( paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream); template void InternalUtils::CopyToCpuWithIoStream( paddle_infer::Tensor *t, float16 *data, cudaStream_t stream); +template void InternalUtils::CopyToCpuWithIoStream( + paddle_infer::Tensor *t, bfloat16 *data, cudaStream_t stream); template void InternalUtils::CopyToCpuWithIoStream( paddle_infer::Tensor *t, bool *data, cudaStream_t stream); diff --git a/paddle/fluid/inference/api/paddle_infer_contrib.cc b/paddle/fluid/inference/api/paddle_infer_contrib.cc index 11786b05c3035..d0d7e59e09139 100644 --- a/paddle/fluid/inference/api/paddle_infer_contrib.cc +++ b/paddle/fluid/inference/api/paddle_infer_contrib.cc @@ -108,6 +108,13 @@ void TensorUtils::CopyTensorImpl(Tensor* p_dst, cb, cb_params); break; + case PaddleDType::BFLOAT16: + src.CopyToCpuImpl( + dst.mutable_data(PlaceType::kCPU), + exec_stream, + cb, + cb_params); + break; default: PADDLE_THROW(paddle::platform::errors::Unimplemented( "Only INT32, INT64, UINT8, INT8, BOOL, FLOAT16, FLOAT32 and " @@ -172,6 +179,13 @@ void TensorUtils::CopyTensorImpl(Tensor* p_dst, src.data(&src_place, &data_size)); data_len = data_size * 2; break; + case PaddleDType::BFLOAT16: + dst_data = static_cast( + dst.mutable_data(PlaceType::kGPU)); + src_data = static_cast( + src.data(&src_place, &data_size)); + data_len = data_size * 2; + break; default: PADDLE_THROW(paddle::platform::errors::Unimplemented( "Only INT32, INT64, UINT8, INT8, BOOL, FLOAT16, FLOAT32 and " diff --git a/paddle/fluid/inference/api/paddle_tensor.h b/paddle/fluid/inference/api/paddle_tensor.h index b9c86a60f55b8..a798bdd38c29c 100644 --- a/paddle/fluid/inference/api/paddle_tensor.h +++ b/paddle/fluid/inference/api/paddle_tensor.h @@ -62,6 +62,7 @@ enum DataType { FLOAT16, BOOL, FLOAT64, + BFLOAT16, // TODO(Inference): support more data types if needed. }; diff --git a/paddle/fluid/operators/custom_all_reduce.h b/paddle/fluid/operators/custom_all_reduce.h new file mode 100644 index 0000000000000..ef2ca9c26c456 --- /dev/null +++ b/paddle/fluid/operators/custom_all_reduce.h @@ -0,0 +1,691 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" + +#if defined(__CUDACC__) && CUDA_VERSION >= 11000 +#define CUSTOMAR_ENABLE_BF16 +#endif + +namespace paddle { +namespace operators { + +constexpr int DEFAULT_BLOCK_SIZE = 1024; +constexpr int MAX_ALL_REDUCE_BLOCKS = 24; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void st_flag_release(uint32_t &flag, // NOLINT + volatile uint32_t *flag_addr) { +#if __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#else + __threadfence_system(); + asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void ld_flag_acquire(uint32_t &flag, // NOLINT + volatile uint32_t *flag_addr) { +#if __CUDA_ARCH__ >= 700 + asm volatile("ld.global.acquire.sys.b32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); +#else + asm volatile("ld.global.volatile.b32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +class SystemCUDAAllocator : public phi::Allocator { + public: + static phi::Allocator *Instance() { + static SystemCUDAAllocator allocator; + return &allocator; + } + + phi::Allocator::AllocationPtr Allocate(size_t size) override { + if (size == 0) { + return nullptr; + } + void *ptr = nullptr; + PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(&ptr, size)); + return phi::Allocator::AllocationPtr(new phi::Allocation(ptr, size, place), + DeleteFunc); + } + + bool IsAllocThreadSafe() const override { return true; } + + private: + static void DeleteFunc(phi::Allocation *allocation) { + cudaFree(allocation->ptr()); + delete allocation; + } + + SystemCUDAAllocator() : place(platform::GetCurrentDeviceId()) {} + + DISABLE_COPY_AND_ASSIGN(SystemCUDAAllocator); + + private: + phi::GPUPlace place; +}; + +template +static __global__ void FillBarrierValue(T *x, T value) { + x[threadIdx.x] = value; +} + +template +static __forceinline__ __device__ void BarrierAllGPUs( + const phi::Array &barriers, T barrier_value, int rank) { + int block_id = blockIdx.x; + int thread_id = threadIdx.x; + + if (thread_id < N) { + if (block_id == 0) { + barriers[thread_id][rank] = barrier_value; + } + while (barriers[rank][thread_id] < barrier_value) { + } + } + + __syncthreads(); +} + +template +static __forceinline__ __device__ void BarrierAllGPUsAllBlock( + const phi::Array &barriers, T barrier_value, int rank) { + int block_id = blockIdx.x; + int thread_id = threadIdx.x; + + if (thread_id < N) { + uint32_t flag_block_offset = N + block_id * N; + st_flag_release(barrier_value, + barriers[thread_id] + flag_block_offset + rank); + uint32_t rank_barrier = 0; + volatile uint32_t *peer_barrier_d = + barriers[rank] + flag_block_offset + thread_id; + do { + ld_flag_acquire(rank_barrier, peer_barrier_d); + } while (rank_barrier != barrier_value); + } + + __syncthreads(); +} + +template +struct AlignedVectorAddHelper { + DEVICE static void Run(const phi::AlignedVector &in, + phi::AlignedVector *out) { +#pragma unroll + for (int i = 0; i < N; ++i) { + (*out)[i] += in[i]; + } + } +}; + +template +struct AlignedVectorAddHelper { + DEVICE static void Run(const phi::AlignedVector &in, + phi::AlignedVector *out) { + const __half2 *in_ptr = + static_cast(static_cast(&in[0])); + __half2 *out_ptr = static_cast<__half2 *>(static_cast(&(*out)[0])); +#pragma unroll + for (int i = 0; i < N / 2; ++i) { + out_ptr[i] = __hadd2(out_ptr[i], in_ptr[i]); + } + if (N % 2 != 0) { + (*out)[N - 1] += in[N - 1]; + } + } +}; +#ifdef CUSTOMAR_ENABLE_BF16 +inline __device__ __nv_bfloat162 float2bf162(const float2 a) { + __nv_bfloat162 a_; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + a_ = __float22bfloat162_rn(a); +#else + a_.x = __float2bfloat16_rn(a.x); + a_.y = __float2bfloat16_rn(a.y); +#endif + return a_; +} + +// Bfloat16 Specialization. +template +struct AlignedVectorAddHelper { + DEVICE static void Run(const phi::AlignedVector &in, + phi::AlignedVector *out) { + const __nv_bfloat162 *in_ptr = + static_cast(static_cast(&in[0])); + __nv_bfloat162 *out_ptr = + static_cast<__nv_bfloat162 *>(static_cast(&(*out)[0])); +#pragma unroll + for (int i = 0; i < N / 2; ++i) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + out_ptr[i] = __hadd2(out_ptr[i], in_ptr[i]); +#else + float2 out{}; + out.x = __bfloat162float(out_ptr[i].x) + __bfloat162float(in_ptr[i].x); + out.y = __bfloat162float(out_ptr[i].y) + __bfloat162float(in_ptr[i].y); + out_ptr[i] = float2bf162(out); +#endif + } + if (N % 2 != 0) { + (*out)[N - 1] += in[N - 1]; + } + } +}; + +#endif //CUSTOMAR_ENABLE_BF16 + +template +static __device__ __forceinline__ void AllReduceFunc( + const phi::Array &ins, + int idx, + int stride, + int n, + int rank, + T *out) { + using AlignedVec = phi::AlignedVector; + while (idx + VecSize <= n) { + AlignedVec in_vecs[N]; + +#pragma unroll + for (int i = 0; i < N; ++i) { + auto cur_rank = (i + rank) % N; + const auto *ptr = ins[cur_rank] + idx; + phi::Load(ptr, &in_vecs[cur_rank]); + } + +#pragma unroll + for (int i = 1; i < N; ++i) { + AlignedVectorAddHelper::Run(in_vecs[i], &in_vecs[0]); + } + phi::Store(in_vecs[0], out + idx); + idx += stride; + } + + while (HasLeftValue && idx < n) { + T sum = ins[0][idx]; +#pragma unroll + for (int i = 1; i < N; ++i) { + sum += ins[i][idx]; + } + out[idx] = sum; + ++idx; + } +} + +template +static __global__ +__launch_bounds__(DEFAULT_BLOCK_SIZE) void OneShotAllReduceKernel( + phi::Array ins, + phi::Array barriers, + BarrierT barrier_value, + int rank, + size_t n, + T *out) { + BarrierAllGPUs(barriers, barrier_value, rank); + + int idx = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize; + int stride = (blockDim.x * gridDim.x) * VecSize; + AllReduceFunc(ins, idx, stride, n, rank, out); +} + +template +static __device__ __forceinline__ void VecStoreGlobalMem(const T *x, T *y) { + using AlignedVec = phi::AlignedVector; + const auto *x_vec = + static_cast(static_cast(x)); + auto *y_vec = static_cast(static_cast(y)); + y_vec[0] = x_vec[0]; +} + +template +static __global__ +__launch_bounds__(DEFAULT_BLOCK_SIZE) void TwoShotAllReduceKernel( + phi::Array ins, + phi::Array barriers, + BarrierT barrier_value, + int rank, + size_t n, + T *out) { + BarrierAllGPUs(barriers, barrier_value, rank); + const size_t n_per_gpu = n / N; + int idx = + (threadIdx.x + blockIdx.x * blockDim.x) * VecSize + rank * n_per_gpu; + int stride = (blockDim.x * gridDim.x) * VecSize; + int limit = (rank + 1) * n_per_gpu; + AllReduceFunc(ins, idx, stride, limit, rank, ins[rank]); + + BarrierAllGPUsAllBlock(barriers, barrier_value + 1, rank); + using AlignedVec = phi::AlignedVector; + + int dst_offset[N]; + int dst_rank[N]; +#pragma unroll + for (int i = 0; i < N; ++i) { + int tmp = (i + rank) % N; + dst_rank[i] = tmp; + dst_offset[i] = (tmp - rank) * n_per_gpu; + } + + while (idx + VecSize <= limit) { +#pragma unroll + for (int i = 0; i < N; ++i) { + auto dst_idx = idx + dst_offset[i]; + VecStoreGlobalMem(ins[dst_rank[i]] + dst_idx, out + dst_idx); + } + idx += stride; + } +} + +class CustomNCCLComm { + public: + virtual void SwapInput(phi::DenseTensor *x) = 0; + virtual phi::DenseTensor AllReduce() = 0; + + virtual ~CustomNCCLComm() = default; + + protected: + void EnableP2P(int nranks) { + for (int i = 0; i < nranks; ++i) { + platform::CUDADeviceGuard guard(i); + for (int j = 0; j < nranks; ++j) { + int enabled = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceCanAccessPeer(&enabled, i, j)); + PADDLE_ENFORCE_EQ(enabled, 1); + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceEnablePeerAccess(j, 0)); + } + } + } +}; + +template +class CustomNCCLCommImpl : public CustomNCCLComm { + private: + template + struct P2PBuffer { + template + P2PBuffer(CustomNCCLCommImpl *comm, size_t size, InitFunc &&init_func) { + phi::Dim<1> dim; + dim[0] = static_cast(size); + t_.Resize(dim); + void *ptr = t_.AllocateFrom(SystemCUDAAllocator::Instance(), + phi::CppTypeToDataType::Type()); + init_func(*(comm->ctx_), &t_); + comm->ctx_->Wait(); + + comm->Barrier(); + + auto pids = comm->AllGatherPOD(::getpid()); + for (int i = 0; i < N; ++i) { + BroadcastDevicePtr(comm, ptr, i, pids[0]); + } + } + + ~P2PBuffer() { + for (int i = 0; i < N; ++i) { + if (i != rank_) { + cudaIpcCloseMemHandle(ptrs_[i]); + } + ::munmap(mmap_ptrs_[i], sizeof(cudaIpcMemHandle_t)); + ::shm_unlink(shm_names_[i].c_str()); + } + t_.clear(); + } + + const phi::DenseTensor &GetTensor() const { return t_; } + phi::DenseTensor *GetMutableTensor() { return &t_; } + + template + phi::Array GetPtrs() const { + phi::Array results; +#pragma unroll + for (int i = 0; i < N; ++i) { + results[i] = static_cast(ptrs_[i]); + } + return results; + } + + private: + void BroadcastDevicePtr(CustomNCCLCommImpl *comm, + void *ptr, + int cur_rank, + pid_t pid) { + VLOG(10) << "BroadcastDevicePtr starts " << cur_rank << " -> " + << comm->rank_; + std::string name = "/paddle_custom_nccl_" + std::to_string(pid) + "_" + + std::to_string(cur_rank); + cudaIpcMemHandle_t *handle; + bool is_root = (comm->rank_ == cur_rank); + + if (!is_root) { + comm->Barrier(); + } + + int fd = ::shm_open( + name.c_str(), is_root ? (O_RDWR | O_CREAT) : O_RDONLY, 0600); + PADDLE_ENFORCE_GE(fd, 0); + if (is_root) { + PADDLE_ENFORCE_EQ(ftruncate(fd, sizeof(cudaIpcMemHandle_t)), 0); + } + void *mmap_ptr = ::mmap(nullptr, + sizeof(cudaIpcMemHandle_t), + is_root ? (PROT_READ | PROT_WRITE) : PROT_READ, + MAP_SHARED, + fd, + 0); + PADDLE_ENFORCE_NOT_NULL(mmap_ptr); + PADDLE_ENFORCE_NE(mmap_ptr, MAP_FAILED); + handle = static_cast(mmap_ptr); + if (is_root) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaIpcGetMemHandle(handle, ptr)); + ptrs_[cur_rank] = ptr; + } else { + PADDLE_ENFORCE_GPU_SUCCESS(cudaIpcOpenMemHandle( + &ptrs_[cur_rank], *handle, cudaIpcMemLazyEnablePeerAccess)); + } + if (is_root) { + comm->Barrier(); + } + + comm->Barrier(); + mmap_ptrs_[cur_rank] = mmap_ptr; + shm_names_[cur_rank] = name; + VLOG(10) << "BroadcastDevicePtr ends " << cur_rank << " -> " + << comm->rank_; + } + + private: + phi::Array ptrs_; + phi::DenseTensor t_; + int rank_; + phi::Array mmap_ptrs_; + phi::Array shm_names_; + }; + + public: + using BarrierDType = uint32_t; + using BarrierTensorDType = int32_t; + + static_assert(sizeof(BarrierDType) == sizeof(BarrierTensorDType), + "Size not match"); + + CustomNCCLCommImpl(const phi::GPUContext &ctx, + size_t one_shot_max_size, + size_t two_shot_max_size, + int ring_id) + : ctx_(&ctx), + one_shot_max_size_(one_shot_max_size), + two_shot_max_size_(two_shot_max_size) { + PADDLE_ENFORCE_LT(one_shot_max_size, two_shot_max_size); + auto comm = + platform::NCCLCommContext::Instance().Get(ring_id, ctx.GetPlace()); + comm_ = comm->comm(); + rank_ = comm->rank(); + auto nranks = comm->nranks(); + PADDLE_ENFORCE_EQ( + nranks, + N, + phi::errors::InvalidArgument("Invalid world size, this may be a bug.")); + + barrier_value_ = 0; + VLOG(10) << "CustomNCCLCommImpl::CustomNCCLCommImpl"; + ins_ = std::make_unique>( + this, + two_shot_max_size_, + [](const phi::GPUContext &ctx, phi::DenseTensor *t) {}); + VLOG(10) << "CustomNCCLCommImpl::ins_ inited"; + + barriers_ = std::make_unique>( + this, + N * (MAX_ALL_REDUCE_BLOCKS + 1), + [](const phi::GPUContext &ctx, phi::DenseTensor *t) { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(t->data(), + 0, + t->numel() * sizeof(BarrierTensorDType), + ctx.stream())); + }); + VLOG(10) << "CustomNCCLCommImpl::barriers_ inited"; + } + + void SwapInput(phi::DenseTensor *x) override { + out_ = *x; + auto numel = x->numel(); + auto dtype = x->dtype(); + auto algo = ChooseAlgo(numel, dtype); + if (algo <= 2 && !HasReachMaxBarrierValue(algo)) { + ShareTensor(x, ins_->GetMutableTensor()); + } + } + + phi::DenseTensor AllReduce() override { + auto dtype = out_.dtype(); + auto numel = out_.numel(); + auto algo = ChooseAlgo(numel, dtype); + if (algo > 2) { + NCCLAllReduce(out_.data(), numel, dtype); + return std::move(out_); + } + + if (HasReachMaxBarrierValue(algo)) { + NCCLAllReduce(out_.data(), numel, dtype); + ResetBarriers(); + return std::move(out_); + } + +#define PD_CUSTOM_ALLREDUCE(__cpp_dtype, __vec_size) \ + do { \ + if (dtype == ::phi::CppTypeToDataType<__cpp_dtype>::Type()) { \ + if (algo == 1) { \ + return OneShotAllReduceImpl<__cpp_dtype, __vec_size>(numel); \ + } else { \ + return TwoShotAllReduceImpl<__cpp_dtype, __vec_size>(numel); \ + } \ + } \ + } while (0) + PD_CUSTOM_ALLREDUCE(phi::dtype::bfloat16, 8); + PD_CUSTOM_ALLREDUCE(phi::dtype::float16, 8); + PD_CUSTOM_ALLREDUCE(float, 4); + PD_CUSTOM_ALLREDUCE(double, 2); + PADDLE_THROW( + phi::errors::InvalidArgument("Unsupported data type %s", dtype)); + } + + private: + uint32_t ChooseAlgo(size_t numel, phi::DataType dtype) const { + auto sizeof_dtype = phi::SizeOf(dtype); + auto mem_size = numel * sizeof_dtype; + if (mem_size <= one_shot_max_size_) { + return 1; + } else if (mem_size <= two_shot_max_size_ && numel % N == 0 && + (numel / N) % (16 / sizeof_dtype) == 0) { + return 2; + } else { + return 3; + } + } + + void NCCLAllReduce(void *ptr, size_t numel, phi::DataType dtype) { + auto nccl_dtype = platform::ToNCCLDataType(dtype); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( + ptr, ptr, numel, nccl_dtype, ncclSum, comm_, ctx_->stream())); + } + + void ResetBarriers() { + LOG(INFO) << "barrier_value_ " << barrier_value_ << " , restart barrier"; + Barrier(); + auto *barrier_tensor = barriers_->GetMutableTensor(); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(barrier_tensor->data(), + 0, + barrier_tensor->numel() * sizeof(BarrierTensorDType), + ctx_->stream())); + Barrier(); + barrier_value_ = 0; + } + + bool HasReachMaxBarrierValue(int algo) const { + return barrier_value_ > std::numeric_limits::max() - algo; + } + + template + phi::DenseTensor OneShotAllReduceImpl(int64_t numel) { + const auto &in_ptrs = ins_->template GetPtrs(); + const auto &barrier_ptrs = + barriers_->template GetPtrs(); + auto *out_data = out_.template data(); + ++barrier_value_; + + int threads = DEFAULT_BLOCK_SIZE; // 1024 + PADDLE_ENFORCE_GE(threads, N); + int64_t blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; + blocks = std::min(blocks, MAX_ALL_REDUCE_BLOCKS /*24*/); + VLOG(10) << "Use OneShotAllReduceKernel for size = " << numel; + + OneShotAllReduceKernel + <<stream()>>>( + in_ptrs, barrier_ptrs, barrier_value_, rank_, numel, out_data); + return std::move(out_); + } + + template + phi::DenseTensor TwoShotAllReduceImpl(int64_t numel) { + PADDLE_ENFORCE_EQ(numel % N, 0); + const auto &in_ptrs = ins_->template GetPtrs(); + const auto &barrier_ptrs = + barriers_->template GetPtrs(); + auto *out_data = out_.template data(); + if (barrier_value_ > 0) { + barrier_value_ += 2; + } else { + barrier_value_ = 1; + } + + // int threads = ctx_->GetMaxThreadsPerBlock(); + int threads = DEFAULT_BLOCK_SIZE; + PADDLE_ENFORCE_GE(threads, N); + int32_t blocks = + ((numel / N + VecSize - 1) / VecSize + threads - 1) / threads; + blocks = std::min(blocks, MAX_ALL_REDUCE_BLOCKS /*24*/); + VLOG(10) << "Use TwoShotAllReduceKernel for size = " << numel; + TwoShotAllReduceKernel + <<stream()>>>( + in_ptrs, barrier_ptrs, barrier_value_, rank_, numel, out_data); + return std::move(out_); + } + + void ShareTensor(phi::DenseTensor *x, phi::DenseTensor *y) { + PADDLE_ENFORCE_LE(x->numel(), two_shot_max_size_); + const void *y_ptr = y->data(); + y->Resize(x->dims()); + auto *new_y_ptr = ctx_->Alloc(y, x->dtype()); + PADDLE_ENFORCE_EQ(y_ptr, new_y_ptr); + x->ShareBufferWith(*y); + } + + void Barrier() { AllGatherPOD(1); } + + template + std::vector AllGatherPOD(const T &value) { + std::vector result(N); + AllGatherBuffer(&value, result.data(), sizeof(T)); + return result; + } + + void AllGatherBuffer(const void *src, void *dst, size_t nbytes) { + phi::DenseTensor tensor; + phi::Dim<1> dim; + dim[0] = N * nbytes; + tensor.Resize(dim); + auto *ptr = ctx_->template Alloc(&tensor); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(ptr + rank_ * nbytes, + src, + nbytes, + cudaMemcpyHostToDevice, + ctx_->stream())); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllGather( + ptr + rank_ * nbytes, ptr, nbytes, ncclInt8, comm_, ctx_->stream())); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync( + dst, ptr, N * nbytes, cudaMemcpyDeviceToHost, ctx_->stream())); + ctx_->Wait(); + } + + private: + std::unique_ptr> ins_; + std::unique_ptr> barriers_; + BarrierDType barrier_value_; + phi::DenseTensor out_; + + const phi::GPUContext *ctx_; + size_t one_shot_max_size_; + size_t two_shot_max_size_; + ncclComm_t comm_; + int rank_; +}; + +static std::unique_ptr CreateCustomNCCLComm( + const phi::GPUContext &ctx, + int64_t one_shot_max_size, + int64_t two_shot_max_size, + int ring_id) { + if (one_shot_max_size <= 0 || two_shot_max_size <= 0 || + one_shot_max_size >= two_shot_max_size) { + return nullptr; + } + + auto nranks = platform::NCCLCommContext::Instance() + .Get(ring_id, ctx.GetPlace()) + ->nranks(); +#define PD_CREATE_CUSTOM_NCCL_COMM(__nranks) \ + do { \ + if (nranks == __nranks) { \ + return std::make_unique>( \ + ctx, one_shot_max_size, two_shot_max_size, ring_id); \ + } \ + } while (0) + + PD_CREATE_CUSTOM_NCCL_COMM(8); + PD_CREATE_CUSTOM_NCCL_COMM(4); + PD_CREATE_CUSTOM_NCCL_COMM(2); + return nullptr; +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/attn_gemm_int8.h b/paddle/fluid/operators/fused/attn_gemm_int8.h index c61a7f60d4359..ade1e79dd2347 100644 --- a/paddle/fluid/operators/fused/attn_gemm_int8.h +++ b/paddle/fluid/operators/fused/attn_gemm_int8.h @@ -1,11 +1,8 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -35,8 +32,8 @@ class AttnMatmulINT8 { AttnMatmulINT8( const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias) : dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) { - auto helper = std::make_shared(m, k, n); - helpers_.emplace_back(helper); + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + helper_ = std::make_unique>(m, k, n, lt_handle); gpu_config_ = std::make_unique( phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, m * n, DequantKernelVecSize)); @@ -54,6 +51,7 @@ class AttnMatmulINT8 { phi::DenseTensor* bias_out, const float quant_in_scale, const phi::DenseTensor* dequant_out_scale, + phi::DenseTensor* workspace = nullptr, const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { @@ -67,10 +65,12 @@ class AttnMatmulINT8 { quant_min_bound, dev_ctx_.stream()); - helpers_[0]->GEMM(input_tmp->data(), - weight->data(), - output_tmp->data(), - dev_ctx_.stream()); + helper_->GEMM(input_tmp->data(), + weight->data(), + output_tmp->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); dequantize_kernel_launcher(output_tmp->data(), output->data(), @@ -86,7 +86,7 @@ class AttnMatmulINT8 { std::vector ins = {output, bias}; std::vector outs = {bias_out}; phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess, platform::errors::Fatal( @@ -103,12 +103,13 @@ class AttnMatmulINT8 { const phi::DenseTensor* bias, phi::DenseTensor* output, phi::DenseTensor* bias_out, - void* workspace = nullptr) { - helpers_[0]->GEMM(input->data(), - weight->data(), - output->data(), - dev_ctx_.stream(), - workspace); + phi::DenseTensor* workspace = nullptr) { + helper_->GEMM(input->data(), + weight->data(), + output->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); } // This function is used to execute GEMM, with input and output's types are @@ -120,11 +121,14 @@ class AttnMatmulINT8 { phi::DenseTensor* output, phi::DenseTensor* output_tmp, phi::DenseTensor* bias_out, - const phi::DenseTensor* dequant_out_scale) { - helpers_[0]->GEMM(input->data(), - weight->data(), - output_tmp->data(), - dev_ctx_.stream()); + const phi::DenseTensor* dequant_out_scale, + phi::DenseTensor* workspace = nullptr) { + helper_->GEMM(input->data(), + weight->data(), + output_tmp->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); dequantize_kernel_launcher(output_tmp->data(), output->data(), @@ -140,7 +144,7 @@ class AttnMatmulINT8 { std::vector ins = {output, bias}; std::vector outs = {bias_out}; phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess, platform::errors::Fatal( @@ -159,6 +163,7 @@ class AttnMatmulINT8 { const phi::DenseTensor* bias, phi::DenseTensor* output, phi::DenseTensor* bias_out, + phi::DenseTensor* workspace = nullptr, const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { @@ -172,10 +177,12 @@ class AttnMatmulINT8 { quant_min_bound, dev_ctx_.stream()); - helpers_[0]->GEMM(input_tmp->data(), - weight->data(), - output->data(), - dev_ctx_.stream()); + helper_->GEMM(input_tmp->data(), + weight->data(), + output->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); } private: @@ -186,7 +193,7 @@ class AttnMatmulINT8 { int k_; // k int compute_bias_; - std::vector> helpers_; + std::unique_ptr> helper_; std::unique_ptr gpu_config_; }; diff --git a/paddle/fluid/operators/fused/cublaslt.h b/paddle/fluid/operators/fused/cublaslt.h index e9728c58b55dc..f949755a411e9 100644 --- a/paddle/fluid/operators/fused/cublaslt.h +++ b/paddle/fluid/operators/fused/cublaslt.h @@ -1,11 +1,11 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); +Copyright( + c) 2022 NVIDIA Authors.All Rights Reserved.Licensed under the Apache License + , + Version 2.0(the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,108 +14,753 @@ limitations under the License. */ #pragma once +#include #include #include #include #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/dynload/cublasLt.h" +DECLARE_int64(cublaslt_exhaustive_search_times); + namespace dyl = paddle::platform::dynload; namespace paddle { namespace operators { -struct CublasLtAlgoParam { - int algoId; +#define PADDLE_CUBLASLT_STATUS_CHECK(name) \ + PADDLE_ENFORCE_EQ( \ + status, \ + CUBLAS_STATUS_SUCCESS, \ + platform::errors::External( \ + #name \ + "execution error" \ + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " \ + "information")) + +const int split_k_candidates[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; + +struct CublasLtAlgoSelectorParam { + cublasLtMatmulAlgo_t algo; + int m; + int n; + int k; + int algo_id; int swizzle; - int customOption; + int custom_option; int tile; - int splitK_val; - int reductionScheme; + int split_k_val; + int reduction_scheme; int stages; + void* workspace; size_t workspace_size; + float time; }; -const std::map, CublasLtAlgoParam> AlgoParamCache{}; +inline bool compare_algo_time(const CublasLtAlgoSelectorParam& param_a, + const CublasLtAlgoSelectorParam& param_b) { + return (param_a.time < param_b.time); +} +#if CUDA_VERSION >= 11020 +class CublasLtAlgoCache { + public: + static CublasLtAlgoCache& Instance() { + static CublasLtAlgoCache instance(FLAGS_cublaslt_exhaustive_search_times); + return instance; + } + + template + void TestMatmulRun(cublasLtHandle_t handle, + cublasLtMatmulDesc_t matmul_desc, + cublasLtMatrixLayout_t a_desc, + cublasLtMatrixLayout_t b_desc, + cublasLtMatrixLayout_t c_desc, + void* alpha, + void* beta, + const InT* a, + const InT* b, + OutT* c, + CublasLtAlgoSelectorParam& param, // NOLINT + cudaEvent_t& start_event, // NOLINT + cudaEvent_t& stop_event, // NOLINT + cudaStream_t stream) { + + cublasStatus_t status; + cublasLtMatmulHeuristicResult_t heuristic_result; + status = dyl::cublasLtMatmulAlgoCheck(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + c_desc, + ¶m.algo, + &heuristic_result); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCheck); + if (status != CUBLAS_STATUS_SUCCESS || + heuristic_result.workspaceSize > param.workspace_size) { + // VLOG(0) << "param.workspace_size is " << param.workspace_size; + param.time = std::numeric_limits::max(); + return; + } + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream)); + int repeats = search_times_; + + for (int loop = 0; loop < repeats; loop++) { + status = dyl::cublasLtMatmul(handle, + matmul_desc, + alpha, + a, + a_desc, + b, + b_desc, + beta, + c, + c_desc, + c, + c_desc, + ¶m.algo, + param.workspace, + param.workspace_size, + stream); + if (status != CUBLAS_STATUS_SUCCESS) { + param.time = std::numeric_limits::max(); + return; + } + } + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + float time; + PADDLE_ENFORCE_GPU_SUCCESS( + cudaEventElapsedTime(&time, start_event, stop_event)); + + param.time = time / repeats; + } + + template + cublasLtMatmulAlgo_t* CublasLtAlgoSelect(cublasLtHandle_t handle, + int m, + int n, + int k, + const InT* a, + const InT* b, + OutT* c, + void* alpha, + void* beta, + cublasLtMatmulDesc_t matmul_desc, + cublasLtMatrixLayout_t a_desc, + cublasLtMatrixLayout_t b_desc, + cublasLtMatrixLayout_t c_desc, + cublasComputeType_t compute_type, + cudaDataType_t scale_type, + cudaDataType_t a_type, + cudaDataType_t b_type, + cudaDataType_t c_type, + void* workspace, + size_t workspace_size, + cudaStream_t stream) { +#if CUDA_VERSION >= 11010 + // If we don't have config file and we donot search, here return nullptr + if (!has_config_file_ && search_times_ <= 0) { + return nullptr; + } + + // VLOG(0) << "m n k" << m << " " << n << " " << k; + + int64_t seed = 0; + std::hash hash_fn; + + HashMatmulDesc_(matmul_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(a_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(b_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(c_desc, &seed, hash_fn); + + cublasLtMatmulAlgo_t ret; + { + std::lock_guard lock(cache_mutex_); + auto it = map_.find(seed); + if (it != map_.end()) { + VLOG(3) << "CublasLtAlgoSelect Found in cache"; + return &(it->second); + } else { + // if we have cache but not found algo, and we don't want to search, + // here return nullptr + if (search_times_ <= 0) { + return nullptr; + } + } + } + VLOG(3) << "CublasLtAlgoSelect Not Found in cache"; + + // Get Ids + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoGetIds + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + // std::vector algo_ids(requested_algo_count_); + int algo_ids[requested_algo_count_]; // NOLINT + + int num_algo_ids; + status = dyl::cublasLtMatmulAlgoGetIds(handle, + compute_type, + scale_type, + a_type, + b_type, + c_type, + c_type, + requested_algo_count_, + algo_ids, + &num_algo_ids); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoGetIds); + + // Traverse all posssible algo combinations + int step = 0; + int limit = 20000; + std::vector params; + + for (int idx = 0; idx < num_algo_ids; idx++) { + cublasLtMatmulAlgo_t algo; + + /* Initialize algo structure with given Algp ID */ + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoInit + status = dyl::cublasLtMatmulAlgoInit(handle, + compute_type, + scale_type, + a_type, + b_type, + c_type, + c_type, + algo_ids[idx], + &algo); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoInit); + + // Query the tiles enums supported by that algo which is used to alloc + // enough space to store it + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCapGetAttribute + size_t attr_size = 0; + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, CUBLASLT_ALGO_CAP_TILE_IDS, nullptr, 0, &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + + int num_tiles = static_cast(attr_size / sizeof(int)); + std::vector tiles(num_tiles == 0 ? 1 : num_tiles); + if (num_tiles == 0) { + tiles[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; + num_tiles = 1; + } else { + status = + dyl::cublasLtMatmulAlgoCapGetAttribute(&algo, + CUBLASLT_ALGO_CAP_TILE_IDS, + tiles.data(), + sizeof(int) * num_tiles, + &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + } + + // Query the stages enums supported by that algo (cuda must >= 11.0) + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, nullptr, 0, &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + int num_stages = static_cast(attr_size / sizeof(int)); + std::vector stages(num_stages == 0 ? 1 : num_stages); + if (num_stages == 0) { + stages[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; + num_stages = 1; + } else { + status = + dyl::cublasLtMatmulAlgoCapGetAttribute(&algo, + CUBLASLT_ALGO_CAP_STAGES_IDS, + stages.data(), + sizeof(int) * num_stages, + &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + } + + // Retrieve Other Algo Capabilities attributes + int splitk_support, red_mask, swizzling_max, custom_option_max; + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, + &splitk_support, + sizeof(splitk_support), + &attr_size); + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, + &red_mask, + sizeof(red_mask), + &attr_size); + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, + &swizzling_max, + sizeof(swizzling_max), + &attr_size); + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, + &custom_option_max, + sizeof(custom_option_max), + &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + + /* Loop over the different tiles */ + for (int tile_id = 0; tile_id < num_tiles && step < limit; tile_id++) { + /* Loop over different stages count */ + for (int stage_id = 0; stage_id < num_stages && step < limit; + stage_id++) { + /* Loop over the different custom option if any */ + for (int custom_option = 0; + custom_option <= custom_option_max && step < limit; + custom_option++) { + /* Loop over the CTAs swizzling support */ + for (int k = 0; k <= swizzling_max && step < limit; k++) { + int splir_k_trial = 0; + if (splitk_support) { + splir_k_trial += + sizeof(split_k_candidates) / sizeof(split_k_candidates[0]); + } + + for (int l = 0; (l < (1 + splir_k_trial)) && (step < limit); + l++) { + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_TILE_ID, + &tiles[tile_id], + sizeof(tiles[tile_id])); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_STAGES_ID, + &stages[stage_id], + sizeof(stages[stage_id])); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &custom_option, + sizeof(custom_option)); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)); + int split_k_val = 0; + int reduction_scheme = CUBLASLT_REDUCTION_SCHEME_NONE; + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &split_k_val, + sizeof(split_k_val)); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &reduction_scheme, + sizeof(int)); + if (l > 0) { // Split-K case + split_k_val = split_k_candidates[l - 1]; + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &split_k_candidates[l - 1], + sizeof(split_k_candidates[l - 1])); + for (reduction_scheme = 1; + reduction_scheme < + static_cast(CUBLASLT_REDUCTION_SCHEME_MASK) && + (step < limit); + reduction_scheme = reduction_scheme << 1) { + if (reduction_scheme & red_mask) { + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &reduction_scheme, + sizeof(reduction_scheme)); + PADDLE_CUBLASLT_STATUS_CHECK( + cublasLtMatmulAlgoConfigSetAttribute); + + cublasLtMatmulHeuristicResult_t heurResult; + status = dyl::cublasLtMatmulAlgoCheck(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + c_desc, + &algo, + &heurResult); + if (status == CUBLAS_STATUS_SUCCESS) { + CublasLtAlgoSelectorParam algo_select_params; + algo_select_params.algo = algo; + algo_select_params.m = m; + algo_select_params.n = n; + algo_select_params.k = k; + algo_select_params.algo_id = algo_ids[idx]; + algo_select_params.tile = tiles[tile_id]; + algo_select_params.swizzle = k; + algo_select_params.custom_option = custom_option; + algo_select_params.split_k_val = split_k_val; + algo_select_params.reduction_scheme = reduction_scheme; + algo_select_params.stages = stages[stage_id]; + algo_select_params.workspace_size = workspace_size; + algo_select_params.workspace = workspace; + params.emplace_back(algo_select_params); + step++; + } + } // end if + } + } else { + // Prepare algos + cublasLtMatmulHeuristicResult_t heurResult; + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCheck + status = dyl::cublasLtMatmulAlgoCheck(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + c_desc, + &algo, + &heurResult); + if (status == CUBLAS_STATUS_SUCCESS) { + CublasLtAlgoSelectorParam algo_select_params; + algo_select_params.algo = algo; + algo_select_params.m = m; + algo_select_params.n = n; + algo_select_params.k = k; + algo_select_params.algo_id = algo_ids[idx]; + algo_select_params.tile = tiles[tile_id]; + algo_select_params.swizzle = k; + algo_select_params.custom_option = custom_option; + algo_select_params.split_k_val = split_k_val; + algo_select_params.reduction_scheme = reduction_scheme; + algo_select_params.stages = stages[stage_id]; + algo_select_params.workspace_size = workspace_size; + algo_select_params.workspace = workspace; + params.emplace_back(algo_select_params); + step++; + } + } + } + } + } + } + } + } + cudaEvent_t start_event; + cudaEvent_t stop_event; + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event)); + + if (step == 0) { + VLOG(3) << "No algo can be used"; + return nullptr; + } + + VLOG(3) << "CublasLtAlgoSelect Start testRun " << step << " " + << params.size(); + + for (int i = 0; i < step; i++) { + TestMatmulRun(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + alpha, + beta, + a, + b, + c, + params[i], + start_event, + stop_event, + stream); + } + std::sort(params.begin(), params.end(), compare_algo_time); + + int res_id = 0; + while (params[res_id].time == 0) res_id++; + + if (res_id >= params.size()) { + VLOG(3) << "No algo can be used"; + return nullptr; + } + + VLOG(3) << "algo selected"; + + ret = params[res_id].algo; + std::lock_guard lock(cache_mutex_); + auto& algo_in_map = map_[seed]; + algo_in_map = ret; + return &algo_in_map; +#endif // #if CUDA_VERSION >= 11010 + } + + ~CublasLtAlgoCache() { + // Serialize map_ to cache file + if (search_times_ > 0) { + int dev; + cudaGetDevice(&dev); + if (dev == 0) { + std::ofstream outfile; + outfile.open(config_filename_, std::ios::out | std::ios::trunc); + outfile << dyl::cublasLtGetCudartVersion() << std::endl; + + for (const auto p : map_) { + outfile << p.first << " "; + for (int i = 0; i < 8; ++i) { + outfile << p.second.data[i] << " "; + } + outfile << std::endl; + } + outfile.close(); + } + } + } + + private: + explicit CublasLtAlgoCache(int search_times) + : search_times_(search_times), has_config_file_(true) { + // Init map_ from cache file + std::ifstream infile; + infile.open(config_filename_); + if (!infile.is_open()) { + printf("No config files \n"); + has_config_file_ = false; + VLOG(3) << "No CublasLtAlgoCache file found"; + return; + } + size_t cublaslt_version, real_cublaslt_version; + int64_t seed = 0; + uint64_t algo_data[8]; + infile >> cublaslt_version; + VLOG(1) << "cublaslt_version " << cublaslt_version; + + if (dyl::cublasLtGetCudartVersion() != cublaslt_version) { + LOG(INFO) << config_filename_ + << " is not compatible with current cublaslt_version " + << real_cublaslt_version; + return; + } + + while (!infile.eof()) { + infile >> seed >> algo_data[0] >> algo_data[1] >> algo_data[2] >> + algo_data[3] >> algo_data[4] >> algo_data[5] >> algo_data[6] >> + algo_data[7]; + + for (int i = 0; i < 8; ++i) { + map_[seed].data[i] = algo_data[i]; + } + } + infile.close(); + } + + std::string config_filename_{"/tmp/paddle_cublaslt_cache"}; + std::unordered_map map_; + int search_times_; + const int requested_algo_count_ = 100; + std::mutex cache_mutex_; + bool has_config_file_; + + inline int64_t RoundToNextHighPowOfTwo(int64_t n, int64_t min_val) { + n--; + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(min_val, (n + 1)); + } + + void HashMatmulDesc_(cublasLtMatmulDesc_t desc, + int64_t* seed, + const std::hash& hash_fn) { + size_t size_to_write; + int trans_a, trans_b; + uint32_t epilogue; + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatmulDescGetAttribute(desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &trans_a, + sizeof(trans_a), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(trans_a)); + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatmulDescGetAttribute(desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &trans_b, + sizeof(trans_b), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(trans_b)); + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatmulDescGetAttribute(desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, + sizeof(epilogue), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(epilogue)); + } + + void HashMatrixLayoutDesc_(cublasLtMatrixLayout_t desc, + int64_t* seed, + const std::hash& hash_fn) { + size_t size_to_write; + uint32_t dtype; + int32_t batch; + uint64_t row, col; + int64_t ld, batch_offset; + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatrixLayoutGetAttribute(desc, + CUBLASLT_MATRIX_LAYOUT_TYPE, + &dtype, + sizeof(dtype), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(dtype)); + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch, + sizeof(batch), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(batch)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), &size_to_write)); + HashValue_(seed, hash_fn, RoundToNextHighPowOfTwo(row, 32)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), &size_to_write)); + HashValue_(seed, hash_fn, RoundToNextHighPowOfTwo(col, 32)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); + HashValue_(seed, hash_fn, RoundToNextHighPowOfTwo(ld, 32)); + + // PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + // desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), + // &size_to_write)); + // HashValue_(seed, hash_fn, row); + + // PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + // desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), + // &size_to_write)); + // HashValue_(seed, hash_fn, col); + + // PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + // desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); + // HashValue_(seed, hash_fn, ld); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_offset, + sizeof(batch_offset), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(batch_offset)); + } + + void HashValue_(int64_t* seed, + const std::hash& hash_fn, + int64_t value) { + *seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + } +}; +#endif + +template class CublasLtHelper { public: - CublasLtHelper(int m, int k, int n) - : alpha_(1), beta_(0), m_(m), k_(k), n_(n) { + CublasLtHelper( + int m, int k, int n, cublasLtHandle_t handle, bool transpose_y = false) + : alpha_(1), + beta_(0), + m_(m), + k_(k), + n_(n), + handle_(handle), + transpose_y_(transpose_y) { cublasStatus_t status; // handle and matmul desc - status = dyl::cublasLtCreate(&handle_); + // status = dyl::cublasLtCreate(&handle_); + // PADDLE_CUBLASLT_STATUS_CHECK(cublasLtCreate); + if (std::is_same::value) { + scale_type_ = CUDA_R_16F; + a_type_ = CUDA_R_16F; + b_type_ = CUDA_R_16F; + c_type_ = CUDA_R_16F; #if CUBLAS_VER_MAJOR < 11 - cudaDataType_t cudaComputeType = CUDA_R_32I; + compute_type_ = CUDA_R_16F; #else - cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; + compute_type_ = CUBLAS_COMPUTE_16F; #endif - - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + } else if (std::is_same::value) { + scale_type_ = CUDA_R_32F; + a_type_ = CUDA_R_32F; + b_type_ = CUDA_R_32F; + c_type_ = CUDA_R_32F; +#if CUBLAS_VER_MAJOR < 11 + compute_type_ = CUDA_R_32F; +#else + compute_type_ = CUBLAS_COMPUTE_32F; +#endif + } else if (std::is_same::value) { +#if defined(__CUDACC__) && CUDA_VERSION >= 11000 + scale_type_ = CUDA_R_32F; + a_type_ = CUDA_R_16BF; + b_type_ = CUDA_R_16BF; + c_type_ = CUDA_R_16BF; +#if CUBLAS_VER_MAJOR < 11 + compute_type_ = CUDA_R_32F; +#else + compute_type_ = CUBLAS_COMPUTE_32F; +#endif +#endif + } else if (std::is_same::value) { + scale_type_ = CUDA_R_32I; + a_type_ = CUDA_R_8I; + b_type_ = CUDA_R_8I; + c_type_ = CUDA_R_32I; +#if CUBLAS_VER_MAJOR < 11 + compute_type_ = CUDA_R_32I; +#else + compute_type_ = CUBLAS_COMPUTE_32I; +#endif + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "CublasLtHelper just implement for FP16/FP32/INT32.")); + } #if CUBLAS_VER_MAJOR < 11 - status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType); + status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, compute_type_); #else status = dyl::cublasLtMatmulDescCreate( - &matmul_desc_, cudaComputeType, CUDA_R_32I); + &matmul_desc_, compute_type_, scale_type_); #endif + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulDescCreate); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatmulDescCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - cublasOperation_t op_transpose = CUBLAS_OP_T; - status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &op_transpose, - sizeof(op_transpose)); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatmulDescSetAttribute execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + // Node: Just test for int8 // matrix desc - status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + if (std::is_same::value && !transpose_y_) { + status = dyl::cublasLtMatrixLayoutCreate(&b_desc_, b_type_, n, k, n); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatrixLayoutCreate); + } else { + cublasOperation_t op_transpose = CUBLAS_OP_T; + status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &op_transpose, + sizeof(op_transpose)); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulDescSetAttribute); + status = dyl::cublasLtMatrixLayoutCreate(&b_desc_, b_type_, k, n, k); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatrixLayoutCreate); + } - status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + status = dyl::cublasLtMatrixLayoutCreate(&a_desc_, a_type_, k, m, k); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatrixLayoutCreate); -#if CUDA_VERSION >= 11020 + status = dyl::cublasLtMatrixLayoutCreate(&c_desc_, c_type_, n, m, n); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatrixLayoutCreate); +#if CUDA_VERSION >= 11010 int algoId = 21; int swizzle = 0; int customOption = 0; @@ -123,32 +768,33 @@ class CublasLtHelper { int splitK_val = 0; int reductionScheme = 0; int stages = 23; - workspace_size_ = 0; if (m >= 128) { tile = 20; stages = 17; } - std::tuple key(m_, k_, n_); - if (AlgoParamCache.count(key) != 0) { - auto value = AlgoParamCache.at(key); - algoId = value.algoId; - swizzle = value.swizzle; - customOption = value.customOption; - tile = value.tile; - splitK_val = value.splitK_val; - reductionScheme = value.reductionScheme; - stages = value.stages; - workspace_size_ = value.workspace_size; + if (std::is_same::value) { + algoId = 6; + swizzle = 1; + customOption = 0; + if (m <= 128) { + tile = 15; + stages = 18; + } else { + tile = 20; + stages = 11; + } + splitK_val = 0; + reductionScheme = 0; } dyl::cublasLtMatmulAlgoInit(handle_, - cudaComputeType, - CUDA_R_32I, - CUDA_R_8I, - CUDA_R_8I, - CUDA_R_32I, - CUDA_R_32I, + compute_type_, + scale_type_, + b_type_, + a_type_, + c_type_, + c_type_, algoId, &algo_); dyl::cublasLtMatmulAlgoConfigSetAttribute( @@ -172,37 +818,68 @@ class CublasLtHelper { CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); -#if CUDA_VERSION >= 11000 dyl::cublasLtMatmulAlgoConfigSetAttribute( &algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); -#endif -#endif +#endif // #if CUDA_VERSION >= 11010 + } + ~CublasLtHelper() { + dyl::cublasLtMatmulDescDestroy(matmul_desc_); + dyl::cublasLtMatrixLayoutDestroy(a_desc_); + dyl::cublasLtMatrixLayoutDestroy(b_desc_); + dyl::cublasLtMatrixLayoutDestroy(c_desc_); } - ~CublasLtHelper() {} - void GEMM(int8_t* A_dev, - const int8_t* B_dev, - int32_t* C_dev, + template + void GEMM(const InT* a_dev, + const InT* b_dev, + OutT* c_dev, cudaStream_t stream, - void* workspace = nullptr) { + void* workspace = nullptr, + size_t workspace_size = 0) { cublasStatus_t status; +#if CUDA_VERSION >= 11020 + // cublasLtMatmulAlgo_t* algo = + // CublasLtAlgoCache::Instance().CublasLtAlgoSelect(handle_, + // m_, + // n_, + // k_, + // b_dev, + // a_dev, + // c_dev, + // &alpha_, + // &beta_, + // matmul_desc_, + // b_desc_, + // a_desc_, + // c_desc_, + // compute_type_, + // scale_type_, + // b_type_, + // a_type_, + // c_type_, + // workspace, + // workspace_size, + // stream); + +#endif + status = dyl::cublasLtMatmul(handle_, matmul_desc_, &alpha_, - B_dev, - B_desc_, - A_dev, - A_desc_, + b_dev, + b_desc_, + a_dev, + a_desc_, &beta_, - C_dev, - C_desc_, - C_dev, - C_desc_, + c_dev, + c_desc_, + c_dev, + c_desc_, #if CUDA_VERSION >= 11020 &algo_, workspace, - workspace_size_, + workspace_size, #else nullptr, nullptr, @@ -221,20 +898,30 @@ class CublasLtHelper { private: cublasLtHandle_t handle_; cublasLtMatmulDesc_t matmul_desc_; - cublasLtMatrixLayout_t A_desc_; - cublasLtMatrixLayout_t B_desc_; - cublasLtMatrixLayout_t C_desc_; + cublasLtMatrixLayout_t a_desc_; + cublasLtMatrixLayout_t b_desc_; + cublasLtMatrixLayout_t c_desc_; cublasLtMatmulAlgo_t algo_; - int32_t alpha_; - int32_t beta_; + cudaDataType_t scale_type_; + cudaDataType_t a_type_; + cudaDataType_t b_type_; + cudaDataType_t c_type_; +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t compute_type_; +#else + cublasComputeType_t compute_type_; +#endif + + ScaleT alpha_; + ScaleT beta_; int m_; int k_; int n_; - size_t workspace_size_; + bool transpose_y_; }; } // namespace operators diff --git a/paddle/fluid/operators/fused/datatype_traits.h b/paddle/fluid/operators/fused/datatype_traits.h new file mode 100644 index 0000000000000..0ad677ba025ea --- /dev/null +++ b/paddle/fluid/operators/fused/datatype_traits.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" + +#pragma once + +namespace paddle { +namespace operators { + +namespace plat = paddle::platform; +using float16 = plat::float16; + + + +template +struct PDDataTypeTraits { + using DataType = T; +}; + +template <> +struct PDDataTypeTraits { + // Since LayerNormDirectCUDAFunctor register half type, we need to convert + // phi::float16 to half. + using DataType = half; +}; + +#ifdef PADDLE_CUDA_BF16 +template <> +class PDDataTypeTraits { + public: + using DataType = __nv_bfloat16; +}; +#endif + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 1d83c7a62b1d9..cd99b89c049f7 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" @@ -23,10 +24,34 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/transpose_function.cu.h" +#include "paddle/phi/kernels/fusion/fused_multihead_attention_kernel.h" +#include "paddle/phi/kernels/fusion/fused_multihead_attention_variable_kernel.h" +#include "paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" +DECLARE_string(fmha_mode); + +DECLARE_bool(print_matrix); +DECLARE_bool(fuse_softmax); + namespace paddle { namespace operators { +template +class PDTraits; + +template <> +class PDTraits { + public: + typedef float DataType; + typedef float data_t; +}; + +template <> +class PDTraits { + public: + typedef half DataType; + typedef paddle::float16 data_t; +}; class AttnDropoutParam { public: @@ -65,17 +90,19 @@ class AttnDropoutParam { template __global__ void TransposeRemovingPadding(const T* input_data, + const int* seq_lens, T* output_data, const int batch_size, const int num_head, + const int max_len_this_time, const int seq_len, const int head_dim, const int token_num, const int elem_cnt, const int* padding_offset) { // transpose and remove padding - // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, - // head_dim] + // [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, + // num_head, head_dim] int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; const int dim_embed = num_head * head_dim; using LoadT = phi::AlignedVector; @@ -89,11 +116,12 @@ __global__ void TransposeRemovingPadding(const T* input_data, const int ori_token_idx = token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); const int ori_batch_id = ori_token_idx / seq_len; + if (seq_lens && seq_lens[ori_batch_id] == 0) continue; const int ori_seq_id = ori_token_idx % seq_len; const int ori_head_id = (linear_index % dim_embed) / head_dim; const int ori_head_lane = (linear_index % dim_embed) % head_dim; - const int ori_idx = ori_batch_id * num_head * seq_len * head_dim + - ori_head_id * seq_len * head_dim + + const int ori_idx = ori_batch_id * num_head * max_len_this_time * head_dim + + ori_head_id * max_len_this_time * head_dim + ori_seq_id * head_dim + ori_head_lane; phi::Load(&input_data[ori_idx], &src_vec); phi::Store(src_vec, &output_data[linear_index]); @@ -103,15 +131,17 @@ __global__ void TransposeRemovingPadding(const T* input_data, template void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx, const T* input_data, + const int* seq_lens, T* output_data, const int batch_size, const int num_head, + const int max_len_this_time, const int seq_len, const int head_dim, const int token_num, const int* padding_offset) { - // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, - // head_dim] + // [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, + // num_head, head_dim] constexpr int VEC_16B = 16; const int elem_cnt = token_num * num_head * head_dim; constexpr int PackSize = VEC_16B / sizeof(T); @@ -125,9 +155,11 @@ void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx, int32_t grid_size = (pack_num + block_size - 1) / block_size; TransposeRemovingPadding <<>>(input_data, + seq_lens, output_data, batch_size, num_head, + max_len_this_time, seq_len, head_dim, token_num, @@ -147,31 +179,230 @@ class FMHARef { : dev_ctx_(dev_ctx), batch_size_(batch_size), seq_len_(seq_len), + ori_seq_len_(seq_len), + num_head_(num_head), + head_dim_(head_dim), + dropout_param_(param) {} + + FMHARef(const phi::GPUContext& dev_ctx, + int64_t batch_size, + int64_t seq_len, + int64_t ori_seq_len, + int64_t num_head, + int64_t head_dim, + AttnDropoutParam param) + : dev_ctx_(dev_ctx), + batch_size_(batch_size), + seq_len_(seq_len), + ori_seq_len_(ori_seq_len), num_head_(num_head), head_dim_(head_dim), dropout_param_(param) {} ~FMHARef() {} - void ComputeForward(const phi::DenseTensor& qkv_input_tensor, - const phi::DenseTensor* cache_kv_tensor, - const phi::DenseTensor* src_mask_tensor, - phi::DenseTensor* transpose_2_out_tensor, - phi::DenseTensor* cache_kv_out_tensor, - phi::DenseTensor* qk_out_tensor, - phi::DenseTensor* src_mask_out_tensor, - phi::DenseTensor* softmax_out_tensor, - phi::DenseTensor* dropout_mask_out_tensor, - phi::DenseTensor* dropout_out_tensor, - phi::DenseTensor* qktv_out_tensor, - phi::DenseTensor* fmha_out_tensor) { + void Compute(const phi::DenseTensor* cache_kv_tensor, + const phi::DenseTensor* src_mask_tensor, + const phi::DenseTensor* padding_offset_tensor, + const phi::DenseTensor* sequence_lengths_tensor, + phi::DenseTensor* q_transpose_out_tensor, + phi::DenseTensor* kv_transpose_out_tensor, + phi::DenseTensor* cache_kv_out_tensor, + phi::DenseTensor* qk_out_tensor, + phi::DenseTensor* src_mask_out_tensor, + phi::DenseTensor* softmax_out_tensor, + phi::DenseTensor* dropout_mask_out_tensor, + phi::DenseTensor* dropout_out_tensor, + phi::DenseTensor* qktv_out_tensor, + phi::DenseTensor* fmha_out_tensor, + const int token_num, + const bool mask_broadcast_num_heads = true) { + /* + Note(Zhengzekang): + There is some reason to pass some “unused” params for unify the interface. + - sequence_lengths_tensor(optional): It is only used in cutlass Variable + length FMHA. + - softmax_out_tensor(optional): It is not need when use cutlass FMHA, since + cutlass FMHA fused all the operations, it do not need to product + intermediate result for softmax. + - mask_broadcast_num_heads(optional): It is only need in Naive FMHA. + */ + if (FLAGS_fmha_mode == "cutlass") { +#ifdef PADDLE_WITH_CUTLASS + // Here Use cutlass Fused Multihead Attention Kernel. + ComputeForwardWithCutlassFMHA(cache_kv_tensor, + src_mask_tensor, + padding_offset_tensor, + sequence_lengths_tensor, + q_transpose_out_tensor, + kv_transpose_out_tensor, + cache_kv_out_tensor, + qk_out_tensor, + src_mask_out_tensor, + softmax_out_tensor, + dropout_mask_out_tensor, + dropout_out_tensor, + qktv_out_tensor, + fmha_out_tensor, + token_num); +#else + PADDLE_THROW(platform::errors::Unimplemented( + "CUTLASS need CUDA_VERSION >= 11.0")); +#endif + } else { + // Here Use Naive Attention Kernel. + ComputeForwardWithoutTranspose(cache_kv_tensor, + src_mask_tensor, + padding_offset_tensor, + sequence_lengths_tensor, + q_transpose_out_tensor, + kv_transpose_out_tensor, + cache_kv_out_tensor, + qk_out_tensor, + src_mask_out_tensor, + softmax_out_tensor, + dropout_mask_out_tensor, + dropout_out_tensor, + qktv_out_tensor, + fmha_out_tensor, + token_num, + mask_broadcast_num_heads); + } + } + + void ComputeForwardWithCutlassFMHA( + const phi::DenseTensor* cache_kv_tensor, + const phi::DenseTensor* src_mask_tensor, + const phi::DenseTensor* padding_offset_tensor, + const phi::DenseTensor* sequence_lengths_tensor, + phi::DenseTensor* q_transpose_out_tensor, + phi::DenseTensor* kv_transpose_out_tensor, + phi::DenseTensor* cache_kv_out_tensor, + phi::DenseTensor* qk_out_tensor, + phi::DenseTensor* src_mask_out_tensor, + phi::DenseTensor* softmax_out_tensor, + phi::DenseTensor* dropout_mask_out_tensor, + phi::DenseTensor* dropout_out_tensor, + phi::DenseTensor* qktv_out_tensor, + phi::DenseTensor* fmha_out_tensor, + const int token_num) { +#ifdef PADDLE_WITH_CUTLASS + // input shape: [bs, seq_len, 3, num_head, head_dim] + // transpose with perm [2, 0, 3, 1, 4], + // output_shape: [3, bs, num_head, seq_len, head_dim] + T* qktv_out_data = qktv_out_tensor->data(); + T* fmha_out_data = fmha_out_tensor->data(); + + int prompt_num = 0; + + auto out_seq_len = seq_len_; + if (cache_kv_tensor) { + prompt_num = cache_kv_tensor->dims()[3]; + // kv [2, bs, num_head, seq_len, head_dim] + phi::funcs::ConcatFunctor concat; + // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] + concat(dev_ctx_, + {*cache_kv_tensor, *kv_transpose_out_tensor}, + 3, + cache_kv_out_tensor); + out_seq_len = cache_kv_out_tensor->dims()[3]; + } + + int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; + T* q_ptr = q_transpose_out_tensor->data(); + T* k_ptr = nullptr; + T* v_ptr = nullptr; + + if (cache_kv_tensor) { + int64_t k_size = cache_kv_out_tensor->numel() / 2; + k_ptr = cache_kv_out_tensor->data(); + v_ptr = k_ptr + k_size; + } else { + int64_t k_size = q_size; + k_ptr = kv_transpose_out_tensor->data(); + v_ptr = k_ptr + k_size; + } + + float scale = 1.0f / sqrt(float(head_dim_)); // NOLINT + if (sequence_lengths_tensor) { + phi::fusion::MultiHeadAttentionVariableWrapper( + dev_ctx_, + q_ptr, + k_ptr, + v_ptr, + sequence_lengths_tensor->data(), + prompt_num == 0 ? nullptr : src_mask_tensor, + scale, + prompt_num == 0 ? true : false, + batch_size_, + num_head_, + seq_len_, + out_seq_len, + head_dim_, + head_dim_, + prompt_num, + qktv_out_data); + } else { + // Author(zhengzekang): we will check src_mask_tensor == nullptr in + // MultiHeadAttentionForwardWrapper. + phi::fusion::cutlass_internal:: + MultiHeadAttentionForwardWrapper(dev_ctx_, + q_ptr, + k_ptr, + v_ptr, + src_mask_tensor, + scale, + false, /*causal*/ + batch_size_, + num_head_, + seq_len_, + out_seq_len, + head_dim_, + qktv_out_data); + } + + // transpose: [0, 2, 1, 3] + // output shape: [batch_size, seq_len, num_heads, head_dim] + if (!padding_offset_tensor) { + std::vector perm_3 = {0, 2, 1, 3}; + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); + } else { + InvokeTransposeRemovePadding(dev_ctx_, + qktv_out_data, + sequence_lengths_tensor->data(), + fmha_out_data, + batch_size_, + num_head_, + seq_len_, + ori_seq_len_, + head_dim_, + token_num, + padding_offset_tensor->data()); + } +#endif + } + + void ComputeForwardWithoutTranspose( + const phi::DenseTensor* cache_kv_tensor, + const phi::DenseTensor* src_mask_tensor, + const phi::DenseTensor* padding_offset_tensor, + const phi::DenseTensor* sequence_lengths_tensor, + phi::DenseTensor* q_transpose_out_tensor, + phi::DenseTensor* kv_transpose_out_tensor, + phi::DenseTensor* cache_kv_out_tensor, + phi::DenseTensor* qk_out_tensor, + phi::DenseTensor* src_mask_out_tensor, + phi::DenseTensor* softmax_out_tensor, + phi::DenseTensor* dropout_mask_out_tensor, + phi::DenseTensor* dropout_out_tensor, + phi::DenseTensor* qktv_out_tensor, + phi::DenseTensor* fmha_out_tensor, + const int token_num, + const bool mask_broadcast_num_heads = true) { // input shape: [bs, seq_len, 3, num_head, head_dim] // transpose with perm [2, 0, 3, 1, 4], // output_shape: [3, bs, num_head, seq_len, head_dim] - std::vector perm_1 = {2, 0, 3, 1, 4}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, qkv_input_tensor, perm_1, transpose_2_out_tensor); - T* qkv_data = transpose_2_out_tensor->data(); T* qk_out_data = qk_out_tensor->data(); T* qktv_out_data = qktv_out_tensor->data(); T* softmax_out_data = softmax_out_tensor->data(); @@ -180,15 +411,17 @@ class FMHARef { auto out_seq_len = seq_len_; if (cache_kv_tensor) { // kv [2, bs, num_head, seq_len, head_dim] - auto kv_tensor = transpose_2_out_tensor->Slice(1, 3); phi::funcs::ConcatFunctor concat; // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] - concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor); + concat(dev_ctx_, + {*cache_kv_tensor, *kv_transpose_out_tensor}, + 3, + cache_kv_out_tensor); out_seq_len = cache_kv_out_tensor->dims()[3]; } int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - T* q_ptr = qkv_data; + T* q_ptr = q_transpose_out_tensor->data(); T* k_ptr = nullptr; T* v_ptr = nullptr; @@ -198,7 +431,7 @@ class FMHARef { v_ptr = k_ptr + k_size; } else { int64_t k_size = q_size; - k_ptr = q_ptr + q_size; + k_ptr = kv_transpose_out_tensor->data(); v_ptr = k_ptr + k_size; } @@ -206,10 +439,9 @@ class FMHARef { // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for // float16 calculation, INF may appear in QK^T if we do not scale before. float alpha = 1.0 / sqrt(head_dim_); - auto q_tensor = transpose_2_out_tensor->Slice(0, 1); auto functor = phi::funcs::ScaleFunctor(alpha); - std::vector ins = {&q_tensor}; - std::vector outs = {&q_tensor}; + std::vector ins = {q_transpose_out_tensor}; + std::vector outs = {q_transpose_out_tensor}; phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); } @@ -247,7 +479,10 @@ class FMHARef { batch_size_, num_head_, seq_len_, + mask_broadcast_num_heads, dev_ctx_.stream()); + // phi::fusion::FusedSoftmaxMaskKernel(dev_ctx_, + // *qk_out_tensor, *src_mask_tensor, softmax_out_tensor); } else { std::vector ins; std::vector outs; @@ -279,6 +514,7 @@ class FMHARef { stride_b = gemm_k * gemm_n; if (dropout_param_.dropout_prob_) { + T* dropout_out_data = dropout_out_tensor->data(); phi::funcs::DropoutFwGPUKernelDriver( static_cast(dev_ctx_), dropout_param_.is_test_, @@ -291,7 +527,6 @@ class FMHARef { dropout_mask_out_tensor, dropout_out_tensor, false); - T* dropout_out_data = dropout_out_tensor->data(); blas.BatchedGEMM(transA, transB, gemm_m, @@ -324,59 +559,74 @@ class FMHARef { } // transpose: [0, 2, 1, 3] // output shape: [batch_size, seq_len, num_heads, head_dim] - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); + if (!padding_offset_tensor) { + std::vector perm_3 = {0, 2, 1, 3}; + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); + } else { + InvokeTransposeRemovePadding(dev_ctx_, + qktv_out_data, + sequence_lengths_tensor->data(), + fmha_out_data, + batch_size_, + num_head_, + seq_len_, + ori_seq_len_, + head_dim_, + token_num, + padding_offset_tensor->data()); + } } - void ComputeForwardWithoutTranspose( - const phi::DenseTensor* cache_kv_tensor, - const phi::DenseTensor* src_mask_tensor, - const phi::DenseTensor* padding_offset_tensor, - phi::DenseTensor* q_transpose_out_tensor, - phi::DenseTensor* kv_transpose_out_tensor, - phi::DenseTensor* cache_kv_out_tensor, - phi::DenseTensor* qk_out_tensor, - phi::DenseTensor* src_mask_out_tensor, - phi::DenseTensor* softmax_out_tensor, - phi::DenseTensor* dropout_mask_out_tensor, - phi::DenseTensor* dropout_out_tensor, - phi::DenseTensor* qktv_out_tensor, - phi::DenseTensor* fmha_out_tensor, - const int token_num) { + void ComputeForward(const phi::DenseTensor& qkv_input_tensor, + const phi::DenseTensor* cache_kv_tensor, + const phi::DenseTensor* src_mask_tensor, + phi::DenseTensor* transpose_2_out_tensor, + phi::DenseTensor* cache_kv_out_tensor, + phi::DenseTensor* qk_out_tensor, + phi::DenseTensor* src_mask_out_tensor, + phi::DenseTensor* softmax_out_tensor, + phi::DenseTensor* dropout_mask_out_tensor, + phi::DenseTensor* dropout_out_tensor, + phi::DenseTensor* qktv_out_tensor, + phi::DenseTensor* fmha_out_tensor, + const bool mask_broadcast_num_heads = true) { // input shape: [bs, seq_len, 3, num_head, head_dim] // transpose with perm [2, 0, 3, 1, 4], // output_shape: [3, bs, num_head, seq_len, head_dim] + std::vector perm_1 = {2, 0, 3, 1, 4}; + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, qkv_input_tensor, perm_1, transpose_2_out_tensor); + T* qkv_data = transpose_2_out_tensor->data(); T* qk_out_data = qk_out_tensor->data(); T* qktv_out_data = qktv_out_tensor->data(); T* softmax_out_data = softmax_out_tensor->data(); - T* dropout_out_data = dropout_out_tensor->data(); T* fmha_out_data = fmha_out_tensor->data(); auto out_seq_len = seq_len_; if (cache_kv_tensor) { // kv [2, bs, num_head, seq_len, head_dim] + auto kv_tensor = transpose_2_out_tensor->Slice(1, 3); phi::funcs::ConcatFunctor concat; // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] - concat(dev_ctx_, - {*cache_kv_tensor, *kv_transpose_out_tensor}, - 3, - cache_kv_out_tensor); + concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor); out_seq_len = cache_kv_out_tensor->dims()[3]; } int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - T* q_ptr = q_transpose_out_tensor->data(); + T* q_ptr = qkv_data; T* k_ptr = nullptr; T* v_ptr = nullptr; + int64_t k_size; + if (cache_kv_tensor) { - int64_t k_size = cache_kv_out_tensor->numel() / 2; + k_size = cache_kv_out_tensor->numel() / 2; k_ptr = cache_kv_out_tensor->data(); v_ptr = k_ptr + k_size; } else { int64_t k_size = q_size; - k_ptr = kv_transpose_out_tensor->data(); + k_ptr = q_ptr + q_size; v_ptr = k_ptr + k_size; } @@ -384,9 +634,10 @@ class FMHARef { // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for // float16 calculation, INF may appear in QK^T if we do not scale before. float alpha = 1.0 / sqrt(head_dim_); + auto q_tensor = transpose_2_out_tensor->Slice(0, 1); auto functor = phi::funcs::ScaleFunctor(alpha); - std::vector ins = {q_transpose_out_tensor}; - std::vector outs = {q_transpose_out_tensor}; + std::vector ins = {&q_tensor}; + std::vector outs = {&q_tensor}; phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); } @@ -424,6 +675,7 @@ class FMHARef { batch_size_, num_head_, seq_len_, + mask_broadcast_num_heads, dev_ctx_.stream()); } else { std::vector ins; @@ -436,8 +688,8 @@ class FMHARef { dev_ctx_, ins, &outs, - elewise_add_axis, - phi::funcs::AddFunctor()); + phi::funcs::AddFunctor(), + elewise_add_axis); phi::SoftmaxForwardCUDAKernelDriver( dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); @@ -468,6 +720,7 @@ class FMHARef { dropout_mask_out_tensor, dropout_out_tensor, false); + T* dropout_out_data = dropout_out_tensor->data(); blas.BatchedGEMM(transA, transB, gemm_m, @@ -500,21 +753,9 @@ class FMHARef { } // transpose: [0, 2, 1, 3] // output shape: [batch_size, seq_len, num_heads, head_dim] - if (!padding_offset_tensor) { - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); - } else { - InvokeTransposeRemovePadding(dev_ctx_, - qktv_out_data, - fmha_out_data, - batch_size_, - num_head_, - seq_len_, - head_dim_, - token_num, - padding_offset_tensor->data()); - } + std::vector perm_3 = {0, 2, 1, 3}; + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); } void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor, @@ -742,6 +983,7 @@ class FMHARef { int64_t batch_size_; int64_t seq_len_; + int64_t ori_seq_len_; int64_t num_head_; int64_t head_dim_; diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 8025ba97ac004..4f0bc41ef215d 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -139,10 +139,12 @@ class FusedDropoutHelper { FusedDropoutHelper(const phi::GPUContext& ctx, const int rows, const int cols, - const DropoutParam& dropout_param) { + const DropoutParam& dropout_param, + const float residual_alpha = 1.0) { rows_ = rows; cols_ = cols; dropout_param_ = dropout_param; + residual_alpha_ = residual_alpha; } // out = residual + dropout( src + bias ) @@ -172,7 +174,8 @@ class FusedDropoutHelper { ctx, quant_last_in_scale, dequant_out_scale_data, - quant_next_in_scale); + quant_next_in_scale, + residual_alpha_); } void ResidualDropoutBiasGrad(const phi::GPUContext& ctx, @@ -340,6 +343,7 @@ class FusedDropoutHelper { int rows_; int cols_; DropoutParam dropout_param_; + float residual_alpha_; }; template @@ -364,20 +368,23 @@ class FusedDropoutLayerNormHelper FusedDropoutLayerNormHelper() {} FusedDropoutLayerNormHelper(const int rows, const int cols, - const float epsilon) { + const float epsilon, + const float residual_alpha = 1.0) { using U = LayerNormParamType; this->rows_ = rows; this->cols_ = cols; epsilon_ = epsilon; + this->residual_alpha_ = residual_alpha; } FusedDropoutLayerNormHelper(const phi::GPUContext& ctx, const int rows, const int cols, const DropoutParam& dropout_param, - const float epsilon) + const float epsilon, + const float residual_alpha = 1.0) : FusedDropoutHelper( - ctx, rows, cols, dropout_param) { + ctx, rows, cols, dropout_param, residual_alpha) { using U = LayerNormParamType; epsilon_ = epsilon; } @@ -490,7 +497,8 @@ class FusedDropoutLayerNormHelper quant_next_in_scale, quant_round_type, quant_max_bound, - quant_min_bound); + quant_min_bound, + this->residual_alpha_); } template , bool is_same_type = false> diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index a6bd467dc1992..cd819c0b494fa 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -843,7 +843,8 @@ void LaunchLayernormResidualDropoutBias( const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, - const float quant_min_bound = -127.0) { + const float quant_min_bound = -127.0, + const float residual_alpha = 1.0) { // dropout_prob == 1.0f // NOTE(minghaoBD): OutType should be T if drop_out_rate == 1.0 if (std::abs(dropout_prob - 1.0f) < 1e-5) { diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h new file mode 100644 index 0000000000000..246431a535120 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h @@ -0,0 +1,918 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/operators/fused/attn_gemm_int8.h" +// #include +// "paddle/fluid/operators/fused/cutlass/cutlass_kernels/gemm_dequant.h" +// #include +// "paddle/fluid/operators/fused/cutlass/cutlass_kernels/intA_intB_interleaved_gemm/intA_intB_gemm_template.h" +#include "paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h" + +DECLARE_int64(custom_allreduce_one_shot_threshold); +DECLARE_int64(custom_allreduce_two_shot_threshold); + +DECLARE_bool(use_gemm_dequant); + +DECLARE_bool(print_matrix); +/* +Note(Zhengzekang): +This header file is to store General Function Helper which has been used in +FusedMultiTransformer. +*/ + +namespace paddle { +namespace operators { + +template +static void PrintFrontNPerLine(const phi::DenseTensor &a, + int rows, + int cols, + int n) { + if (!FLAGS_print_matrix) return; + std::vector a_h(a.numel()); + + cudaMemcpy( + a_h.data(), a.data(), a.numel() * sizeof(T), cudaMemcpyDeviceToHost); + + for (int line = 0; line < rows; ++line) { + std::cout << "[" << line << "] "; + for (int i = 0; i < n; ++i) { + if (std::is_same::value) { + std::cout << (int)(a_h[line * cols + i]) << " "; // NOLINT + } else { + std::cout << a_h[line * cols + i] << " "; // NOLINT + } + } + std::cout << "\n"; + } +} + +static CustomNCCLComm *GetCustomNCCLComm(const phi::GPUContext &ctx, + int ring_id) { + static auto comm = + CreateCustomNCCLComm(ctx, + FLAGS_custom_allreduce_one_shot_threshold, + FLAGS_custom_allreduce_two_shot_threshold, + ring_id); + return comm.get(); +} +template +struct AddTriFunctor { + inline HOSTDEVICE T operator()(const T a, const T b, const T c) const { + return a + b + c; + } +}; + +template +struct SmoothFunctor { + inline HOSTDEVICE T operator()(const T a, const T b, const T c) const { + return (a + b) * c; + } +}; + +namespace { // NOLINT + +template +class BiasActHelper { + public: + BiasActHelper(const phi::GPUContext &dev_ctx, + const std::string &act_method, + int rows, + int cols) + : dev_ctx_(dev_ctx), act_method_(act_method), rows_(rows), cols_(cols) {} + + // dst = Activation(x + bias(optional)) + void Compute(const phi::DenseTensor *x, + const phi::DenseTensor *bias, + phi::DenseTensor *output) { + const T *bias_data = (bias == nullptr) ? nullptr : bias->data(); + Load load_func(x->data()); + Store store_func(output->data()); + ComputeImpl(bias_data, load_func, store_func); + } + + void Compute(const phi::DenseTensor *x, + const phi::DenseTensor *bias, + const phi::DenseTensor *dequant_scales, + const phi::DenseTensor *shift, + const phi::DenseTensor *smooth, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + phi::DenseTensor *output) { + if (shift != nullptr) { + DispatchComputeImpl(x, + bias, + dequant_scales, + shift, + smooth, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + output); + } else { + DispatchComputeImpl(x, + bias, + dequant_scales, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + output); + } + } + + private: + void DispatchComputeImpl(const phi::DenseTensor *x, + const phi::DenseTensor *bias, + const phi::DenseTensor *dequant_scales, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + phi::DenseTensor *output) { + const T *bias_data = (bias == nullptr) ? nullptr : bias->data(); + if (dequant_scales != nullptr && quant_scale > 0) { + DequantLoad load_func( + x->data(), dequant_scales->data(), cols_); + QuantStore store_func(output->data(), + quant_round_type, + quant_scale, + quant_max_bound, + quant_min_bound); + ComputeImpl, QuantStore, int32_t>( + bias_data, load_func, store_func); + } else if (dequant_scales == nullptr && quant_scale > 0) { + Load load_func(x->data()); + QuantStore store_func(output->data(), + quant_round_type, + quant_scale, + quant_max_bound, + quant_min_bound); + ComputeImpl(bias_data, load_func, store_func); + } else if (dequant_scales != nullptr && quant_scale <= 0) { + DequantLoad load_func( + x->data(), dequant_scales->data(), cols_); + Store store_func(output->data()); + ComputeImpl, Store, int32_t>( + bias_data, load_func, store_func); + } else { + Load load_func(x->data()); + Store store_func(output->data()); + ComputeImpl(bias_data, load_func, store_func); + } + } + + void DispatchComputeImpl(const phi::DenseTensor *x, + const phi::DenseTensor *bias, + const phi::DenseTensor *dequant_scales, + const phi::DenseTensor *shift, + const phi::DenseTensor *smooth, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + phi::DenseTensor *output) { + bool use_glu = (act_method_ == "geglu" || act_method_ == "swiglu"); + const T *bias_data = (bias == nullptr) ? nullptr : bias->data(); + if (dequant_scales != nullptr && quant_scale > 0) { + DequantLoad load_func( + x->data(), dequant_scales->data(), cols_); + QuantStore store_func(output->data(), + shift->data(), + smooth->data(), + use_glu ? cols_ / 2 : cols_, + quant_round_type, + quant_scale, + quant_max_bound, + quant_min_bound); + ComputeImpl, QuantStore, int32_t>( + bias_data, load_func, store_func); + } else if (dequant_scales == nullptr && quant_scale > 0) { + Load load_func(x->data()); + QuantStore store_func(output->data(), + shift->data(), + smooth->data(), + use_glu ? cols_ / 2 : cols_, + quant_round_type, + quant_scale, + quant_max_bound, + quant_min_bound); + ComputeImpl(bias_data, load_func, store_func); + } else if (dequant_scales != nullptr && quant_scale <= 0) { + DequantLoad load_func( + x->data(), dequant_scales->data(), cols_); + Store store_func(output->data(), + shift->data(), + smooth->data(), + use_glu ? cols_ / 2 : cols_); + ComputeImpl, Store, int32_t>( + bias_data, load_func, store_func); + } else { + Load load_func(x->data()); + Store store_func(output->data(), + shift->data(), + smooth->data(), + use_glu ? cols_ / 2 : cols_); + ComputeImpl(bias_data, load_func, store_func); + } + } + + template + void ComputeImpl(const T *bias_data, + LoadFunc load_func, + StoreFunc store_func) { + if (act_method_ == "geglu") { + // Note(Zhengzekang): For GLU structure, we need divide the cols by 2. + VLOG(5) << "doing geglu"; + LaunchActFFNGlu, LoadFunc, StoreFunc, LoadT>( + dev_ctx_, bias_data, rows_, cols_ / 2, load_func, store_func); + } else if (act_method_ == "swiglu") { + VLOG(5) << "doing swiglu"; + LaunchActFFNGlu, LoadFunc, StoreFunc, LoadT>( + dev_ctx_, bias_data, rows_, cols_ / 2, load_func, store_func); + } else if (act_method_ == "gelu") { + if (FLAGS_use_fast_math) { + VLOG(5) << "doing Fast GELU"; + LaunchBiasAct, LoadFunc, StoreFunc, LoadT>( + dev_ctx_, bias_data, rows_, cols_, load_func, store_func); + } else { + VLOG(5) << "doing GELU"; + LaunchBiasAct, LoadFunc, StoreFunc, LoadT>( + dev_ctx_, bias_data, rows_, cols_, load_func, store_func); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently Only Support GeGLU, SwiGLU, GeLU")); + } + } + const phi::GPUContext &dev_ctx_; + std::string act_method_; + int rows_; + int cols_; +}; + +template ::DataType> +class GEMMHelper { + public: + GEMMHelper( + const phi::GPUContext &dev_ctx, + int token_num, + int dim_ffn, + int dim_embed, + const std::string gemm_method, + // paddle::operators::CutlassFpAIntBGemmRunner + // *int8_mixed_gemm_runner, + // paddle::operators::CutlassFpAIntBGemmRunner + // *int4_mixed_gemm_runner, + // paddle::operators::CutlassIntAIntBInterleavedGemmRunner + // *int8_int8_interleaved_gemm_runner, + bool transpose_weight = false) + : dev_ctx_(dev_ctx), + token_num_(token_num), + dim_ffn_(dim_ffn), + dim_embed_(dim_embed), + gemm_method_(gemm_method), + // int8_mixed_gemm_runner_(int8_mixed_gemm_runner), + // int4_mixed_gemm_runner_(int4_mixed_gemm_runner), + // int8_int8_interleaved_gemm_runner_(int8_int8_interleaved_gemm_runner), + transpose_weight_(transpose_weight) {} + + // dst = act(fc(src[0]) + bias) * src[1] + void Compute(const phi::DenseTensor *input, + const phi::DenseTensor *weight, + const phi::DenseTensor *scale, + const phi::DenseTensor *bias, + phi::DenseTensor *workspace, + phi::DenseTensor *output) { + VLOG(5) << "GEMMHelper," + << " token_num_:" << token_num_ << " dim_ffn_:" << dim_ffn_ + << " dim_embed_:" << dim_embed_; + bool compute_bias = true; + if (bias == nullptr) { + compute_bias = false; + } + using NvType = typename PDDataTypeTraits::DataType; + + if (gemm_method_ == "weight-only") { + // VLOG(5) << "do weight-only gemm int8"; + // if (bias) { + // int8_mixed_gemm_runner_->gemm_bias_act( + // reinterpret_cast(input->data()), + // reinterpret_cast(weight->data()), + // reinterpret_cast(scale->data()), + // reinterpret_cast(bias->data()), + // reinterpret_cast(output->data()), + // token_num_, + // dim_ffn_, + // dim_embed_, + // "none", + // reinterpret_cast(workspace->data()), + // workspace->numel(), + // dev_ctx_.stream()); + // } else { + // int8_mixed_gemm_runner_->gemm( + // reinterpret_cast(input->data()), + // reinterpret_cast(weight->data()), + // reinterpret_cast(scale->data()), + // reinterpret_cast(output->data()), + // token_num_, + // dim_ffn_, + // dim_embed_, + // reinterpret_cast(workspace->data()), + // workspace->numel(), + // dev_ctx_.stream()); + // } + // VLOG(5) << "input:" << *input; + // VLOG(5) << "output:" << *output; + } else if (gemm_method_ == "weight-only-int4") { + // VLOG(5) << "do weight-only gemm"; + // if (bias) { + // int4_mixed_gemm_runner_->gemm_bias_act( + // reinterpret_cast(input->data()), + // reinterpret_cast(weight->data()), reinterpret_cast(scale->data()), reinterpret_cast(bias->data()), reinterpret_cast(output->data()), token_num_, dim_ffn_, dim_embed_, "none", + // reinterpret_cast(workspace->data()), + // workspace->numel(), + // dev_ctx_.stream()); + // } else { + // int4_mixed_gemm_runner_->gemm( + // reinterpret_cast(input->data()), + // reinterpret_cast(weight->data()), reinterpret_cast(scale->data()), reinterpret_cast(output->data()), token_num_, dim_ffn_, dim_embed_, + // reinterpret_cast(workspace->data()), + // workspace->numel(), + // dev_ctx_.stream()); + // } + // VLOG(5) << "input:" << *input; + // VLOG(5) << "output:" << *output; + } else if (gemm_method_ == "weightonly_gemv") { + // // TODO(zhengzekang): support weightonly gemv int4 + // const T *bias_data = bias ? bias->data() : nullptr; + // phi::GemvWeightonlyInt8Wrapper(dev_ctx_, + // input->data(), + // weight->data(), + // bias_data, + // scale->data(), + // dim_ffn_, + // dim_embed_, + // "None", + // /*act_method*/ + // output->data()); + } else if (gemm_method_ == "LLM.int8") { + // // Note(Zhengzekang): LLM Gemm donot support fused add_bias. + // LLMGemm(dev_ctx_, + // weight, + // input, + // scale, + // int8_int8_interleaved_gemm_runner_, + // FLAGS_custom_llm_int8_threshold, + // output, + // workspace, + // "LLMGemm", + // token_num_, + // dim_embed_, + // dim_ffn_); + } else if (gemm_method_ == "None") { + auto ffn_linear_compute = AttnMatMul(dev_ctx_, + false, + transpose_weight_, + token_num_, + dim_ffn_, + dim_embed_, + compute_bias); + ffn_linear_compute.ComputeForward(weight, input, bias, output, output); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently GemmHelper only support `weight-only`, `LLM.int8`, " + "`None`. ")); + } + } + + private: + const phi::GPUContext &dev_ctx_; + int token_num_; + int dim_ffn_; + int dim_embed_; + std::string gemm_method_; + // paddle::operators::CutlassFpAIntBGemmRunner + // *int8_mixed_gemm_runner_; + // paddle::operators::CutlassFpAIntBGemmRunner + // *int4_mixed_gemm_runner_; + // paddle::operators::CutlassIntAIntBInterleavedGemmRunner + // *int8_int8_interleaved_gemm_runner_; + bool transpose_weight_; // Just For AttnMatmul. +}; + +template +class Int8GEMMHelper { + public: + Int8GEMMHelper(const phi::GPUContext &dev_ctx, + int m, + int k, + int n, + phi::DenseTensor &workspace, // NOLINT + phi::DenseTensor &input_workspace, // NOLINT + phi::DenseTensor &out_workspace, // NOLINT + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + bool use_gemm_dequant = false) + : dev_ctx_(dev_ctx), + m_(m), + k_(k), + n_(n), + use_gemm_dequant_(use_gemm_dequant), + quant_round_type_(quant_round_type), + quant_min_bound_(quant_min_bound), + quant_max_bound_(quant_max_bound), + workspace_(workspace), + input_workspace_(input_workspace), + out_workspace_(out_workspace) { + cublaslt_helper = std::make_unique>( + m, k, n, dev_ctx.cublaslt_handle()); + } + + void Compute(const phi::DenseTensor *input, + const phi::DenseTensor *weight, // int8, Need be transposed + const phi::DenseTensor *dequant_out_scales, + const float quant_in_scale, + phi::DenseTensor *output, + bool quant_in = false, + bool dequant_out = false) { + phi::DenseTensor input_tmp, out_tmp; + if (quant_in) { + input_tmp = input_workspace_; + quantize_kernel_launcher(input->data(), + input_tmp.data(), + quant_in_scale, + m_, + k_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_, + dev_ctx_.stream()); + } else { + input_tmp = *input; + } + + if (dequant_out) { + out_tmp = out_workspace_; + } else { + out_tmp = *output; + } + + if (use_gemm_dequant_ && dequant_out) { + // RunGemmDequant(input_tmp.data(), + // weight->data(), + // dequant_out_scales->data(), + // output->data(), + // m_, + // k_, + // n_, + // dev_ctx_.stream()); + } else { + cublaslt_helper->GEMM(input_tmp.data(), + weight->data(), + out_tmp.data(), + dev_ctx_.stream(), + (void *)workspace_.data(), + workspace_.numel()); + } + + if (!use_gemm_dequant_ && dequant_out) { + auto gpu_config = std::make_unique( + phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx_, m_ * n_, DequantKernelVecSize)); + dequantize_kernel_launcher(out_tmp.data(), + output->data(), + m_, + n_, + dev_ctx_.stream(), + gpu_config.get(), + quant_in_scale, + dequant_out_scales->data()); + } + } + + private: + const phi::GPUContext &dev_ctx_; + int m_; + int k_; + int n_; + int quant_round_type_; + float quant_max_bound_; + float quant_min_bound_; + bool use_gemm_dequant_; + phi::DenseTensor &workspace_; // char + phi::DenseTensor &input_workspace_; // int8_t + phi::DenseTensor &out_workspace_; // int32_t + + std::unique_ptr> cublaslt_helper; +}; + +template +class LtGEMMHelper { + public: + LtGEMMHelper( + const phi::GPUContext &dev_ctx, int m, int k, int n, bool transpose_y) + : dev_ctx_(dev_ctx), m_(m), k_(k), n_(n) { + cublaslt_helper = std::make_unique>( + m, k, n, dev_ctx.cublaslt_handle(), transpose_y); + } + + void Compute(const phi::DenseTensor *input, + const phi::DenseTensor *weight, + phi::DenseTensor *output) { + cublaslt_helper->GEMM(input->data(), + weight->data(), + output->data(), + dev_ctx_.stream(), + nullptr, + 0); + } + + private: + const phi::GPUContext &dev_ctx_; + int m_; + int k_; + int n_; + + std::unique_ptr> cublaslt_helper; +}; + +template +class NormHelper { + public: + NormHelper(const phi::GPUContext &dev_ctx, + const std::string &norm_type, + const int rows, + const int cols, + const float epsilon, + const float residual_alpha) + : dev_ctx_(dev_ctx), + norm_type_(norm_type), + rows_(rows), + cols_(cols), + epsilon_(epsilon), + residual_alpha_( + residual_alpha), // TODO(zhengzekang): currently only available for + // Layernorm. Need support rmsnorm. + layernorm_helper_(dev_ctx_, epsilon_, rows_, cols_) { + // VLOG(0) << "NormHelper residual_alpha:" << residual_alpha_; + DropoutParam dropout_param(true, 0, true, true, 0.0, nullptr, 0); + residual_bias_add_layernorm_helper_ = + FusedDropoutLayerNormHelper( + dev_ctx, rows_, cols_, dropout_param, epsilon_, residual_alpha_); + } + + /* + Note(Zhengzekang): + Since input `X` and `Residual` in FusedMT will be swaped by preallocated + buffer, I have no choice but to pass the data pointer instead of + phi::DenseTensor. + */ + + // dst = Norm(x + residual + bias(optional)) + void NormResidualBias(const T *x_data, + const T *residual_data, + const phi::DenseTensor *bias, + const phi::DenseTensor *norm_weight, + const phi::DenseTensor *norm_bias, + phi::DenseTensor *mean, + phi::DenseTensor *var, + phi::DenseTensor *bias_residual_out, + phi::DenseTensor *output) { + using U = LayerNormParamType; + const T *bias_data = bias ? bias->data() : nullptr; + U *mean_data = mean ? mean->data() : nullptr; + U *var_data = var ? var->data() : nullptr; + T *bias_residual_out_data = bias_residual_out->data(); + T *output_data = output->data(); + + if (norm_type_ == "layernorm") { + // For layernorm, it use FP32 type weight and bias. + const U *norm_weight_data = + norm_weight ? norm_weight->data() : nullptr; + const U *norm_bias_data = norm_bias ? norm_bias->data() : nullptr; + residual_bias_add_layernorm_helper_.LayernormResidualDropoutBias( + dev_ctx_, + x_data, + residual_data, + bias_data, + norm_weight_data, + norm_bias_data, + bias_residual_out_data, + nullptr, + output_data, + mean_data, + var_data); + // } else if (norm_type_ == "rmsnorm") { + // // For rmsnorm, it use Input's type weight and bias. + // const T *norm_weight_data = + // norm_weight ? norm_weight->data() : nullptr; + // const T *norm_bias_data = norm_bias ? norm_bias->data() : nullptr; + // phi::ResidualAddRmsNormWrapper(dev_ctx_, + // x_data, + // residual_data, + // bias_data, + // norm_weight_data, + // norm_bias_data, + // epsilon_, + // rows_, + // cols_, + // bias_residual_out_data, + // output_data); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently NormHelper only support `layernorm`, `rmsnorm`. ")); + } + } + + // dst = Norm(x) + void Norm(const T *x_data, + const phi::DenseTensor *norm_weight, + const phi::DenseTensor *norm_bias, + phi::DenseTensor *mean, + phi::DenseTensor *var, + phi::DenseTensor *output) { + using U = LayerNormParamType; + U *mean_data = mean ? mean->data() : nullptr; + U *var_data = var ? var->data() : nullptr; + T *output_data = output->data(); + + if (norm_type_ == "layernorm") { + // For layernorm, it use FP32 type weight and bias. + const U *norm_weight_data = + norm_weight ? norm_weight->data() : nullptr; + const U *norm_bias_data = norm_bias ? norm_bias->data() : nullptr; + layernorm_helper_.ComputeForward(x_data, + norm_weight_data, + norm_bias_data, + output_data, + mean_data, + var_data); + // } else if (norm_type_ == "rmsnorm") { + // // For rmsnorm, it use Input's type weight and bias. + // const T *norm_weight_data = + // norm_weight ? norm_weight->data() : nullptr; + // const T *norm_bias_data = norm_bias ? norm_bias->data() : nullptr; + // phi::RmsNormWrapper(dev_ctx_, + // x_data, + // norm_weight_data, + // norm_bias_data, + // epsilon_, + // rows_, + // cols_, + // output_data); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently NormHelper only support `layernorm`, `rmsnorm`. ")); + } + } + + private: + const phi::GPUContext &dev_ctx_; + std::string norm_type_; + int rows_; + int cols_; + float epsilon_; + float residual_alpha_; + FusedDropoutLayerNormHelper residual_bias_add_layernorm_helper_; + AttnLayerNorm layernorm_helper_; +}; + +template ::DataType> +class FFNHelper { + public: + FFNHelper(const phi::GPUContext &dev_ctx, + const std::string &act_method, + int token_num, + int dim_ffn, + int dim_embed, + const std::string gemm_method) + : dev_ctx_(dev_ctx), + act_method_(act_method), + token_num_(token_num), + dim_ffn_(dim_ffn), + dim_embed_(dim_embed), + gemm_method_(gemm_method) {} + + // dst = act(fc(src[0]) + bias) * src[1] + void Compute(const phi::DenseTensor *input, + const phi::DenseTensor *weight, + const phi::DenseTensor *scale, + const phi::DenseTensor *bias, + phi::DenseTensor *workspace, + phi::DenseTensor *bias_out, + phi::DenseTensor *output) { + /* + input's shape [token_num, dim_embed] + weight's shape [dim_embed, dim_ffn] + bias' shape [dim_ffn] + output's shape [token_num, dim_ffn]. + */ + // for debug + VLOG(5) << "FFNHelper," + << " token_num_:" << token_num_ << " dim_ffn_:" << dim_ffn_ + << " dim_embed_:" << dim_embed_; + GEMMHelper gemm_helper( + dev_ctx_, token_num_, dim_ffn_, dim_embed_, gemm_method_); + BiasActHelper bias_act_helper( + dev_ctx_, act_method_, token_num_, dim_ffn_); + + gemm_helper.Compute(input, weight, scale, bias, workspace, bias_out); + if (gemm_method_ == "LLm.int8") { + bias_act_helper.Compute(bias_out, bias, output); + } else { + // Note(Zhengzekang): Other Gemm method can fuse bias add. + bias_act_helper.Compute(bias_out, nullptr, output); + } + } + + private: + const phi::GPUContext &dev_ctx_; + std::string act_method_; + int token_num_; + int dim_ffn_; + int dim_embed_; + std::string gemm_method_; +}; + + +template +class WriteCacheKVHelper { + public: + WriteCacheKVHelper(const phi::GPUContext &dev_ctx, + int quant_round_type, + float quant_max_bound, + float quant_min_bound) + : dev_ctx_(dev_ctx), + quant_round_type_(quant_round_type), + quant_min_bound_(quant_min_bound), + quant_max_bound_(quant_max_bound) {} + + void Compute(const phi::DenseTensor *pre_cache_kv_out, + phi::DenseTensor *cache_kv_out, + const phi::DenseTensor *kv_transpose_out, + const int *sequence_lengths_data, + const int cache_bsz, + const int bsz, + const int num_head, + const int seq_len, + const int dim_head, + const int cache_offset, + const float cache_k_scale, + const float cache_v_scale) { + if (cache_k_scale > 0) { + WriteInt8CacheKV(dev_ctx_, + pre_cache_kv_out, + cache_kv_out, + kv_transpose_out, + sequence_lengths_data, + cache_bsz, + bsz, + num_head, + seq_len, + dim_head, + cache_offset, + quant_round_type_, + quant_max_bound_, + quant_min_bound_, + cache_k_scale, + cache_v_scale); + } else { + WriteCacheKV(dev_ctx_, + pre_cache_kv_out, + cache_kv_out, + kv_transpose_out, + sequence_lengths_data, + cache_bsz, + bsz, + num_head, + seq_len, + dim_head, + cache_offset); + } + } + + private: + const phi::GPUContext &dev_ctx_; + int quant_round_type_; + float quant_max_bound_; + float quant_min_bound_; +}; + +template +class AttnOutHelper { + public: + AttnOutHelper(const phi::GPUContext &dev_ctx, + phi::DenseTensor &workspace, // NOLINT + phi::DenseTensor &tmp_quant_space, // NOLINT + phi::DenseTensor &tmp_dequant_space, // NOLINT + int token_num, // m + int hidden_size, // k + int dim_embed, // n + int ring_id, + CustomNCCLComm *nccl_comm, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + bool is_decoder) + : dev_ctx_(dev_ctx), + token_num_(token_num), + hidden_size_(hidden_size), + dim_embed_(dim_embed), + ring_id_(ring_id), + nccl_comm_(nccl_comm), + quant_round_type_(quant_round_type), + quant_min_bound_(quant_min_bound), + quant_max_bound_(quant_max_bound), + workspace_(workspace), + tmp_quant_space_(tmp_quant_space), + tmp_dequant_space_(tmp_dequant_space), + is_decoder_(is_decoder) { + int8_gemm_helper_ = std::make_unique>( + dev_ctx_, + token_num, + hidden_size, + dim_embed, + workspace, + tmp_quant_space, + tmp_dequant_space, + quant_round_type, + quant_max_bound, + quant_min_bound, + is_decoder && FLAGS_use_gemm_dequant /*use_gemm_dequant*/); + gemm_helper_ = std::make_unique>( + dev_ctx_, token_num, hidden_size, dim_embed, false); + } + + void Compute(const phi::DenseTensor &input, + const phi::DenseTensor &weight, + const phi::DenseTensor &dequant_out_scales, + const float in_scale, + phi::DenseTensor *output) { + if (nccl_comm_) { + nccl_comm_->SwapInput(output); + } + if (in_scale > 0) { + int8_gemm_helper_->Compute(&input, // T + &weight, // int8, Need be transposed + &dequant_out_scales, + in_scale, + output, // T + !is_decoder_, // quant in mmha in decoder + true); // need to dequant cause allreduce + } else { + gemm_helper_->Compute(&input, &weight, output); + } + if (nccl_comm_) { + *output = nccl_comm_->AllReduce(); + } else { + AllReduce(*output, ring_id_, output->numel(), dev_ctx_); + } + } + + private: + const phi::GPUContext &dev_ctx_; + int token_num_; // m + int hidden_size_; // k + int dim_embed_; // n + int ring_id_; + int quant_round_type_; + float quant_max_bound_; + float quant_min_bound_; + bool is_decoder_; + CustomNCCLComm *nccl_comm_; + phi::DenseTensor &workspace_; // char + phi::DenseTensor &tmp_quant_space_; // int8_t + phi::DenseTensor &tmp_dequant_space_; // int32_t + std::unique_ptr> int8_gemm_helper_; + std::unique_ptr> gemm_helper_; +}; + + +} // namespace + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index 4db07640f8359..4647b48d50304 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -22,638 +22,655 @@ template class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - using U = LayerNormParamType; - auto &dev_ctx = ctx.cuda_device_context(); - - auto *time_step = ctx.Input("TimeStep"); - // 0. input - auto *input_x = ctx.Input("X"); - const auto input_x_dims = input_x->dims(); - int bsz = input_x_dims[0]; - int seq_len = input_x_dims[1]; - int dim_embed = input_x_dims[2]; - int bsz_seq = bsz * seq_len; - - // quant input scales, vector, size = num_layers - auto qkv_in_scale = ctx.Attr>("qkv_in_scale"); - auto out_linear_in_scale = - ctx.Attr>("out_linear_in_scale"); - auto ffn1_in_scale = ctx.Attr>("ffn1_in_scale"); - auto ffn2_in_scale = ctx.Attr>("ffn2_in_scale"); - - // quant round type and bound - auto quant_round_type = ctx.Attr("quant_round_type"); - auto quant_max_bound = ctx.Attr("quant_max_bound"); - auto quant_min_bound = ctx.Attr("quant_min_bound"); - - // dequant output scales, tensor, size = [num_layers, n], n is gemm output - // size - auto qkv_out_scales = ctx.MultiInput("QKVOutScale"); - auto out_linear_out_scales = - ctx.MultiInput("OutLinearOutScale"); - auto ffn1_out_scales = ctx.MultiInput("FFN1OutScale"); - auto ffn2_out_scales = ctx.MultiInput("FFN2OutScale"); - - // 1. layer norm - const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); - const float epsilon = ctx.Attr("epsilon"); - auto ln_scales = ctx.MultiInput("LnScale"); - auto ln_biases = ctx.MultiInput("LnBias"); - - auto ln_compute = - AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); - phi::DenseTensor ln_mean, ln_var; - ln_mean.Resize({{bsz_seq}}); - auto *ln_mean_data = - dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_var.Resize({{bsz_seq}}); - auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); - - // 2. qkv - // x: qkv's input [batch_size, seq_len, dim_embed] - // y: qkv's weight: [3, num_head, dim_head, dim_embed] - auto qkv_weights = ctx.MultiInput("QKVW"); - auto qkv_biases = ctx.MultiInput("QKVBias"); - const bool trans_qkvw = ctx.Attr("trans_qkvw"); - const auto qkv_w_dims = qkv_weights[0]->dims(); - int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; - int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; - int hidden_size = num_head * dim_head; - int output_size = 3 * hidden_size; - int input_size = dim_embed; - - bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; - // (transA, transB, compute_bias) = (false, trans_qkvw, false) - AttnMatmulINT8 qkv_compute( - dev_ctx, bsz_seq, output_size, input_size, compute_bias); - phi::DenseTensor qkv_out; - qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); - auto *qkv_out_data = - dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); - - // 3. fmha - AttnDropoutParam attn_param( - true, "upscale_in_train", 0.0, true, true, 0, nullptr); - auto fmha_compute = - FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); - auto *src_mask = ctx.Input("SrcMask"); - auto cache_kvs = ctx.MultiInput("CacheKV"); - auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); - // auto *time_step = ctx.Input("TimeStep"); - - auto out_seq_len = seq_len; - if (time_step) { - PADDLE_ENFORCE_EQ(time_step->place(), - platform::CPUPlace(), - platform::errors::PreconditionNotMet( - "The place of input(TimeStep) must be CPUPlace.")); - // cache_seq_len - int time_step_value = time_step->data()[0]; - PADDLE_ENFORCE_GT(time_step_value, - 0, - platform::errors::PreconditionNotMet( - "The value of time_step must > 0, but now is %d", - time_step_value)); - PADDLE_ENFORCE_EQ( - seq_len, - 1, - platform::errors::PreconditionNotMet( - "In decode stage, the seq_len of input must be 1, but now is %d", - seq_len)); - out_seq_len += time_step_value; - } - - phi::DenseTensor transpose_out_2, qk_out; - transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); - auto *transpose_out_2_data = - dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); - - qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); - - phi::DenseTensor softmax_out; - phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; - phi::DenseTensor qktv_out, fmha_out; - softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *softmax_out_data = - dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); - - attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_mask_out_data = dev_ctx.Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_data_data = dev_ctx.Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); - - qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); - auto *qktv_out_data = - dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); - auto *fmha_out_data = - dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); - - // 4. out_linear - auto out_linear_weights = ctx.MultiInput("OutLinearW"); - auto out_linear_biases = ctx.MultiInput("OutLinearBias"); - int ring_id = ctx.Attr("ring_id"); - // (transA, transB, compute_bias) = (false, false, false) - AttnMatmulINT8 out_linear_compute( - dev_ctx, bsz_seq, dim_embed, hidden_size, false); - - // 5. ln(residual + bias) - DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper - fused_dropout_layernorm_helper( - dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); - FusedDropoutLayerNormHelper - fused_dropout_layernorm_helper_for_post_layernorm( - dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); - auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); - auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); - phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; - T *bias_dropout_residual_out_data = nullptr; - if (pre_layer_norm) { - bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}}); - bias_dropout_residual_out_data = - dev_ctx.Alloc(&bias_dropout_residual_out, - bias_dropout_residual_out.numel() * sizeof(T)); - } - dropout_mask_out.Resize({{bsz, seq_len, dim_embed}}); - auto *dropout_mask_out_data = dev_ctx.Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); - - // 6. ffn matmul1 - auto ffn1_weights = ctx.MultiInput("FFN1Weight"); - auto ffn1_biases = ctx.MultiInput("FFN1Bias"); - auto ffn1_weight_dim = ffn1_weights[0]->dims(); - - int dim_ffn = ffn1_weight_dim[0]; - AttnMatmulINT8 ffn1_linear_compute( - dev_ctx, bsz_seq, dim_ffn, dim_embed, false); - phi::DenseTensor ffn1_out; - ffn1_out.Resize({{bsz_seq, dim_ffn}}); - auto *ffn1_out_data = - dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); - - // 7. ffn act + bias - DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutHelper fused_act_dropout_helper( - dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); - FusedDropoutHelper fused_act_dropout_helper_for_post_layernorm( - dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); - phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; - ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}}); - auto *ffn1_dropout_out_data = dev_ctx.Alloc( - &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); - ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}}); - auto *ffn1_dropout_mask_data = dev_ctx.Alloc( - &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); - - // 8. ffn2 matmul - auto ffn2_weights = ctx.MultiInput("FFN2Weight"); - auto ffn2_biases = ctx.MultiInput("FFN2Bias"); - AttnMatmulINT8 ffn2_linear_compute( - dev_ctx, bsz_seq, dim_embed, dim_ffn, false); - - // 9. ffn2 residual bias - DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper - ffn2_fused_dropout_helper( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); - FusedDropoutLayerNormHelper - ffn2_fused_dropout_dequant_helper( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); - FusedDropoutLayerNormHelper - ffn2_fused_dropout_helper_for_post_layernorm( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); - - // []. init workspace for cublasLt transform - phi::DenseTensor input_workspace, output_workspace, cublaslt_workspace; - // for input and output transform data is CUBLASLT_ORDER_COL32 format, - int m_max = bsz_seq, k_max = std::max(dim_embed, dim_ffn), - n_max = std::max({output_size, dim_embed, dim_ffn}); - - input_workspace.Resize({{(m_max * k_max + 31) / 32 * 32}}); - dev_ctx.Alloc(&input_workspace, - input_workspace.numel() * sizeof(int8_t)); - - output_workspace.Resize({{(n_max * m_max + 31) / 32 * 32}}); - dev_ctx.Alloc(&output_workspace, - output_workspace.numel() * sizeof(int32_t)); - - cublaslt_workspace.Resize({{3000000}}); - dev_ctx.Alloc(&cublaslt_workspace, - cublaslt_workspace.numel() * sizeof(int8_t)); - - // calc - auto *out = ctx.Output("Out"); - auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - phi::DenseTensor *from_tensor = out; - phi::DenseTensor tmp_out; - tmp_out.Resize({{bsz, seq_len, dim_embed}}); - auto *tmp_out_data = - dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); - - auto *x_data = input_x->data(); - phi::DenseTensor *buf0 = nullptr; - phi::DenseTensor *buf1 = nullptr; - - // step0: x --> buf1 - // step1: buf1 --> buf0 - // step2: buf0 --> buf1 - int layers = qkv_weights.size(); - if (pre_layer_norm) { - buf1 = out; - } else { - buf0 = &tmp_out; - buf1 = out; - } - - for (int i = 0; i < layers; ++i) { - // step1. layer_norm - if (i == 0 && pre_layer_norm) { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - // TODO(wangxi): can remove mean var in inference - ln_compute.ComputeForward(x_data, - ln_scale_data, - ln_bias_data, - input_workspace.data(), - ln_mean_data, - ln_var_data, - nullptr, - 0, - qkv_in_scale[i], - quant_round_type, - quant_max_bound, - quant_min_bound); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step1"; -#endif - - // step2. qkv - const phi::DenseTensor *qkv_bias = - qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; - // NOTE: in decoder stage, bias is fused in fmha - const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; - if (!pre_layer_norm && i == 0) { - qkv_compute.ComputeForward(qkv_weights[i], - input_x, - &input_workspace, - bias, - &qkv_out, - &output_workspace, - &qkv_out, - qkv_in_scale[i], - qkv_out_scales[i], - quant_round_type, - quant_max_bound, - quant_min_bound); - } else if (!pre_layer_norm) { - qkv_compute.ComputeForward(qkv_weights[i], - buf1, - &input_workspace, - bias, - &qkv_out, - &output_workspace, - &qkv_out, - qkv_in_scale[i], - qkv_out_scales[i], - quant_round_type, - quant_max_bound, - quant_min_bound); - } else { - qkv_compute.ComputeForwardINT8ToT(qkv_weights[i], - qkv_in_scale[i], - &input_workspace, - bias, - &qkv_out, - &output_workspace, - &qkv_out, - qkv_out_scales[i]); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step2"; -#endif - - // step3. fmha - const phi::DenseTensor *cache_kv = - cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; - phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; - - if (time_step) { // generation decoder stage - // [2, batch_size, num_head, max_seq_len, head_size] - int max_seq_len = cache_kv->dims()[3]; - fmha(dev_ctx, - qkv_out, - *qkv_bias, - *src_mask, - cache_kv_out, - &fmha_out, - bsz, - max_seq_len, - num_head, - dim_head, - time_step->data()[0], - 1. / sqrt(dim_head)); - } else if (cache_kv_out) { // generation context stage - // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - nullptr, - src_mask, - &transpose_out_2, - nullptr, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); - // [3, bsz, num_head, seq_len, head_dim] - T *qkv_data = transpose_out_2_data; - int64_t q_size = bsz * seq_len * num_head * dim_head; - int64_t k_size = q_size; - const T *q_ptr = qkv_data; - const T *k_ptr = q_ptr + q_size; - const T *v_ptr = k_ptr + k_size; - - // [2, bsz, num_head, max_seq_len, head_dim] - int max_seq_len = cache_kv_out->dims()[3]; - T *cache_kv_data = cache_kv_out->data(); - int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; - - T *cache_k_ptr = cache_kv_data; - T *cache_v_ptr = cache_kv_data + cache_k_size; - - write_cache_kv(dev_ctx, - cache_k_ptr, - cache_v_ptr, - k_ptr, - v_ptr, - bsz, - num_head, - seq_len, - max_seq_len, - dim_head); - } else { // not generation - // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - cache_kv, - src_mask, - &transpose_out_2, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step3"; -#endif - - if (pre_layer_norm) { - out_linear_compute.ComputeForwardTToINT8(out_linear_weights[i], - out_linear_in_scale[i], - &fmha_out, - &input_workspace, - nullptr, - &output_workspace, - nullptr, - quant_round_type, - quant_max_bound, - quant_min_bound); - AllReduce(output_workspace, - ring_id, - bsz * seq_len * num_head * dim_head, - dev_ctx); - } else { - out_linear_compute.ComputeForward(out_linear_weights[i], - &fmha_out, - &input_workspace, - nullptr, - buf0, - &output_workspace, - nullptr, - out_linear_in_scale[i], - out_linear_out_scales[i], - quant_round_type, - quant_max_bound, - quant_min_bound); - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step4"; -#endif - - // step5. ln(residual + dropout(input + bias)) - if (pre_layer_norm) { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases[i]->data(); - - // inplace - // non-inplace: buf1 -> input_workspace - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - output_workspace.data(), - x_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - bias_dropout_residual_out_data, - dropout_mask_out_data, - input_workspace.data(), - ln_mean_data, - ln_var_data, - out_linear_in_scale[i], - out_linear_out_scales[i]->data(), - ffn1_in_scale[i], - quant_round_type, - quant_max_bound, - quant_min_bound); - } else { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases[i]->data(); - auto *residual_data = (i == 0 ? x_data : buf1->data()); - fused_dropout_layernorm_helper_for_post_layernorm - .LayernormResidualDropoutBias(dev_ctx, - buf0->data(), - residual_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step5"; -#endif - - // step6. ffn matmul1 - - if (pre_layer_norm) { - ffn1_linear_compute.ComputeForwardINT8ToINT8( - ffn1_weights[i], - &input_workspace, - nullptr, - &output_workspace, - nullptr, - cublaslt_workspace.data()); - } else { - ffn1_linear_compute.ComputeForward(ffn1_weights[i], - buf1, - &input_workspace, - nullptr, - &ffn1_out, - &output_workspace, - nullptr, - ffn1_in_scale[i], - ffn1_out_scales[i], - quant_round_type, - quant_max_bound, - quant_min_bound); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step6"; -#endif - - // step7. act bias - // TODO(wangxi): remove dropout mask in inference - if (pre_layer_norm) { - fused_act_dropout_helper.DropoutActBias( - dev_ctx, - output_workspace.data(), - ffn1_biases[i]->data(), - "gelu", - input_workspace.data(), - ffn1_dropout_mask_data, - ffn1_in_scale[i], - ffn1_out_scales[i]->data(), - ffn2_in_scale[i], - quant_round_type, - quant_max_bound, - quant_min_bound); - } else { - fused_act_dropout_helper_for_post_layernorm.DropoutActBias( - dev_ctx, - ffn1_out_data, - ffn1_biases[i]->data(), - "gelu", - ffn1_dropout_out_data, - ffn1_dropout_mask_data); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7"; -#endif - - // step8. ffn matmul2 - if (pre_layer_norm) { - ffn2_linear_compute.ComputeForwardINT8ToINT8( - ffn2_weights[i], - &input_workspace, - nullptr, - &output_workspace, - nullptr, - cublaslt_workspace.data()); - } else { - ffn2_linear_compute.ComputeForward(ffn2_weights[i], - &ffn1_dropout_out, - &input_workspace, - nullptr, - buf0, - &output_workspace, - nullptr, - ffn2_in_scale[i], - ffn2_out_scales[i], - quant_round_type, - quant_max_bound, - quant_min_bound); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8.0"; -#endif - - if (pre_layer_norm) { - AllReduce(output_workspace, - ring_id, - bsz * seq_len * num_head * dim_head, - dev_ctx); - } else { - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8.1"; -#endif - - // step9. residual bias - if (pre_layer_norm) { - // TODO(wangxi): remove dropout mask in inference - if (i < layers - 1) { - auto *ln_scale_data = ln_scales[i + 1]->data(); - auto *ln_bias_data = ln_biases[i + 1]->data(); - - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - output_workspace.data(), - bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - ln_scale_data, - ln_bias_data, - buf1->data(), - dropout_mask_out_data, - input_workspace.data(), - ln_mean_data, - ln_var_data, - ffn2_in_scale[i], - ffn2_out_scales[i]->data(), - qkv_in_scale[i + 1], - quant_round_type, - quant_max_bound, - quant_min_bound); - } else { - ffn2_fused_dropout_dequant_helper.ResidualDropoutBias( - dev_ctx, - output_workspace.data(), - bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - buf1->data(), - dropout_mask_out_data, - ffn2_in_scale[i], - ffn2_out_scales[i]->data(), - 1.0); - } - } else { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - ffn2_fused_dropout_helper_for_post_layernorm - .LayernormResidualDropoutBias(dev_ctx, - buf0->data(), - buf1->data(), - ffn2_biases[i]->data(), - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step9"; -#endif - if (pre_layer_norm) { - x_data = buf1->data(); - } - } + // using U = LayerNormParamType; + // auto &dev_ctx = ctx.cuda_device_context(); + + // auto *time_step = ctx.Input("TimeStep"); + // // 0. input + // auto *input_x = ctx.Input("X"); + // const auto input_x_dims = input_x->dims(); + // int bsz = input_x_dims[0]; + // int seq_len = input_x_dims[1]; + // int dim_embed = input_x_dims[2]; + // int bsz_seq = bsz * seq_len; + + // // quant input scales, vector, size = num_layers + // auto qkv_in_scale = ctx.Attr>("qkv_in_scale"); + // auto out_linear_in_scale = + // ctx.Attr>("out_linear_in_scale"); + // auto ffn1_in_scale = ctx.Attr>("ffn1_in_scale"); + // auto ffn2_in_scale = ctx.Attr>("ffn2_in_scale"); + + // // quant round type and bound + // auto quant_round_type = ctx.Attr("quant_round_type"); + // auto quant_max_bound = ctx.Attr("quant_max_bound"); + // auto quant_min_bound = ctx.Attr("quant_min_bound"); + + // // dequant output scales, tensor, size = [num_layers, n], n is gemm + // output + // // size + // auto qkv_out_scales = + // ctx.MultiInput("QKVOutScale"); auto + // out_linear_out_scales = + // ctx.MultiInput("OutLinearOutScale"); + // auto ffn1_out_scales = + // ctx.MultiInput("FFN1OutScale"); auto + // ffn2_out_scales = ctx.MultiInput("FFN2OutScale"); + + // // 1. layer norm + // const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + // const float epsilon = ctx.Attr("epsilon"); + // auto ln_scales = ctx.MultiInput("LnScale"); + // auto ln_biases = ctx.MultiInput("LnBias"); + + // auto ln_compute = + // AttnLayerNorm(dev_ctx, epsilon, bsz_seq, + // dim_embed); + // phi::DenseTensor ln_mean, ln_var; + // ln_mean.Resize({{bsz_seq}}); + // auto *ln_mean_data = + // dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + // ln_var.Resize({{bsz_seq}}); + // auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * + // sizeof(U)); + + // // 2. qkv + // // x: qkv's input [batch_size, seq_len, dim_embed] + // // y: qkv's weight: [3, num_head, dim_head, dim_embed] + // auto qkv_weights = ctx.MultiInput("QKVW"); + // auto qkv_biases = ctx.MultiInput("QKVBias"); + // const bool trans_qkvw = ctx.Attr("trans_qkvw"); + // const auto qkv_w_dims = qkv_weights[0]->dims(); + // int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + // int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + // int hidden_size = num_head * dim_head; + // int output_size = 3 * hidden_size; + // int input_size = dim_embed; + + // bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; + // // (transA, transB, compute_bias) = (false, trans_qkvw, false) + // AttnMatmulINT8 qkv_compute( + // dev_ctx, bsz_seq, output_size, input_size, compute_bias); + // phi::DenseTensor qkv_out; + // qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + // auto *qkv_out_data = + // dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // // 3. fmha + // AttnDropoutParam attn_param( + // true, "upscale_in_train", 0.0, true, true, 0, nullptr); + // auto fmha_compute = + // FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, + // attn_param); + // auto *src_mask = ctx.Input("SrcMask"); + // auto cache_kvs = ctx.MultiInput("CacheKV"); + // auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + // // auto *time_step = ctx.Input("TimeStep"); + + // auto out_seq_len = seq_len; + // if (time_step) { + // PADDLE_ENFORCE_EQ(time_step->place(), + // platform::CPUPlace(), + // platform::errors::PreconditionNotMet( + // "The place of input(TimeStep) must be + // CPUPlace.")); + // // cache_seq_len + // int time_step_value = time_step->data()[0]; + // PADDLE_ENFORCE_GT(time_step_value, + // 0, + // platform::errors::PreconditionNotMet( + // "The value of time_step must > 0, but now is + // %d", time_step_value)); + // PADDLE_ENFORCE_EQ( + // seq_len, + // 1, + // platform::errors::PreconditionNotMet( + // "In decode stage, the seq_len of input must be 1, but now + // is %d", seq_len)); + // out_seq_len += time_step_value; + // } + + // phi::DenseTensor transpose_out_2, qk_out; + // transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); + // auto *transpose_out_2_data = + // dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * + // sizeof(T)); + + // qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + // auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * + // sizeof(T)); + + // phi::DenseTensor softmax_out; + // phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; + // phi::DenseTensor qktv_out, fmha_out; + // softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + // auto *softmax_out_data = + // dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + + // attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, + // out_seq_len}}); auto *attn_dropout_mask_out_data = dev_ctx.Alloc( + // &attn_dropout_mask_out, attn_dropout_mask_out.numel() * + // sizeof(T)); + // attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + // auto *attn_dropout_data_data = dev_ctx.Alloc( + // &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); + + // qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); + // auto *qktv_out_data = + // dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + // fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); + // auto *fmha_out_data = + // dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // // 4. out_linear + // auto out_linear_weights = + // ctx.MultiInput("OutLinearW"); auto + // out_linear_biases = + // ctx.MultiInput("OutLinearBias"); int ring_id = + // ctx.Attr("ring_id"); + // // (transA, transB, compute_bias) = (false, false, false) + // AttnMatmulINT8 out_linear_compute( + // dev_ctx, bsz_seq, dim_embed, hidden_size, false); + + // // 5. ln(residual + bias) + // DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); + // FusedDropoutLayerNormHelper + // fused_dropout_layernorm_helper( + // dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); + // FusedDropoutLayerNormHelper + // fused_dropout_layernorm_helper_for_post_layernorm( + // dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); + // auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + // auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + // phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; + // T *bias_dropout_residual_out_data = nullptr; + // if (pre_layer_norm) { + // bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}}); + // bias_dropout_residual_out_data = + // dev_ctx.Alloc(&bias_dropout_residual_out, + // bias_dropout_residual_out.numel() * + // sizeof(T)); + // } + // dropout_mask_out.Resize({{bsz, seq_len, dim_embed}}); + // auto *dropout_mask_out_data = dev_ctx.Alloc( + // &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); + + // // 6. ffn matmul1 + // auto ffn1_weights = ctx.MultiInput("FFN1Weight"); + // auto ffn1_biases = ctx.MultiInput("FFN1Bias"); + // auto ffn1_weight_dim = ffn1_weights[0]->dims(); + + // int dim_ffn = ffn1_weight_dim[0]; + // AttnMatmulINT8 ffn1_linear_compute( + // dev_ctx, bsz_seq, dim_ffn, dim_embed, false); + // phi::DenseTensor ffn1_out; + // ffn1_out.Resize({{bsz_seq, dim_ffn}}); + // auto *ffn1_out_data = + // dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); + + // // 7. ffn act + bias + // DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, + // 0); FusedDropoutHelper + // fused_act_dropout_helper( + // dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); + // FusedDropoutHelper + // fused_act_dropout_helper_for_post_layernorm( + // dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); + // phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; + // ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}}); + // auto *ffn1_dropout_out_data = dev_ctx.Alloc( + // &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); + // ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}}); + // auto *ffn1_dropout_mask_data = dev_ctx.Alloc( + // &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); + + // // 8. ffn2 matmul + // auto ffn2_weights = ctx.MultiInput("FFN2Weight"); + // auto ffn2_biases = ctx.MultiInput("FFN2Bias"); + // AttnMatmulINT8 ffn2_linear_compute( + // dev_ctx, bsz_seq, dim_embed, dim_ffn, false); + + // // 9. ffn2 residual bias + // DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, + // 0); FusedDropoutLayerNormHelper + // ffn2_fused_dropout_helper( + // dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + // FusedDropoutLayerNormHelper + // ffn2_fused_dropout_dequant_helper( + // dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + // FusedDropoutLayerNormHelper + // ffn2_fused_dropout_helper_for_post_layernorm( + // dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + + // // []. init workspace for cublasLt transform + // phi::DenseTensor input_workspace, output_workspace, + // cublaslt_workspace; + // // for input and output transform data is CUBLASLT_ORDER_COL32 + // format, int m_max = bsz_seq, k_max = std::max(dim_embed, dim_ffn), + // n_max = std::max({output_size, dim_embed, dim_ffn}); + + // input_workspace.Resize({{(m_max * k_max + 31) / 32 * 32}}); + // dev_ctx.Alloc(&input_workspace, + // input_workspace.numel() * sizeof(int8_t)); + + // output_workspace.Resize({{(n_max * m_max + 31) / 32 * 32}}); + // dev_ctx.Alloc(&output_workspace, + // output_workspace.numel() * sizeof(int32_t)); + + // cublaslt_workspace.Resize({{3000000}}); + // dev_ctx.Alloc(&cublaslt_workspace, + // cublaslt_workspace.numel() * sizeof(int8_t)); + + // // calc + // auto *out = ctx.Output("Out"); + // auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + // phi::DenseTensor *from_tensor = out; + // phi::DenseTensor tmp_out; + // tmp_out.Resize({{bsz, seq_len, dim_embed}}); + // auto *tmp_out_data = + // dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); + + // auto *x_data = input_x->data(); + // phi::DenseTensor *buf0 = nullptr; + // phi::DenseTensor *buf1 = nullptr; + + // // step0: x --> buf1 + // // step1: buf1 --> buf0 + // // step2: buf0 --> buf1 + // int layers = qkv_weights.size(); + // if (pre_layer_norm) { + // buf1 = out; + // } else { + // buf0 = &tmp_out; + // buf1 = out; + // } + + // for (int i = 0; i < layers; ++i) { + // // step1. layer_norm + // if (i == 0 && pre_layer_norm) { + // auto *ln_scale_data = ln_scales[i]->data(); + // auto *ln_bias_data = ln_biases[i]->data(); + // // TODO(wangxi): can remove mean var in inference + // ln_compute.ComputeForward(x_data, + // ln_scale_data, + // ln_bias_data, + // input_workspace.data(), + // ln_mean_data, + // ln_var_data, + // nullptr, + // 0, + // qkv_in_scale[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // } + // #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + // VLOG(0) << "step1"; + // #endif + + // // step2. qkv + // const phi::DenseTensor *qkv_bias = + // qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + // // NOTE: in decoder stage, bias is fused in fmha + // const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; + // if (!pre_layer_norm && i == 0) { + // qkv_compute.ComputeForward(qkv_weights[i], + // input_x, + // &input_workspace, + // bias, + // &qkv_out, + // &output_workspace, + // &qkv_out, + // qkv_in_scale[i], + // qkv_out_scales[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // } else if (!pre_layer_norm) { + // qkv_compute.ComputeForward(qkv_weights[i], + // buf1, + // &input_workspace, + // bias, + // &qkv_out, + // &output_workspace, + // &qkv_out, + // qkv_in_scale[i], + // qkv_out_scales[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // } else { + // qkv_compute.ComputeForwardINT8ToT(qkv_weights[i], + // qkv_in_scale[i], + // &input_workspace, + // bias, + // &qkv_out, + // &output_workspace, + // &qkv_out, + // qkv_out_scales[i]); + // } + // #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + // VLOG(0) << "step2"; + // #endif + + // // step3. fmha + // const phi::DenseTensor *cache_kv = + // cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + // phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : + // nullptr; + + // if (time_step) { // generation decoder stage + // // [2, batch_size, num_head, max_seq_len, head_size] + // int max_seq_len = cache_kv->dims()[3]; + // fmha(dev_ctx, + // qkv_out, + // *qkv_bias, + // *src_mask, + // cache_kv_out, + // &fmha_out, + // bsz, + // max_seq_len, + // num_head, + // dim_head, + // time_step->data()[0], + // 1. / sqrt(dim_head)); + // } else if (cache_kv_out) { // generation context stage + // // TODO(wangxi): can remove dropout in inference + // fmha_compute.ComputeForward(qkv_out, + // nullptr, + // src_mask, + // &transpose_out_2, + // nullptr, + // &qk_out, + // nullptr, + // &softmax_out, + // &attn_dropout_mask_out, + // &attn_dropout_out, + // &qktv_out, + // &fmha_out); + // // [3, bsz, num_head, seq_len, head_dim] + // T *qkv_data = transpose_out_2_data; + // int64_t q_size = bsz * seq_len * num_head * dim_head; + // int64_t k_size = q_size; + // const T *q_ptr = qkv_data; + // const T *k_ptr = q_ptr + q_size; + // const T *v_ptr = k_ptr + k_size; + + // // [2, bsz, num_head, max_seq_len, head_dim] + // int max_seq_len = cache_kv_out->dims()[3]; + // T *cache_kv_data = cache_kv_out->data(); + // int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + // T *cache_k_ptr = cache_kv_data; + // T *cache_v_ptr = cache_kv_data + cache_k_size; + + // write_cache_kv(dev_ctx, + // cache_k_ptr, + // cache_v_ptr, + // k_ptr, + // v_ptr, + // bsz, + // num_head, + // seq_len, + // max_seq_len, + // dim_head); + // } else { // not generation + // // TODO(wangxi): can remove dropout in inference + // fmha_compute.ComputeForward(qkv_out, + // cache_kv, + // src_mask, + // &transpose_out_2, + // cache_kv_out, + // &qk_out, + // nullptr, + // &softmax_out, + // &attn_dropout_mask_out, + // &attn_dropout_out, + // &qktv_out, + // &fmha_out); + // } + // #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + // VLOG(0) << "step3"; + // #endif + + // if (pre_layer_norm) { + // out_linear_compute.ComputeForwardTToINT8(out_linear_weights[i], + // out_linear_in_scale[i], + // &fmha_out, + // &input_workspace, + // nullptr, + // &output_workspace, + // nullptr, + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // AllReduce(output_workspace, + // ring_id, + // bsz * seq_len * num_head * dim_head, + // dev_ctx); + // } else { + // out_linear_compute.ComputeForward(out_linear_weights[i], + // &fmha_out, + // &input_workspace, + // nullptr, + // buf0, + // &output_workspace, + // nullptr, + // out_linear_in_scale[i], + // out_linear_out_scales[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + // } + // #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + // VLOG(0) << "step4"; + // #endif + + // // step5. ln(residual + dropout(input + bias)) + // if (pre_layer_norm) { + // auto *ln_scale_data = ffn_ln_scales[i]->data(); + // auto *ln_bias_data = ffn_ln_biases[i]->data(); + // auto *out_linear_bias_data = out_linear_biases[i]->data(); + + // // inplace + // // non-inplace: buf1 -> input_workspace + // fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // x_data, + // out_linear_bias_data, + // ln_scale_data, + // ln_bias_data, + // bias_dropout_residual_out_data, + // dropout_mask_out_data, + // input_workspace.data(), + // ln_mean_data, + // ln_var_data, + // out_linear_in_scale[i], + // out_linear_out_scales[i]->data(), + // ffn1_in_scale[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // } else { + // auto *ln_scale_data = ln_scales[i]->data(); + // auto *ln_bias_data = ln_biases[i]->data(); + // auto *out_linear_bias_data = out_linear_biases[i]->data(); + // auto *residual_data = (i == 0 ? x_data : buf1->data()); + // fused_dropout_layernorm_helper_for_post_layernorm + // .LayernormResidualDropoutBias(dev_ctx, + // buf0->data(), + // residual_data, + // out_linear_bias_data, + // ln_scale_data, + // ln_bias_data, + // buf0->data(), + // dropout_mask_out_data, + // buf1->data(), + // ln_mean_data, + // ln_var_data); + // } + // #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + // VLOG(0) << "step5"; + // #endif + + // // step6. ffn matmul1 + + // if (pre_layer_norm) { + // ffn1_linear_compute.ComputeForwardINT8ToINT8( + // ffn1_weights[i], + // &input_workspace, + // nullptr, + // &output_workspace, + // nullptr, + // cublaslt_workspace.data()); + // } else { + // ffn1_linear_compute.ComputeForward(ffn1_weights[i], + // buf1, + // &input_workspace, + // nullptr, + // &ffn1_out, + // &output_workspace, + // nullptr, + // ffn1_in_scale[i], + // ffn1_out_scales[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // } + // #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + // VLOG(0) << "step6"; + // #endif + + // // step7. act bias + // // TODO(wangxi): remove dropout mask in inference + // if (pre_layer_norm) { + // fused_act_dropout_helper.DropoutActBias( + // dev_ctx, + // output_workspace.data(), + // ffn1_biases[i]->data(), + // "gelu", + // input_workspace.data(), + // ffn1_dropout_mask_data, + // ffn1_in_scale[i], + // ffn1_out_scales[i]->data(), + // ffn2_in_scale[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // } else { + // fused_act_dropout_helper_for_post_layernorm.DropoutActBias( + // dev_ctx, + // ffn1_out_data, + // ffn1_biases[i]->data(), + // "gelu", + // ffn1_dropout_out_data, + // ffn1_dropout_mask_data); + // } + // #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + // VLOG(0) << "step7"; + // #endif + + // // step8. ffn matmul2 + // if (pre_layer_norm) { + // ffn2_linear_compute.ComputeForwardINT8ToINT8( + // ffn2_weights[i], + // &input_workspace, + // nullptr, + // &output_workspace, + // nullptr, + // cublaslt_workspace.data()); + // } else { + // ffn2_linear_compute.ComputeForward(ffn2_weights[i], + // &ffn1_dropout_out, + // &input_workspace, + // nullptr, + // buf0, + // &output_workspace, + // nullptr, + // ffn2_in_scale[i], + // ffn2_out_scales[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // } + // #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + // VLOG(0) << "step8.0"; + // #endif + + // if (pre_layer_norm) { + // AllReduce(output_workspace, + // ring_id, + // bsz * seq_len * num_head * dim_head, + // dev_ctx); + // } else { + // AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + // } + // #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + // VLOG(0) << "step8.1"; + // #endif + + // // step9. residual bias + // if (pre_layer_norm) { + // // TODO(wangxi): remove dropout mask in inference + // if (i < layers - 1) { + // auto *ln_scale_data = ln_scales[i + 1]->data(); + // auto *ln_bias_data = ln_biases[i + 1]->data(); + + // ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // bias_dropout_residual_out_data, + // ffn2_biases[i]->data(), + // ln_scale_data, + // ln_bias_data, + // buf1->data(), + // dropout_mask_out_data, + // input_workspace.data(), + // ln_mean_data, + // ln_var_data, + // ffn2_in_scale[i], + // ffn2_out_scales[i]->data(), + // qkv_in_scale[i + 1], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // } else { + // ffn2_fused_dropout_dequant_helper.ResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // bias_dropout_residual_out_data, + // ffn2_biases[i]->data(), + // buf1->data(), + // dropout_mask_out_data, + // ffn2_in_scale[i], + // ffn2_out_scales[i]->data(), + // 1.0); + // } + // } else { + // auto *ln_scale_data = ffn_ln_scales[i]->data(); + // auto *ln_bias_data = ffn_ln_biases[i]->data(); + // ffn2_fused_dropout_helper_for_post_layernorm + // .LayernormResidualDropoutBias(dev_ctx, + // buf0->data(), + // buf1->data(), + // ffn2_biases[i]->data(), + // ln_scale_data, + // ln_bias_data, + // buf0->data(), + // dropout_mask_out_data, + // buf1->data(), + // ln_mean_data, + // ln_var_data); + // } + // #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + // VLOG(0) << "step9"; + // #endif + // if (pre_layer_norm) { + // x_data = buf1->data(); + // } + // } } }; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc index b2c280e38b71d..78d7cd973c941 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -106,13 +106,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { paddle::platform::errors::InvalidArgument( "The first dim of CacheKV must be 2, but got %d", c_dim[0])); // 2 - PADDLE_ENFORCE_EQ(c_dim[1], - x_dim[0], - paddle::platform::errors::InvalidArgument( - "The second dim of CacheKV must be equal with " - "batch size %d, but got %d", - x_dim[0], - c_dim[1])); // batch_size PADDLE_ENFORCE_EQ(c_dim[2], trans_qkvw ? y_dim[1] : y_dim[2], paddle::platform::errors::InvalidArgument( @@ -166,6 +159,7 @@ class FusedMultiTransformerOpOpMaker AddInput("LnBias", "Bias is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable() .AsDuplicable(); AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); @@ -179,6 +173,9 @@ class FusedMultiTransformerOpOpMaker AddInput("RotaryPosEmb", "(optional) The RoPE embeddings for generation inference.") .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); AddInput("TimeStep", "(optional, int) The time step for generation inference.") .AsDispensable(); @@ -190,10 +187,10 @@ class FusedMultiTransformerOpOpMaker AddInput("OutLinearBias", "The out_linear bias tensor.") .AsDispensable() .AsDuplicable(); - AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") .AsDuplicable(); AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") + .AsDispensable() .AsDuplicable(); AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op") .AsDuplicable(); @@ -205,12 +202,10 @@ class FusedMultiTransformerOpOpMaker AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op") .AsDispensable() .AsDuplicable(); - AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") .AsDispensable() .AsDuplicable(); AddOutput("Out", "Result after multi ."); - AddAttr("pre_layer_norm", "if true, the attention op uses pre_layer_norm architecure, " "else, uses post_layer_norm architecuture. " @@ -249,6 +244,10 @@ class FusedMultiTransformerOpOpMaker "'dropout_rate' must be between 0.0 and 1.0.")); }); + AddAttr("residual_alpha", + "Constant for residual_alpha [default 1.0].") + .SetDefault(1.0f); + AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") @@ -269,12 +268,14 @@ class FusedMultiTransformerOpOpMaker AddAttr("act_method", "act_method") .SetDefault("gelu") .AddCustomChecker([](const std::string &act_type) { - PADDLE_ENFORCE_EQ( - act_type == "gelu" || act_type == "relu" || act_type == "none", - true, - platform::errors::InvalidArgument( - "Only support `gelu`, `relu`, `none` activation in " - "FusedMultiTransformer. ")); + PADDLE_ENFORCE_EQ(act_type == "gelu" || act_type == "geglu" || + act_type == "swiglu" || act_type == "relu" || + act_type == "none", + true, + platform::errors::InvalidArgument( + "Only support `gelu`, `geglu`, `swiglu`, " + "`relu`, `none` activation in " + "FusedMultiTransformer. ")); }); AddAttr( @@ -290,6 +291,68 @@ class FusedMultiTransformerOpOpMaker "ring id for tensor model parallel. distributed training and inference") .SetDefault(-1); + AddAttr("norm_type", "norm_type") + .SetDefault("layernorm") + .AddCustomChecker([](const std::string &norm_type) { + PADDLE_ENFORCE_EQ( + norm_type == "layernorm" || norm_type == "rmsnorm", + true, + platform::errors::InvalidArgument( + "Only support `layernorm`, `rmsnorm` method for in" + "FusedMultiTransformerDyquant. ")); + }); + + AddAttr("use_neox_rotary_style", + "Whether use neox rotary embedding. ") + .SetDefault(false); + AddAttr>( + "cache_k_scale", + "cache_k_scale is used to quantize cache kv tensor." + "in_scale is generated by PTQ or QAT, which represents valid max range " + "of this tensor." + "the size of cache_k_scale should be num_layers, which is equal to " + "len(CacheKV)") + .SetDefault({}); + AddAttr>( + "cache_v_scale", + "cache_v_scale is used to quantize cache kv tensor." + "in_scale is generated by PTQ or QAT, which represents valid max range " + "of this tensor." + "the size of cache_v_scale should be num_layers, which is equal to " + "len(CacheKV)") + .SetDefault({}); + AddAttr>( + "cache_k_out_scale", + "cache_k_out_scale is used to dequantize cache kv tensor." + "in_scale is generated by PTQ or QAT, which represents valid max range " + "of this tensor." + "the size of cache_k_out_scale should be num_layers, which is equal to " + "len(CacheKV)") + .SetDefault({}); + AddAttr>( + "cache_v_out_scale", + "cache_v_out_scale is used to dequantize cache kv tensor." + "in_scale is generated by PTQ or QAT, which represents valid max range " + "of this tensor." + "the size of cache_v_out_scale should be num_layers, which is equal to " + "len(CacheKV)") + .SetDefault({}); + AddAttr( + "quant_round_type", + "(int, default 1) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(-2.5)=-3") + .SetDefault(0); + AddAttr( + "quant_max_bound", + "(float, default 127.0) the max bound of float type to int type") + .SetDefault(127.0); + AddAttr( + "quant_min_bound", + "(float, default -127.0) the min bound of float type to int type") + .SetDefault(-127.0); + AddComment(R"DOC(fused multi transformer layers op)DOC"); } }; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 5ca5994dc3193..41f262a604f26 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -12,678 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h" +#include "paddle/fluid/operators/custom_all_reduce.h" +#include "paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h" +#include "paddle/fluid/platform/device/gpu/gpu_resource_pool.h" +#include "paddle/phi/kernels/flash_attn_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" -namespace paddle { -namespace operators { - -#if CUDA_VERSION >= 11060 // Use cublasLt to fuse FFN operation. - -template -class FusedMultiTransformerOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using U = LayerNormParamType; - auto &dev_ctx = ctx.cuda_device_context(); - - auto *time_step = ctx.Input("TimeStep"); - // 0. input - auto *input_x = ctx.Input("X"); - const auto input_x_dims = input_x->dims(); - int bsz = input_x_dims[0]; - int seq_len = input_x_dims[1]; - int dim_embed = input_x_dims[2]; - int bsz_seq = bsz * seq_len; - const std::string act_method = ctx.Attr("act_method"); - bool remove_padding = false; - auto *sequence_lengths = ctx.Input("SeqLengths"); - if (sequence_lengths) { - remove_padding = true; - } - phi::DenseTensor d_token_tensor; - phi::DenseTensor padding_offset_tensor; - phi::DenseTensor x_remove_padding; - bool encoder_remove_padding = (remove_padding && !time_step); - int token_num = 0; - - // remove padding in encoder - if (encoder_remove_padding) { - // just for encoder - d_token_tensor.Resize({{1}}); - auto *d_token_num = dev_ctx.Alloc( - &d_token_tensor, d_token_tensor.numel() * sizeof(int)); - // alloc the max size of padding_offset_tensor - padding_offset_tensor.Resize({{bsz_seq}}); - dev_ctx.Alloc(&padding_offset_tensor, - padding_offset_tensor.numel() * sizeof(int)); - InvokeGetPaddingOffset(dev_ctx, - &token_num, - d_token_num, - padding_offset_tensor.data(), - sequence_lengths->data(), - bsz, - seq_len); - padding_offset_tensor.Resize({{token_num}}); - x_remove_padding.Resize({{token_num, dim_embed}}); - dev_ctx.Alloc(&x_remove_padding, x_remove_padding.numel() * sizeof(T)); - InvokeRemovePadding(dev_ctx, - x_remove_padding.data(), - input_x->data(), - padding_offset_tensor.data(), - token_num, - dim_embed); - } else { - token_num = bsz_seq; - } - auto *padding_offset_data = - encoder_remove_padding ? padding_offset_tensor.data() : nullptr; - - // 1. layer norm - const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); - const float epsilon = ctx.Attr("epsilon"); - auto ln_scales = ctx.MultiInput("LnScale"); - auto ln_biases = ctx.MultiInput("LnBias"); - - auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); - phi::DenseTensor ln_mean, ln_var; - ln_mean.Resize({{token_num}}); - auto *ln_mean_data = - dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_var.Resize({{token_num}}); - auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); - - // 2. qkv - // x: qkv's input [batch_size, seq_len, dim_embed] - // y: qkv's weight: [3, num_head, dim_head, dim_embed] - auto qkv_weights = ctx.MultiInput("QKVW"); - auto qkv_biases = ctx.MultiInput("QKVBias"); - const bool trans_qkvw = ctx.Attr("trans_qkvw"); - const auto qkv_w_dims = qkv_weights[0]->dims(); - int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; - int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; - int hidden_size = num_head * dim_head; - int output_size = 3 * hidden_size; - int input_size = dim_embed; - - bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; - // (transA, transB, compute_bias) = (false, trans_qkvw, false) - // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set - // compute_bias as false. - auto qkv_compute = AttnMatMul(dev_ctx, - false, - trans_qkvw, - token_num, - output_size, - input_size, - /*compute_bias=*/false); - - phi::DenseTensor qkv_out; - qkv_out.Resize({{token_num, 3, num_head, dim_head}}); - auto *qkv_out_data = - dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); - - // 2.1 rotary - auto *rotary_tensor = ctx.Input("RotaryPosEmb"); - const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); - - // 3. fmha - AttnDropoutParam attn_param( - true, "upscale_in_train", 0.0, true, true, 0, nullptr); - auto fmha_compute = - FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); - auto *src_mask = ctx.Input("SrcMask"); - auto cache_kvs = ctx.MultiInput("CacheKV"); - auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); - // auto *time_step = ctx.Input("TimeStep"); - auto pre_caches = ctx.MultiInput("PreCaches"); - int cache_offset = 0; - if (pre_caches.size() > 0) { - cache_offset = pre_caches[0]->dims()[3]; - } - - auto out_seq_len = seq_len; - if (time_step) { - PADDLE_ENFORCE_EQ(time_step->place(), - platform::CPUPlace(), - platform::errors::PreconditionNotMet( - "The place of input(TimeStep) must be CPUPlace.")); - // cache_seq_len - int time_step_value = time_step->data()[0]; - PADDLE_ENFORCE_GT(time_step_value, - 0, - platform::errors::PreconditionNotMet( - "The value of time_step must > 0, but now is %d", - time_step_value)); - PADDLE_ENFORCE_EQ( - seq_len, - 1, - platform::errors::PreconditionNotMet( - "In decode stage, the seq_len of input must be 1, but now is %d", - seq_len)); - out_seq_len += time_step_value; - } else { - out_seq_len += cache_offset; - } - - phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; - q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); - auto *q_transpose_out_data = - dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); - - kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); - auto *kv_transpose_out_data = dev_ctx.Alloc( - &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); - - qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); - - phi::DenseTensor src_mask_out; - if (cache_offset > 0) { - src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *src_mask_out_data = - dev_ctx.Alloc(&src_mask_out, src_mask_out.numel() * sizeof(T)); - } - - // [2, bs, num_head, cache_seq_len + seq_len, head_dim] - phi::DenseTensor pre_cache_kv_out; - if (cache_offset > 0) { - pre_cache_kv_out.Resize( - {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); - auto *pre_cache_kv_out_data = dev_ctx.Alloc( - &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); - } - - phi::DenseTensor softmax_out; - phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; - phi::DenseTensor qktv_out, fmha_out; - softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *softmax_out_data = - dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); - - attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_mask_out_data = dev_ctx.Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_data_data = dev_ctx.Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); - - qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); - auto *qktv_out_data = - dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); - auto *fmha_out_data = - dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); - - // 4. out_linear - auto out_linear_weights = ctx.MultiInput("OutLinearW"); - auto out_linear_biases = ctx.MultiInput("OutLinearBias"); - int ring_id = ctx.Attr("ring_id"); - // (transA, transB, compute_bias) = (false, false, false) - auto out_linear_compute = AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, hidden_size, false); - - // 5. ln(residual + bias) - DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - dev_ctx, token_num, dim_embed, dropout_param2, epsilon); - auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); - auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); - phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; - T *bias_dropout_residual_out_data = nullptr; - if (pre_layer_norm) { - bias_dropout_residual_out.Resize({{token_num, dim_embed}}); - bias_dropout_residual_out_data = - dev_ctx.Alloc(&bias_dropout_residual_out, - bias_dropout_residual_out.numel() * sizeof(T)); - } - dropout_mask_out.Resize({{token_num, dim_embed}}); - auto *dropout_mask_out_data = dev_ctx.Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); - - // 6. ffn1 matmul + act + bias - auto ffn1_weights = ctx.MultiInput("FFN1Weight"); - auto ffn1_biases = ctx.MultiInput("FFN1Bias"); - auto ffn1_weight_dim = ffn1_weights[0]->dims(); - - int dim_ffn = ffn1_weight_dim[1]; - - auto ffn1_cublas_linear = CublasFusedMLP(dev_ctx); - const phi::DDim ffn1_input_shape({token_num, dim_embed}); - ffn1_cublas_linear.Setup(ffn1_input_shape, ffn1_weight_dim, false, false); - - phi::DenseTensor ffn1_out; - ffn1_out.Resize({{token_num, dim_ffn}}); - auto *ffn1_out_data = - dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); - - // 7. ffn2 matmul + bias + residual. - auto ffn2_weights = ctx.MultiInput("FFN2Weight"); - auto ffn2_biases = ctx.MultiInput("FFN2Bias"); - - auto ffn2_linear_compute = AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, dim_ffn, false); - - // 8. ffn2 Layernorm residual bias - DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( - dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); - - // calc - auto *out = ctx.Output("Out"); - auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - phi::DenseTensor *from_tensor = out; - phi::DenseTensor tmp_out, tmp_out_rm_padding; - tmp_out.Resize({{token_num, dim_embed}}); - if (encoder_remove_padding) { - tmp_out_rm_padding.Resize({{token_num, dim_embed}}); - auto *tmp_out_rm_padding_data = dev_ctx.Alloc( - &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); - } - auto *tmp_out_data = - dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); - - const T *x_data; - if (encoder_remove_padding) { - x_data = x_remove_padding.data(); - } else { - x_data = input_x->data(); - } - phi::DenseTensor *buf0 = nullptr; - phi::DenseTensor *buf1 = nullptr; - - // step0: x --> buf1 - // step1: buf1 --> buf0 - // step2: buf0 --> buf1 - int layers = qkv_weights.size(); - if (encoder_remove_padding) { - // In the case of variable lengths, the padding needs to be rebuilt - // eventually. So buf0 and buf1 do not need to be changed according to the - // pre_layer_norm and the number of layers. - buf0 = &tmp_out; - buf1 = &tmp_out_rm_padding; - } else { - if (pre_layer_norm) { - if (layers & 1) { - // odd, set buf1 as out - buf0 = &tmp_out; - buf1 = out; - } else { - // even, set buf0 as out - buf0 = out; - buf1 = &tmp_out; - } - } else { - buf0 = &tmp_out; - buf1 = out; - } - } - - for (int i = 0; i < layers; ++i) { - // step1. layer_norm - if (i == 0 && pre_layer_norm) { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - // TODO(wangxi): can remove mean var in inference - ln_compute.ComputeForward(x_data, - ln_scale_data, - ln_bias_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step1"; -#endif - - // step2. qkv - const phi::DenseTensor *qkv_bias = - qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; - // NOTE: in decoder stage, bias is fused in fmha - const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; - if (!pre_layer_norm && i == 0) { - const phi::DenseTensor *tmp_input_x = - (encoder_remove_padding) ? &x_remove_padding : input_x; - qkv_compute.ComputeForward( - qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); - } else { - qkv_compute.ComputeForward( - qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step2"; -#endif - - // step3. fmha - const phi::DenseTensor *cache_kv = - cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; - phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; - - if (time_step) { // generation decoder stage - // [2, batch_size, num_head, max_seq_len, head_size] - int max_seq_len = cache_kv->dims()[3]; - fmha(dev_ctx, - qkv_out, - *qkv_bias, - *src_mask, - sequence_lengths, - rotary_tensor, - cache_kv_out, - &fmha_out, - bsz, - max_seq_len, - num_head, - dim_head, - time_step->data()[0], - rotary_emb_dims, - 1. / sqrt(dim_head)); - } else if (cache_kv_out) { // generation context stage - const phi::DenseTensor *pre_cache_kv_tensor = - pre_caches.size() > 0 ? pre_caches[i] : nullptr; - phi::DenseTensor *pre_cache_kv_out_tmp = - cache_offset > 0 ? &pre_cache_kv_out : nullptr; - phi::DenseTensor *src_mask_tmp = - cache_offset > 0 ? &src_mask_out : nullptr; - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); - // q_transpose_out_data [bs, head_num, seq_len, dim_head] - // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] - if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor->data(); - const int *sequence_lengths_data = - encoder_remove_padding ? sequence_lengths->data() : nullptr; - rotary_qk(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - q_transpose_out_data, - kv_transpose_out_data, - rotary_emb_data, - sequence_lengths_data, - rotary_emb_dims, - bsz, - num_head, - seq_len, - dim_head); - } - - phi::DenseTensor *tmp_padding_offset_tensor = - encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, - src_mask, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - pre_cache_kv_out_tmp, - &qk_out, - src_mask_tmp, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); - const T *k_ptr = nullptr; - const T *v_ptr = nullptr; - - if (cache_offset > 0) { - // [2, bsz, num_head, cache_offset + seq_len, head_dim] - const T *kv_data = pre_cache_kv_out.data(); - k_ptr = kv_data; - int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; - v_ptr = k_ptr + k_size; - } else { - // [3, bsz, num_head, seq_len, head_dim] - int64_t k_size = bsz * seq_len * num_head * dim_head; - const T *q_ptr = q_transpose_out_data; - k_ptr = kv_transpose_out_data; - v_ptr = k_ptr + k_size; - } - - // [2, bsz, num_head, max_seq_len, head_dim] - int max_seq_len = cache_kv_out->dims()[3]; - T *cache_kv_data = cache_kv_out->data(); - int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; - - T *cache_k_ptr = cache_kv_data; - T *cache_v_ptr = cache_kv_data + cache_k_size; - - const int seq_len_tmp = seq_len + cache_offset; - write_cache_kv(dev_ctx, - cache_k_ptr, - cache_v_ptr, - k_ptr, - v_ptr, - bsz, - num_head, - seq_len_tmp, - max_seq_len, - dim_head); - } else { // not generation - // TODO(wangxi): can remove dropout in inference - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); - - // q_transpose_out_data [bs, head_num, seq_len, dim_head] - // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] - if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor->data(); - const int *sequence_lengths_data = - encoder_remove_padding ? sequence_lengths->data() : nullptr; - rotary_qk(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - q_transpose_out_data, - kv_transpose_out_data, - rotary_emb_data, - sequence_lengths_data, - rotary_emb_dims, - bsz, - num_head, - seq_len, - dim_head); - } - - phi::DenseTensor *tmp_padding_offset_tensor = - encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(cache_kv, - src_mask, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step3"; -#endif - - if (pre_layer_norm) { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); - } else { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step4"; -#endif - - // step5. ln(residual + dropout(input + bias)) - if (pre_layer_norm) { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases[i]->data(); - - // inplace - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - buf1->data(), - x_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - bias_dropout_residual_out_data, - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } else { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases[i]->data(); - auto *residual_data = (i == 0 ? x_data : buf1->data()); - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - buf0->data(), - residual_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step5"; -#endif - - // step6. ffn matmul1 - ffn1_cublas_linear.ComputeForward(buf1, - ffn1_weights[i], - ffn1_biases[i], - nullptr, - &ffn1_out, - act_method); - -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step6"; -#endif - - // step7. ffn2 matmul - if (pre_layer_norm) { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_out, nullptr, buf1, nullptr); - } else { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_out, nullptr, buf0, nullptr); - } - -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7"; -#endif - - if (pre_layer_norm) { - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); - } else { - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); - } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7.1"; -#endif +#include - // step8. layer norm + bias_add + residual - if (pre_layer_norm) { - // TODO(wangxi): remove dropout mask in inference - if (i < layers - 1) { - auto *ln_scale_data = ln_scales[i + 1]->data(); - auto *ln_bias_data = ln_biases[i + 1]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - buf1->data(), - bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - ln_scale_data, - ln_bias_data, - buf1->data(), - dropout_mask_out_data, - buf0->data(), - ln_mean_data, - ln_var_data); - } else { - ffn2_fused_dropout_helper.ResidualDropoutBias( - dev_ctx, - buf1->data(), - bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - buf1->data(), - dropout_mask_out_data); - } - } else { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - buf0->data(), - buf1->data(), - ffn2_biases[i]->data(), - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } +// #define _DEBUG_FUSED_MULTI_TRANSFORMER +// #define _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8"; -#endif - if (pre_layer_norm) { - x_data = buf1->data(); - std::swap(buf0, buf1); - } - } - if (encoder_remove_padding) { - if (pre_layer_norm) { - InvokeRebuildPadding(dev_ctx, - from_data, - buf0->data(), - padding_offset_data, - token_num, - dim_embed); - } else { - InvokeRebuildPadding(dev_ctx, - from_data, - buf1->data(), - padding_offset_data, - token_num, - dim_embed); - } - } - } -}; +namespace paddle { +namespace operators { -#else +static phi::DenseTensor CustomAllReduce(const phi::DenseTensor &t) { + auto *ctx = static_cast( + platform::DeviceContextPool::Instance().Get(t.place())); + auto comm = GetCustomNCCLComm(*ctx, 0); + PADDLE_ENFORCE_NOT_NULL(comm); + phi::DenseTensor ret; + ret.Resize(t.dims()); + ctx->Alloc(&ret, t.dtype()); + comm->SwapInput(&ret); + phi::Copy(*ctx, t, t.place(), false, &ret); + return comm->AllReduce(); +} template class FusedMultiTransformerOpKernel : public framework::OpKernel { @@ -701,17 +55,38 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { int dim_embed = input_x_dims[2]; int bsz_seq = bsz * seq_len; const std::string act_method = ctx.Attr("act_method"); + bool use_glu = (act_method == "geglu" || act_method == "swiglu"); + const std::string norm_type = ctx.Attr("norm_type"); + const bool use_neox_rotary_style = ctx.Attr("use_neox_rotary_style"); bool remove_padding = false; auto *sequence_lengths = ctx.Input("SeqLengths"); if (sequence_lengths) { remove_padding = true; } + + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } + phi::DenseTensor d_token_tensor; phi::DenseTensor padding_offset_tensor; phi::DenseTensor x_remove_padding; + + // cumulative seqlens [batch_size+1] + phi::DenseTensor cu_seqlens_q, cu_seqlens_k; bool encoder_remove_padding = (remove_padding && !time_step); int token_num = 0; + auto *out = ctx.Output("Out"); + auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + + // Init out + if (encoder_remove_padding) { + InitValue(dev_ctx, from_data, out->numel(), static_cast(0.)); + } + // remove padding in encoder if (encoder_remove_padding) { // just for encoder @@ -722,13 +97,19 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { padding_offset_tensor.Resize({{bsz_seq}}); dev_ctx.Alloc(&padding_offset_tensor, padding_offset_tensor.numel() * sizeof(int)); + cu_seqlens_q.Resize({{bsz + 1}}); + dev_ctx.Alloc(&cu_seqlens_q, + cu_seqlens_q.numel() * sizeof(int32_t)); + InvokeGetPaddingOffset(dev_ctx, &token_num, d_token_num, padding_offset_tensor.data(), + cu_seqlens_q.data(), sequence_lengths->data(), bsz, seq_len); + if (token_num == 0) return; padding_offset_tensor.Resize({{token_num}}); x_remove_padding.Resize({{token_num, dim_embed}}); dev_ctx.Alloc(&x_remove_padding, x_remove_padding.numel() * sizeof(T)); @@ -740,17 +121,20 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { dim_embed); } else { token_num = bsz_seq; + if (token_num == 0) return; } + auto *padding_offset_data = encoder_remove_padding ? padding_offset_tensor.data() : nullptr; // 1. layer norm const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); + const float residual_alpha = ctx.Attr("residual_alpha"); auto ln_scales = ctx.MultiInput("LnScale"); auto ln_biases = ctx.MultiInput("LnBias"); - - auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + NormHelper norm_helper( + dev_ctx, norm_type, token_num, dim_embed, epsilon, residual_alpha); phi::DenseTensor ln_mean, ln_var; ln_mean.Resize({{token_num}}); auto *ln_mean_data = @@ -771,17 +155,29 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { int output_size = 3 * hidden_size; int input_size = dim_embed; - bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; + auto cache_k_scale = ctx.Attr>("cache_k_scale"); + auto cache_v_scale = ctx.Attr>("cache_v_scale"); + auto cache_k_out_scale = ctx.Attr>("cache_k_out_scale"); + auto cache_v_out_scale = ctx.Attr>("cache_v_out_scale"); + bool do_cachekv_quant = (cache_k_scale.size() != 0); + + auto quant_round_type = ctx.Attr("quant_round_type"); + auto quant_max_bound = ctx.Attr("quant_max_bound"); + auto quant_min_bound = ctx.Attr("quant_min_bound"); + + // Set a flag whether need to add Matmul / Layernorm bias. + bool compute_bias = qkv_biases.size() > 0; + bool compute_ln_bias = ln_biases.size() > 0; + // (transA, transB, compute_bias) = (false, trans_qkvw, false) // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we // set compute_bias as false. - auto qkv_compute = AttnMatMul(dev_ctx, - false, - trans_qkvw, - token_num, - output_size, - input_size, - /*compute_bias=*/false); + + // auto mixed_gemm_runner = paddle::operators::CutlassFpAIntBGemmRunner< + // typename PDDataTypeTraits::DataType, + // uint8_t>(); + auto qkv_compute = GEMMHelper( + dev_ctx, token_num, output_size, input_size, "None", trans_qkvw); phi::DenseTensor qkv_out; qkv_out.Resize({{token_num, 3, num_head, dim_head}}); @@ -831,6 +227,29 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { out_seq_len += cache_offset; } + // whether to broadcast 2nd dimension for src_mask, default true + // if mask_broadcast_num_heads if False, which means src_mask shape + // will be: + // 1. [batch_size, num_head, seq_len, seq_len] for encoder + // 2. [batch_size, num_heads, 1, time_step+1] for decoder + // and do not need to broadcast num_heads dimension when calculating + // attn_mask offset in MHA + bool mask_broadcast_num_heads = true; + if (src_mask) { + if (src_mask->dims()[1] == 1) { + mask_broadcast_num_heads = true; + } else if (src_mask->dims()[1] == num_head) { + mask_broadcast_num_heads = false; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unknow dimension for attn_mask, the num_head(2nd) " + "dimension is invalid, it should be 1 or num_head(%d), " + "but got %d", + num_head, + src_mask->dims()[1])); + } + } + phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); auto *q_transpose_out_data = @@ -840,14 +259,29 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto *kv_transpose_out_data = dev_ctx.Alloc( &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); - qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + if (encoder_remove_padding) { + InitValue(dev_ctx, + q_transpose_out_data, + q_transpose_out.numel(), + static_cast(0.)); + InitValue(dev_ctx, + kv_transpose_out_data, + kv_transpose_out.numel(), + static_cast(0.)); + } + + if (FLAGS_fmha_mode == "naive") { + qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + } phi::DenseTensor src_mask_out; - if (cache_offset > 0) { - src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *src_mask_out_data = - dev_ctx.Alloc(&src_mask_out, src_mask_out.numel() * sizeof(T)); + if (FLAGS_fmha_mode == "naive") { + if (cache_offset > 0) { + src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *src_mask_out_data = + dev_ctx.Alloc(&src_mask_out, src_mask_out.numel() * sizeof(T)); + } } // [2, bs, num_head, cache_seq_len + seq_len, head_dim] @@ -862,36 +296,58 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { phi::DenseTensor softmax_out; phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; phi::DenseTensor qktv_out, fmha_out; - softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *softmax_out_data = - dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + if (FLAGS_fmha_mode == "naive") { + softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *softmax_out_data = + dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + } - attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_mask_out_data = dev_ctx.Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_data_data = dev_ctx.Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); + // unpadding_q/unpadding_k/unpadding_v: [token_num, num_head, dim_head] + phi::DenseTensor unpadding_q, unpadding_k, unpadding_v; + phi::DenseTensor softmax_lse, seed_offset; + if (FLAGS_fmha_mode == "flash_attention_v2" && encoder_remove_padding) { + unpadding_q.Resize({{token_num, num_head, dim_head}}); + unpadding_k.Resize({{token_num, num_head, dim_head}}); + unpadding_v.Resize({{token_num, num_head, dim_head}}); + cu_seqlens_k.Resize(cu_seqlens_q.dims()); + + dev_ctx.Alloc(&unpadding_q, unpadding_q.numel() * sizeof(T)); + dev_ctx.Alloc(&unpadding_k, unpadding_k.numel() * sizeof(T)); + dev_ctx.Alloc(&unpadding_v, unpadding_v.numel() * sizeof(T)); + dev_ctx.Alloc(&cu_seqlens_k, + cu_seqlens_k.numel() * sizeof(int32_t)); + } + + T *attn_dropout_mask_out_data = nullptr; + T *attn_dropout_data_data = nullptr; qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); auto *qktv_out_data = dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); + if (remove_padding) { + fmha_out.Resize({{token_num, num_head, dim_head}}); + } else { + fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); + } auto *fmha_out_data = dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + if (FLAGS_fmha_mode != "flash_attention_v2") { + if (remove_padding && time_step) { + InitValue(dev_ctx, fmha_out_data, fmha_out.numel(), static_cast(0.)); + } + } // 4. out_linear auto out_linear_weights = ctx.MultiInput("OutLinearW"); auto out_linear_biases = ctx.MultiInput("OutLinearBias"); int ring_id = ctx.Attr("ring_id"); + auto *custom_comm = GetCustomNCCLComm(dev_ctx, ring_id); // (transA, transB, compute_bias) = (false, false, false) - auto out_linear_compute = AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, hidden_size, false); + + auto out_linear_compute = GEMMHelper( + dev_ctx, token_num, dim_embed, hidden_size, "None", false); // 5. ln(residual + bias) - DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - dev_ctx, token_num, dim_embed, dropout_param2, epsilon); auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; @@ -902,50 +358,55 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { dev_ctx.Alloc(&bias_dropout_residual_out, bias_dropout_residual_out.numel() * sizeof(T)); } - dropout_mask_out.Resize({{token_num, dim_embed}}); - auto *dropout_mask_out_data = dev_ctx.Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); + uint8_t *dropout_mask_out_data = nullptr; // 6. ffn matmul1 auto ffn1_weights = ctx.MultiInput("FFN1Weight"); auto ffn1_biases = ctx.MultiInput("FFN1Bias"); auto ffn1_weight_dim = ffn1_weights[0]->dims(); - + // if quant weight, + // matmul weight is transposed int dim_ffn = ffn1_weight_dim[1]; - auto ffn1_linear_compute = AttnMatMul( - dev_ctx, false, false, token_num, dim_ffn, dim_embed, false); + FFNHelper ffn1_helper( + dev_ctx, act_method, token_num, dim_ffn, dim_embed, "None"); + phi::DenseTensor ffn1_out; ffn1_out.Resize({{token_num, dim_ffn}}); auto *ffn1_out_data = dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); + // Note(Zhengzekang): It is no need when using FP16 matmul. + phi::DenseTensor mixgemm_workspace; + char *mixgemm_workspace_data = nullptr; + // 7. ffn act + bias DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutHelper fused_act_dropout_helper( + FusedDropoutHelper fused_act_dropout_helper( dev_ctx, token_num, dim_ffn, ffn1_dropout_param); phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; - ffn1_dropout_out.Resize({{token_num, dim_ffn}}); + int tmp_dim_ffn = dim_ffn; + if (use_glu) tmp_dim_ffn /= 2; + int8_t *ffn1_dropout_mask_data = nullptr; + ffn1_dropout_out.Resize({{token_num, tmp_dim_ffn}}); auto *ffn1_dropout_out_data = dev_ctx.Alloc( &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); - ffn1_dropout_mask.Resize({{token_num, dim_ffn}}); - auto *ffn1_dropout_mask_data = dev_ctx.Alloc( - &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); // 8. ffn2 matmul auto ffn2_weights = ctx.MultiInput("FFN2Weight"); auto ffn2_biases = ctx.MultiInput("FFN2Bias"); - auto ffn2_linear_compute = AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, dim_ffn, false); + auto ffn2_linear_compute = GEMMHelper( + dev_ctx, token_num, dim_embed, tmp_dim_ffn, "None", false); // 9. ffn2 residual bias DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( - dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); + dev_ctx, + token_num, + dim_embed, + ffn2_dropout_param, + epsilon, + residual_alpha); - // calc - auto *out = ctx.Output("Out"); - auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - phi::DenseTensor *from_tensor = out; phi::DenseTensor tmp_out, tmp_out_rm_padding; tmp_out.Resize({{token_num, dim_embed}}); if (encoder_remove_padding) { @@ -991,46 +452,80 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { buf1 = out; } } + auto start = std::chrono::system_clock::now(); + + VLOG(1) << "input_x->" << input_x->dims(); + + VLOG(1) << "input_x " << *input_x; for (int i = 0; i < layers; ++i) { // step1. layer_norm if (i == 0 && pre_layer_norm) { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - // TODO(wangxi): can remove mean var in inference - ln_compute.ComputeForward(x_data, - ln_scale_data, - ln_bias_data, - buf1->data(), - ln_mean_data, - ln_var_data); + norm_helper.Norm(x_data, + ln_scales[i], + compute_ln_bias ? ln_biases[i] : nullptr, /*norm_bias*/ + &ln_mean, /*mean*/ + &ln_var, /*var*/ + buf1); } + #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step1"; + VLOG(2) << "step1"; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR + VLOG(2) << "ln1_out:" << *buf1; +#endif #endif // step2. qkv + // NOTE: In decoder stage, bias is fused in fmha. In encoder stage, bias + // is fused in QKVBiasAddTransposeSplit const phi::DenseTensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; - // NOTE: in decoder stage, bias is fused in fmha - const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; if (!pre_layer_norm && i == 0) { const phi::DenseTensor *tmp_input_x = (encoder_remove_padding) ? &x_remove_padding : input_x; - qkv_compute.ComputeForward( - qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); + VLOG(5) << "Doing !pre_layer_norm&&i==0, qkv gemm, mnk:" << token_num + << ", " << output_size << ", " << input_size; + qkv_compute.Compute(tmp_input_x, + qkv_weights[i], + /*weight_scale*/ nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + &qkv_out); } else { - qkv_compute.ComputeForward( - qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR + VLOG(2) << "qkv_weights:" << *(qkv_weights[i]); +#endif +#endif + VLOG(5) << "Doing qkv gemm, mnk:" << token_num << ", " << output_size + << ", " << input_size; + qkv_compute.Compute(buf1, + qkv_weights[i], + /*weight_scale*/ nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + &qkv_out); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step2"; + VLOG(2) << "step2"; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR + VLOG(2) << "qkv_out:" << qkv_out; +#endif #endif + // 2. cache kv + auto write_cache_kv_helper = WriteCacheKVHelper( + dev_ctx, quant_round_type, quant_max_bound, quant_min_bound); + // step3. fmha const phi::DenseTensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + int cache_bsz = 0; + if (cache_kv) { + cache_bsz = cache_kv->dims()[1]; + } if (time_step) { // generation decoder stage // [2, batch_size, num_head, max_seq_len, head_size] @@ -1038,18 +533,32 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { fmha(dev_ctx, qkv_out, *qkv_bias, - *src_mask, + src_mask, + nullptr, sequence_lengths, rotary_tensor, + beam_cache_offset, cache_kv_out, &fmha_out, bsz, + cache_bsz, + seq_len, max_seq_len, num_head, dim_head, - time_step->data()[0], + src_mask->dims()[3] - 1, rotary_emb_dims, - 1. / sqrt(dim_head)); + 1. / sqrt(dim_head), + mask_broadcast_num_heads, + compute_bias, + use_neox_rotary_style, + nullptr, // qkv_out_scale + nullptr, // out_linear_shift + nullptr, // out_smooth_shift + (do_cachekv_quant) ? cache_k_scale[i] : -1.0, + (do_cachekv_quant) ? cache_v_scale[i] : -1.0, + (do_cachekv_quant) ? cache_k_out_scale[i] : -1.0, + (do_cachekv_quant) ? cache_v_out_scale[i] : -1.0); } else if (cache_kv_out) { // generation context stage const phi::DenseTensor *pre_cache_kv_tensor = pre_caches.size() > 0 ? pre_caches[i] : nullptr; @@ -1057,23 +566,26 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { cache_offset > 0 ? &pre_cache_kv_out : nullptr; phi::DenseTensor *src_mask_tmp = cache_offset > 0 ? &src_mask_out : nullptr; - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + qkv_bias_add_transpose_split( + dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias ? qkv_bias->data() : nullptr, + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); // q_transpose_out_data [bs, head_num, seq_len, dim_head] // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor->data(); + auto *rotary_emb_data = rotary_tensor->data(); const int *sequence_lengths_data = encoder_remove_padding ? sequence_lengths->data() : nullptr; rotary_qk(dev_ctx, @@ -1084,83 +596,111 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { rotary_emb_data, sequence_lengths_data, rotary_emb_dims, + rotary_tensor->dims()[1], bsz, num_head, seq_len, - dim_head); + dim_head, + use_neox_rotary_style); } phi::DenseTensor *tmp_padding_offset_tensor = encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, - src_mask, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - pre_cache_kv_out_tmp, - &qk_out, - src_mask_tmp, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); - const T *k_ptr = nullptr; - const T *v_ptr = nullptr; - - if (cache_offset > 0) { - // [2, bsz, num_head, cache_offset + seq_len, head_dim] - const T *kv_data = pre_cache_kv_out.data(); - k_ptr = kv_data; - int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; - v_ptr = k_ptr + k_size; + + if (FLAGS_fmha_mode == "flash_attention_v2" && encoder_remove_padding) { + TransposeSplit(dev_ctx, + unpadding_q.data(), + unpadding_k.data(), + unpadding_v.data(), + q_transpose_out.data(), + kv_transpose_out.data(), + padding_offset_data, + sequence_lengths->data(), + token_num, + bsz, + num_head, + seq_len, + dim_head); + phi::Copy(dev_ctx, + cu_seqlens_q, + cu_seqlens_k.place(), + false, + &cu_seqlens_k); + + // fmha_out[token_num, num_head, dim_head] + phi::FlashAttnUnpaddedKernel(dev_ctx, + unpadding_q, + unpadding_k, + unpadding_v, + cu_seqlens_q, + cu_seqlens_k, + none /*fixed_seed_offset*/, + none /*attn_mask*/, + seq_len, + seq_len, + 1.0f / sqrt(float(dim_head)), + 0.0, + true /*causal*/, + false, + true /* is_test*/, + "" /*rng_name*/, + &fmha_out, + &softmax_out, + &softmax_lse, + &seed_offset); } else { - // [3, bsz, num_head, seq_len, head_dim] - int64_t k_size = bsz * seq_len * num_head * dim_head; - const T *q_ptr = q_transpose_out_data; - k_ptr = kv_transpose_out_data; - v_ptr = k_ptr + k_size; + fmha_compute.Compute(pre_cache_kv_tensor, + src_mask, + tmp_padding_offset_tensor, + sequence_lengths, + &q_transpose_out, + &kv_transpose_out, + pre_cache_kv_out_tmp, + &qk_out, + src_mask_tmp, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num, + mask_broadcast_num_heads); } - // [2, bsz, num_head, max_seq_len, head_dim] - int max_seq_len = cache_kv_out->dims()[3]; - T *cache_kv_data = cache_kv_out->data(); - int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; - - T *cache_k_ptr = cache_kv_data; - T *cache_v_ptr = cache_kv_data + cache_k_size; - - const int seq_len_tmp = seq_len + cache_offset; - write_cache_kv(dev_ctx, - cache_k_ptr, - cache_v_ptr, - k_ptr, - v_ptr, - bsz, - num_head, - seq_len_tmp, - max_seq_len, - dim_head); + write_cache_kv_helper.Compute( + &pre_cache_kv_out, + cache_kv_out, // int8_t + &kv_transpose_out, // T + sequence_lengths_data, + cache_bsz, + bsz, + num_head, + seq_len, + dim_head, + cache_offset, + (do_cachekv_quant) ? cache_k_scale[i] : -1.0, + (do_cachekv_quant) ? cache_v_scale[i] : -1.0); + } else { // not generation // TODO(wangxi): can remove dropout in inference - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); + qkv_bias_add_transpose_split( + dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias ? qkv_bias->data() : nullptr, + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); // q_transpose_out_data [bs, head_num, seq_len, dim_head] // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor->data(); + auto *rotary_emb_data = rotary_tensor->data(); const int *sequence_lengths_data = encoder_remove_padding ? sequence_lengths->data() : nullptr; rotary_qk(dev_ctx, @@ -1171,178 +711,300 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { rotary_emb_data, sequence_lengths_data, rotary_emb_dims, + rotary_tensor->dims()[1], bsz, num_head, seq_len, - dim_head); + dim_head, + use_neox_rotary_style); } - phi::DenseTensor *tmp_padding_offset_tensor = encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(cache_kv, - src_mask, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); + + if (FLAGS_fmha_mode == "flash_attention_v2" && encoder_remove_padding) { + TransposeSplit(dev_ctx, + unpadding_q.data(), + unpadding_k.data(), + unpadding_v.data(), + q_transpose_out.data(), + kv_transpose_out.data(), + padding_offset_data, + sequence_lengths->data(), + token_num, + bsz, + num_head, + seq_len, + dim_head); + phi::Copy(dev_ctx, + cu_seqlens_q, + cu_seqlens_k.place(), + false, + &cu_seqlens_k); + + // fmha_out[token_num, num_head, dim_head] + phi::FlashAttnUnpaddedKernel(dev_ctx, + unpadding_q, + unpadding_k, + unpadding_v, + cu_seqlens_q, + cu_seqlens_k, + none /*fixed_seed_offset*/, + none /*attn_mask*/, + seq_len, + seq_len, + 1.0f / sqrt(float(dim_head)), + 0.0, + true /*causal*/, + false, + true /* is_test*/, + "" /*rng_name*/, + &fmha_out, + &softmax_out, + &softmax_lse, + &seed_offset); + } else { + fmha_compute.Compute(cache_kv, + src_mask, + tmp_padding_offset_tensor, + sequence_lengths, + &q_transpose_out, + &kv_transpose_out, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num, + mask_broadcast_num_heads); + } } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step3"; + VLOG(2) << "step3"; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR + VLOG(2) << "fmha_out:" << fmha_out; #endif - +#endif + VLOG(5) << "Doing out_linear gemm, mnk:" << token_num << ", " << dim_embed + << ", " << hidden_size; if (pre_layer_norm) { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + if (custom_comm) { + custom_comm->SwapInput(buf1); + } + + out_linear_compute.Compute(&fmha_out, + out_linear_weights[i], + /*weight_scale*/ nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + buf1); + + if (custom_comm) { + *buf1 = custom_comm->AllReduce(); + } else { + VLOG(1) << "ALLREDUCE " << buf1->numel(); + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } } else { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + if (custom_comm) { + custom_comm->SwapInput(buf0); + } + out_linear_compute.Compute(&fmha_out, + out_linear_weights[i], + /*weight_scale*/ nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + buf0); + + if (custom_comm) { + *buf0 = custom_comm->AllReduce(); + } else { + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step4"; + VLOG(2) << "step4"; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR + VLOG(2) << "out_linear_out:" << *buf1; +#endif #endif // step5. ln(residual + dropout(input + bias)) if (pre_layer_norm) { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases[i]->data(); - - // inplace - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, + norm_helper.NormResidualBias( buf1->data(), x_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - bias_dropout_residual_out_data, - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); + compute_bias ? out_linear_biases[i] : nullptr, /*skip_bias*/ + ffn_ln_scales[i], + compute_ln_bias ? ffn_ln_biases[i] : nullptr, /*norm_bias*/ + &ln_mean, /*mean*/ + &ln_var, /*var*/ + &bias_dropout_residual_out, + buf1); } else { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases[i]->data(); auto *residual_data = (i == 0 ? x_data : buf1->data()); - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, + norm_helper.NormResidualBias( buf0->data(), residual_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); + compute_bias ? out_linear_biases[i] : nullptr, /*skip_bias*/ + ln_scales[i], + compute_ln_bias ? ln_biases[i] : nullptr, /*norm_bias*/ + &ln_mean, /*mean*/ + &ln_var, /*var*/ + buf0, + buf1); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step5"; + VLOG(2) << "step5"; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR + VLOG(2) << "ffn1_input:" << *buf1; +#endif #endif - // step6. ffn matmul1 - ffn1_linear_compute.ComputeForward( - ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); + ffn1_helper.Compute(buf1, + ffn1_weights[i], + /*weight_scale*/ nullptr, + compute_bias ? ffn1_biases[i] : nullptr, + &mixgemm_workspace, + &ffn1_out, + &ffn1_dropout_out); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step6"; + VLOG(2) << "step6"; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR + VLOG(2) << "ffn1_output:" << ffn1_out; #endif - - // step7. act bias - // TODO(wangxi): remove dropout mask in inference - fused_act_dropout_helper.DropoutActBias(dev_ctx, - ffn1_out_data, - ffn1_biases[i]->data(), - act_method, - ffn1_dropout_out_data, - ffn1_dropout_mask_data); -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7"; #endif - // step8. ffn matmul2 + // step7. ffn2 matmul if (pre_layer_norm) { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); + if (custom_comm) { + custom_comm->SwapInput(buf1); + } + ffn2_linear_compute.Compute(&ffn1_dropout_out, + ffn2_weights[i], + nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + buf1); } else { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_dropout_out, nullptr, buf0, nullptr); + if (custom_comm) { + custom_comm->SwapInput(buf0); + } + ffn2_linear_compute.Compute(&ffn1_dropout_out, + ffn2_weights[i], + nullptr, + /*bias*/ nullptr, + &mixgemm_workspace, + buf0); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8.0"; + VLOG(2) << "step8.0"; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR + if (pre_layer_norm) { + VLOG(2) << "ffn2_out, buf1:" << *buf1; + } else { + VLOG(2) << "ffn2_out, buf0:" << *buf0; + } +#endif #endif if (pre_layer_norm) { - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + VLOG(4) << "MPAllReduce 4: " << buf1->numel(); + if (custom_comm) { + *buf1 = custom_comm->AllReduce(); + } else { + VLOG(1) << "ALLREDUCE ffn" << buf1->numel(); + + auto ar_start = std::chrono::system_clock::now(); + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + auto ar_end = std::chrono::system_clock::now(); + auto ar_duration = + std::chrono::duration_cast(ar_end - + ar_start); + VLOG(3) << "reduce elapse " + << double(ar_duration.count()) * + std::chrono::microseconds::period::num / + std::chrono::microseconds::period::den + << " SEC"; + } } else { - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + VLOG(4) << "MPAllReduce 4: " << buf0->numel(); + if (custom_comm) { + *buf0 = custom_comm->AllReduce(); + } else { + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } } + #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8.1"; + VLOG(2) << "step8.1"; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR + if (pre_layer_norm) { + VLOG(2) << "ffn2_out_rd:" << *buf1; + } else { + VLOG(2) << "ffn2_out_rd:" << *buf0; + } +#endif #endif - // step9. residual bias + // step8. residual bias + // TODO(wangxi): remove dropout mask in inference if (pre_layer_norm) { // TODO(wangxi): remove dropout mask in inference if (i < layers - 1) { - auto *ln_scale_data = ln_scales[i + 1]->data(); - auto *ln_bias_data = ln_biases[i + 1]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, + norm_helper.NormResidualBias( buf1->data(), bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - ln_scale_data, - ln_bias_data, - buf1->data(), - dropout_mask_out_data, - buf0->data(), - ln_mean_data, - ln_var_data); + compute_bias ? ffn2_biases[i] : nullptr, /*skip_bias*/ + ln_scales[i + 1], + compute_ln_bias ? ln_biases[i + 1] : nullptr, /*norm_bias*/ + &ln_mean, /*mean*/ + &ln_var, /*var*/ + buf1, + buf0); } else { ffn2_fused_dropout_helper.ResidualDropoutBias( dev_ctx, buf1->data(), bias_dropout_residual_out_data, - ffn2_biases[i]->data(), + compute_bias ? ffn2_biases[i]->data() : nullptr, buf1->data(), dropout_mask_out_data); } } else { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, + norm_helper.NormResidualBias( buf0->data(), buf1->data(), - ffn2_biases[i]->data(), - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); + compute_bias ? ffn2_biases[i] : nullptr, /*skip_bias*/ + ffn_ln_scales[i], + compute_ln_bias ? ffn_ln_biases[i] : nullptr, /*norm_bias*/ + &ln_mean, /*mean*/ + &ln_var, /*var*/ + buf0, + buf1); } + #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step9"; + VLOG(2) << "step9"; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER_PRINT_TENSOR + VLOG(2) << "residual_out:" << *buf1; +#endif #endif if (pre_layer_norm) { x_data = buf1->data(); std::swap(buf0, buf1); } } + auto end = std::chrono::system_clock::now(); + auto duration = + std::chrono::duration_cast(end - start); + VLOG(3) << "ELAPSE " + << double(duration.count()) * + std::chrono::microseconds::period::num / + std::chrono::microseconds::period::den + << " SEC"; if (encoder_remove_padding) { if (pre_layer_norm) { InvokeRebuildPadding(dev_ctx, @@ -1363,16 +1025,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { } }; -#endif // CUDA_VERSION >= 11060 - } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; + +#if CUDA_VERSION >= 11000 +PD_REGISTER_STRUCT_KERNEL(fused_multi_transformer, + GPU, + ALL_LAYOUT, + ops::FusedMultiTransformerOpKernel, + float, + plat::float16, + plat::bfloat16) {} +#else PD_REGISTER_STRUCT_KERNEL(fused_multi_transformer, GPU, ALL_LAYOUT, ops::FusedMultiTransformerOpKernel, float, plat::float16) {} +#endif diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 4769433317f0f..efe029f70eae6 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -19,35 +19,527 @@ limitations under the License. */ #pragma once -#include -#include - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/fused/attention_layer_norm.h" -#include "paddle/fluid/operators/fused/attn_gemm.h" -#include "paddle/fluid/operators/fused/fmha_ref.h" -#include "paddle/fluid/operators/fused/fused_dropout_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/fluid/platform/dynload/cublasLt.h" -#include "paddle/phi/api/include/tensor.h" -#include "paddle/phi/backends/gpu/gpu_device_function.h" -#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" -#include "paddle/phi/kernels/funcs/math_function.h" +#include +#include +#include "paddle/phi/kernels/flash_attn_kernel.h" -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/distributed/collective/process_group.h" -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#endif +#include "paddle/fluid/operators/fused/mmha_util.cu.h" +#include "paddle/phi/kernels/gpu/flash_attn_utils.h" -DECLARE_bool(gemm_use_half_precision_compute_type); +DECLARE_string(fmha_mode); +DECLARE_int64(custom_allreduce_one_shot_threshold); +DECLARE_int64(custom_allreduce_two_shot_threshold); namespace paddle { namespace operators { +inline float fp32_from_bits(uint32_t w) { +#if defined(__OPENCL_VERSION__) + return as_float(w); +#elif defined(__CUDA_ARCH__) + return __uint_as_float((unsigned int)w); +#elif defined(__INTEL_COMPILER) + return _castu32_f32(w); +#else + union { + uint32_t as_bits; + float as_value; + } fp32 = {w}; + return fp32.as_value; +#endif +} + +inline uint32_t fp32_to_bits(float f) { +#if defined(__OPENCL_VERSION__) + return as_uint(f); +#elif defined(__CUDA_ARCH__) + return (uint32_t)__float_as_uint(f); +#elif defined(__INTEL_COMPILER) + return _castf32_u32(f); +#else + union { + float as_value; + uint32_t as_bits; + } fp32 = {f}; + return fp32.as_bits; +#endif +} + +static float CPUHalfConvert2Float(const uint16_t h) { + const uint32_t w = (uint32_t)h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; + // const float exp_scale = 0x1.0p-112f; + constexpr uint32_t scale_bits = (uint32_t)15 << 23; + float exp_scale_val; + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); + const float exp_scale = exp_scale_val; + const float normalized_value = + fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + constexpr uint32_t magic_mask = UINT32_C(126) << 23; + constexpr float magic_bias = 0.5f; + const float denormalized_value = + fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = + sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +template +static void PrintMatrix(const T *mat_d, int num, std::string name) { + // if (FLAGS_cublaslt_exhaustive_search_times != 114514) return; + + std::vector tmp(num); + cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); + + std::ofstream outfile; + outfile.open(name + ".txt", std::ios::out); + std::stringstream ss; + + for (int i = 0; i < num; ++i) { + if (std::is_same::value) { + ss << static_cast(tmp[i]) << std::endl; + } else { + ss << std::setprecision(8) << (float)(tmp[i]) << std::endl; // NOLINT + } + } + outfile << ss.str(); + outfile.close(); +} + +static void PrintHalfMatrix(const void *mat_d_ptr, int num, std::string name) { + VLOG(0) << "PRINT HALF MATRIX Num is: " << num; + const uint16_t *mat_d = reinterpret_cast(mat_d_ptr); + std::vector tmp(num); + cudaMemcpy(tmp.data(), mat_d, sizeof(uint16_t) * num, cudaMemcpyDeviceToHost); + + std::ofstream outfile; + outfile.open(name + ".txt", std::ios::out); + std::stringstream ss; + + for (int i = 0; i < num; ++i) { + ss << std::setprecision(8) << CPUHalfConvert2Float(tmp[i]) << std::endl; + } + outfile << ss.str(); + outfile.close(); +} + +template +struct Load { + explicit Load(const T *src) : src_(src) {} + + template + __device__ void load(phi::AlignedVector *dst, int idx) { + phi::Load(src_ + idx, dst); + } + + const T *src_; +}; + +template +struct Store { + explicit Store(T *dst) : dst_(dst) {} + + template + __device__ void store(phi::AlignedVector &src, int idx) { + phi::Store(src, dst_ + idx); + } + + T *dst_; +}; + +template +struct Store { + Store(T *dst, const T *shift, const T *smooth, const int cols) + : dst_(dst), shift_(shift), smooth_(smooth), cols_(cols) {} + + template + __device__ void store(phi::AlignedVector &src, int idx) { + using Vec = phi::AlignedVector; + Vec shift_vec; + Vec smooth_vec; + + phi::Load(shift_ + idx % cols_, &shift_vec); + phi::Load(smooth_ + idx % cols_, &smooth_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src[i] = (src[i] + shift_vec[i]) * smooth_vec[i]; + } + phi::Store(src, dst_ + idx); + } + + T *dst_; + const T *shift_; + const T *smooth_; + const int cols_; +}; + +template +struct DequantLoad { + DequantLoad(const int32_t *src, const float *dequant_scales, const int cols) + : src_(src), dequant_scales_(dequant_scales), cols_(cols) {} + + template + __device__ void load(phi::AlignedVector *dst, int idx) { + using SrcVec = phi::AlignedVector; + using DstVec = phi::AlignedVector; + using ScaleVec = phi::AlignedVector; + + SrcVec src_vec; + DstVec dst_vec; + ScaleVec scale_vec; + + phi::Load(src_ + idx, &src_vec); + phi::Load(dequant_scales_ + idx % cols_, &scale_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dst_vec[i] = + static_cast(static_cast(src_vec[i]) * scale_vec[i]); + } + *dst = dst_vec; + } + + const int32_t *src_; + const float *dequant_scales_; + const int cols_; +}; + +template +__device__ __inline__ T ClipFunc(const T v, const T min, const T max) { + if (v > max) return max; + if (v < min) return min; + return v; +} + +template +__forceinline__ __device__ OutType QuantHelperFunc(const InType input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * static_cast(input); + + if (round_type == 0) { + quant_value = static_cast(rint(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + return static_cast( + ClipFunc(quant_value, min_bound, max_bound)); +} + +template +struct QuantStore { + QuantStore(int8_t *dst, + const int quant_round_type, + const float quant_scale, + const float quant_max_bound, + const float quant_min_bound) + : dst_(dst), + quant_round_type_(quant_round_type), + quant_scale_(quant_scale), + quant_max_bound_(quant_max_bound), + quant_min_bound_(quant_min_bound) {} + + template + __device__ void store(phi::AlignedVector &src, // NOLINT + int idx) { // NOLINT + using DstVec = phi::AlignedVector; + + DstVec dst_vec; +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dst_vec[i] = QuantHelperFunc(static_cast(src[i]), + quant_scale_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_); + } + + phi::Store(dst_vec, dst_ + idx); + } + + int8_t *dst_; + const int quant_round_type_; + const float quant_scale_; + const float quant_max_bound_; + const float quant_min_bound_; +}; + +template +struct QuantStore { + QuantStore(int8_t *dst, + const T *shift, + const T *smooth, + const int cols, + const int quant_round_type, + const float quant_scale, + const float quant_max_bound, + const float quant_min_bound) + : dst_(dst), + shift_(shift), + smooth_(smooth), + cols_(cols), + quant_round_type_(quant_round_type), + quant_scale_(quant_scale), + quant_max_bound_(quant_max_bound), + quant_min_bound_(quant_min_bound) {} + + template + __device__ void store(phi::AlignedVector &src, // NOLINT + int idx) { // NOLINT + using DstVec = phi::AlignedVector; + using Vec = phi::AlignedVector; + + DstVec dst_vec; + Vec shift_vec; + Vec smooth_vec; + + phi::Load(shift_ + idx % cols_, &shift_vec); + phi::Load(smooth_ + idx % cols_, &smooth_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src[i] = (src[i] + shift_vec[i]) * smooth_vec[i]; + dst_vec[i] = QuantHelperFunc(static_cast(src[i]), + quant_scale_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_); + } + + phi::Store(dst_vec, dst_ + idx); + } + + int8_t *dst_; + const int quant_round_type_; + const float quant_scale_; + const float quant_max_bound_; + const float quant_min_bound_; + const T *shift_; + const T *smooth_; + const int cols_; +}; + +template +struct MMHALoad { + explicit MMHALoad(const LoadT *src) : src_(src) {} + + template + __device__ void load(Vec &dst, int idx) { + dst = *reinterpret_cast(src_ + idx); + } + + const LoadT *src_; +}; + +template +struct MMHAStore { + explicit MMHAStore(StoreT *dst) : dst_(dst) {} + + template + __device__ void store(Vec &src, int idx) { + *reinterpret_cast(dst_ + idx) = src; + } + + StoreT *dst_; +}; + +template +struct MMHAStore { + MMHAStore(T *dst, const T *shift, const T *smooth, const int cols) + : dst_(dst), shift_(shift), smooth_(smooth), cols_(cols) {} + + template + __device__ void store(Vec &src, int idx) { + constexpr int VecSize = sizeof(Vec) / sizeof(T); + using TVec = phi::AlignedVector; + TVec src_vec; + TVec shift_vec; + TVec smooth_vec; + + *reinterpret_cast(&src_vec) = src; + phi::Load(shift_ + idx % cols_, &shift_vec); + phi::Load(smooth_ + idx % cols_, &smooth_vec); + +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src_vec[i] = (src_vec[i] + shift_vec[i]) * smooth_vec[i]; + } + + phi::Store(src_vec, dst_ + idx); + } + + T *dst_; + const T *shift_; + const T *smooth_; + const int cols_; +}; + +template +struct MMHALoad { + MMHALoad(const int32_t *src, const float *dequant_scales, const int cols) + : src_(src), dequant_scales_(dequant_scales), cols_(cols) {} + + template + __device__ void load(Vec &dst, int idx) { + constexpr int VecSize = sizeof(Vec) / sizeof(T); + using SrcVec = phi::AlignedVector; + using DstVec = phi::AlignedVector; + using ScaleVec = phi::AlignedVector; + + SrcVec src_vec; + DstVec dst_vec; + ScaleVec scale_vec; + + phi::Load(src_ + idx, &src_vec); + phi::Load(dequant_scales_ + idx % cols_, &scale_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dst_vec[i] = + static_cast(static_cast(src_vec[i]) * scale_vec[i]); + } + dst = *reinterpret_cast(&dst_vec); + } + + const int32_t *src_; + const float *dequant_scales_; + const int cols_; +}; + +template +struct MMHAStore { + MMHAStore(int8_t *dst, + const int quant_round_type, + const float quant_scale, + const float quant_max_bound, + const float quant_min_bound) + : dst_(dst), + quant_round_type_(quant_round_type), + quant_scale_(quant_scale), + quant_max_bound_(quant_max_bound), + quant_min_bound_(quant_min_bound) {} + + template + __device__ void store(Vec &src, int idx) { // NOLINT + constexpr int VecSize = sizeof(Vec) / sizeof(T); + using SrcVec = phi::AlignedVector; + using DstVec = phi::AlignedVector; + + SrcVec src_vec; + *reinterpret_cast(&src_vec) = src; + + DstVec dst_vec; +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dst_vec[i] = + QuantHelperFunc(static_cast(src_vec[i]), + quant_scale_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_); + } + + phi::Store(dst_vec, dst_ + idx); + } + + int8_t *dst_; + const int quant_round_type_; + const float quant_scale_; + const float quant_max_bound_; + const float quant_min_bound_; +}; + +template +struct MMHAStore { + MMHAStore(int8_t *dst, + const T *shift, + const T *smooth, + const int cols, + const int quant_round_type, + const float quant_scale, + const float quant_max_bound, + const float quant_min_bound) + : dst_(dst), + quant_round_type_(quant_round_type), + quant_scale_(quant_scale), + quant_max_bound_(quant_max_bound), + quant_min_bound_(quant_min_bound), + shift_(shift), + smooth_(smooth), + cols_(cols) {} + + template + __device__ void store(Vec &src, int idx) { // NOLINT + constexpr int VecSize = sizeof(Vec) / sizeof(T); + using SrcVec = phi::AlignedVector; + using DstVec = phi::AlignedVector; + + SrcVec src_vec; + DstVec dst_vec; + SrcVec shift_vec; + SrcVec smooth_vec; + + *reinterpret_cast(&src_vec) = src; + phi::Load(shift_ + idx % cols_, &shift_vec); + phi::Load(smooth_ + idx % cols_, &smooth_vec); + +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src_vec[i] = (src_vec[i] + shift_vec[i]) * smooth_vec[i]; + dst_vec[i] = + QuantHelperFunc(static_cast(src_vec[i]), + quant_scale_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_); + } + + phi::Store(dst_vec, dst_ + idx); + } + + int8_t *dst_; + const T *shift_; + const T *smooth_; + const int cols_; + const int quant_round_type_; + const float quant_scale_; + const float quant_max_bound_; + const float quant_min_bound_; +}; + +template +struct BaseActivationFunctor { + using ELEMENT_TYPE = T; + + using AttrPair = std::vector>; + + AttrPair GetAttrs() { return AttrPair(); } +}; + +template +struct CudaSwishFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float beta = 1.0; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + // swish(x) = x / (1 + exp(-beta * x)) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + MPType b = static_cast(beta); + return static_cast(x / (one + exp(-b * x))); + } +}; + // for debug // #define _DEBUG_FUSED_MULTI_TRANSFORMER @@ -62,13 +554,13 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT if (map->has(ring_id)) { paddle::distributed::ProcessGroup *pg = map->get(ring_id); - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(tensor); + // std::vector in_tensor; + // std::vector out_tensor; + // in_tensor.push_back(tensor); + // out_tensor.push_back(tensor); paddle::distributed::AllreduceOptions opts; opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg->AllReduce(in_tensor, out_tensor, opts); + auto task = pg->AllReduce(&tensor, tensor, opts, false, true); task->Wait(); } else { auto dtype = platform::ToNCCLDataType( @@ -91,29 +583,11 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT namespace { // NOLINT -namespace plat = paddle::platform; -using float16 = plat::float16; - #define MMHA_USE_FP32_ACUM_FOR_LOGITS #define MMHA_USE_FP32_ACUM_FOR_OUT #define MMHA_USE_FP32_ACUM_FOR_FMA // #define MMHA_USE_HMMA_FOR_REDUCTION -template -class PDDataTypeTraits; - -template <> -class PDDataTypeTraits { - public: - typedef float DataType; -}; - -template <> -class PDDataTypeTraits { - public: - typedef half DataType; -}; - template struct Masked_multihead_attention_params { // output buffer, [B, 1(seq_len), num_head * dim_head] @@ -121,618 +595,813 @@ struct Masked_multihead_attention_params { // qkv_out, [B, 1(seq_len), 3, num_head * dim_head] const T *qkv; // bias, [3, num_head, dim_head] - const T *qkv_bias; + T *qkv_bias; + // [bsz, seq_len] + const int *cum_offsets; // TODO(wangxi): optimize with input_lengths and max_input_len? // [bsz, 1, 1, time_step(cache_seq_length)+1] const T *attn_mask; + int mask_length; + // whether to broadcast num_heads(2nd) dimension for attn_mask + // in MMHA, if false, attn_mask shape should be + // [bsz, num_heads, 1, time_step(cache_seq_length)+1] + bool mask_broadcast_num_heads; // [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head] // k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first // v [B, num_head, max_seq_len, dim_head] - T *cache_kv; + T *cache_kv = nullptr; + // [B, max_seq_len] + const int *beam_cache_offset = nullptr; const int *sequence_lengths{nullptr}; - // The RoPE embedding, [B, 1, 1, dim_head] + // The RoPE embedding, [2, B, rotary_seq_len, 1, dim_head] // rotary_emb_dims = 1 if pos_ids_extra is null else 2 - const T *rotary_emb; + const float *rotary_emb; + int rotary_bsz; int rotary_emb_dims; + int rotary_seq_len = 1; - int batch_size; + int batch_size; // batch * beam + int beam_width; + int cache_batch_size; int num_head; int timestep; // cache_seq_length + int seq_len; int max_seq_length; // 1.f / sqrt(Dh) float inv_sqrt_dh; -}; -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; + bool add_qkv_bias; + bool neox_rotary_style; }; -// clang-format off - -template struct Qk_vec_ {}; -template <> struct Qk_vec_ { using Type = float; }; -template <> struct Qk_vec_ { using Type = float2; }; -template <> struct Qk_vec_ { using Type = float4; }; -template <> struct Qk_vec_ { using Type = float4; }; -template <> struct Qk_vec_ { using Type = uint32_t; }; -template <> struct Qk_vec_ { using Type = uint32_t; }; -template <> struct Qk_vec_ { using Type = uint2; }; -template <> struct Qk_vec_ { using Type = uint4; }; - -template struct K_vec_ {}; -template <> struct K_vec_ { using Type = float; }; -template <> struct K_vec_ { using Type = float2; }; -template <> struct K_vec_ { using Type = float4; }; -template <> struct K_vec_ { using Type = uint32_t; }; -template <> struct K_vec_ { using Type = uint2; }; -template <> struct K_vec_ { using Type = uint4; }; - -template struct V_vec_ {}; -template <> struct V_vec_ { using Type = float; }; -template <> struct V_vec_ { using Type = float2; }; -template <> struct V_vec_ { using Type = float4; }; -template <> struct V_vec_ { using Type = uint32_t; }; -template <> struct V_vec_ { using Type = uint2; }; -template <> struct V_vec_ { using Type = uint4; }; - #ifdef MMHA_USE_FP32_ACUM_FOR_FMA -template -struct K_vec_acum_fp32_ { -}; +template +struct K_vec_acum_fp32_ {}; -template<> +template <> struct K_vec_acum_fp32_ { - using Type = float2; + using Type = float2; }; #endif #ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template struct V_vec_acum_fp32_ {}; +template +struct V_vec_acum_fp32_ {}; // template <> struct V_vec_acum_fp32_ { using Type = float; }; // template <> struct V_vec_acum_fp32_ { using Type = float2; }; -template <> struct V_vec_acum_fp32_ { using Type = float4; }; +template <> +struct V_vec_acum_fp32_ { + using Type = float4; +}; // template <> struct V_vec_acum_fp32_ { using Type = float2; }; // template <> struct V_vec_acum_fp32_ { using Type = Float4_; }; -template <> struct V_vec_acum_fp32_ { using Type = Float8_; }; -#endif - -// clang-format on - -inline __device__ float half_to_float(uint16_t h) { - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} +template <> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; -inline __device__ float2 half2_to_float2(uint32_t v) { - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} +#ifdef ENABLE_BF16 +template <> +struct V_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template <> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template <> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 -inline __device__ uint32_t float2_to_half2(float2 f) { - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" - : "=r"(tmp.u32) - : "f"(f.y), "f"(f.x)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif - return tmp.u32; -} -inline __device__ float add(float a, float b) { return a + b; } - -inline __device__ float2 add(float2 a, float2 b) { - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} +// clang-format on -inline __device__ float4 add(float4 a, float4 b) { - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} +//////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ uint16_t add(uint16_t a, uint16_t b) { - uint16_t c; - asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} +template +inline __device__ float qk_dot_(const K_vec (&q)[N], + const K_vec (&k)[N], + float inv_sqrt_dh) { + K_vec inv_q = mul(q[0], inv_sqrt_dh); + K_vec qk_vec = mul(inv_q, k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + inv_q = mul(q[ii], inv_sqrt_dh); + qk_vec = fma(inv_q, k[ii], qk_vec); + } -inline __device__ uint32_t add(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; } -inline __device__ uint2 add(uint2 a, uint2 b) { - uint2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} +inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) { + float4 c; + float zero = 0.f; + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" -inline __device__ uint4 add(uint4 a, uint4 b) { - uint4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); return c; } -inline __device__ float2 add(uint32_t a, float2 fb) { - float2 fa = half2_to_float2(a); - return add(fa, fb); -} - -inline __device__ Float8_ add(uint4 a, Float8_ fb) { - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} - -template -inline __device__ Acc mul(A a, B b); - -template <> -inline __device__ float mul(float a, float b) { - return a * b; -} - -template <> -inline __device__ float2 mul(float2 a, float2 b) { - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], + const uint32_t (&k)[N], + float inv_sqrt_dh) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum inv_q = mul(q[0], inv_sqrt_dh); + K_vec_acum qk_vec = mul(inv_q, k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + inv_q = mul(q[ii], inv_sqrt_dh); + qk_vec = fma(inv_q, k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif } -template <> -inline __device__ float4 mul(float4 a, float4 b) { - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], + const K_vec (&k)[N], + float inv_sqrt_dh) { + return qk_dot_(q, k, inv_sqrt_dh); + } +}; template <> -inline __device__ uint16_t mul(uint16_t a, uint16_t b) { - uint16_t c; - asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], + const uint32_t (&k)[N], + float inv_sqrt_dh) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 + return qk_hmma_dot_(q, k, inv_sqrt_dh); +#else + return qk_dot_<4>(q, k, inv_sqrt_dh); +#endif + } +}; -template <> -inline __device__ uint32_t mul(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} +template +inline __device__ float block_sum(float *red_smem, float sum) { + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; -template <> -inline __device__ uint2 mul(uint2 a, uint2 b) { - uint2 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - return c; -} +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } -template <> -inline __device__ uint4 mul(uint4 a, uint4 b) { - uint4 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - c.z = mul(a.z, b.z); - c.w = mul(a.w, b.w); - return c; -} + if (lane == 0) { + red_smem[warp] = sum; + } + __syncthreads(); -template <> -inline __device__ uint32_t mul(uint32_t a, float b) { - float2 tmp = half2_to_float2(a); - float2 tmp_res; - tmp_res.x = tmp.x * b; - tmp_res.y = tmp.y * b; - uint32_t res = float2_to_half2(tmp_res); - return res; -} + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } -template <> -inline __device__ float2 mul(uint32_t a, float b) { - float2 tmp = half2_to_float2(a); - float2 res; - res.x = tmp.x * b; - res.y = tmp.y * b; - return res; -} +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } -template <> -inline __device__ uint2 mul(uint2 a, float b) { - uint2 res; - res.x = mul(a.x, b); - res.y = mul(a.y, b); - return res; + return __shfl_sync(uint32_t(-1), sum, 0); } -template <> -inline __device__ uint4 mul(uint4 a, float b) { - uint4 res; - res.x = mul(a.x, b); - res.y = mul(a.y, b); - res.z = mul(a.z, b); - res.w = mul(a.w, b); - return res; +inline __device__ void convert_from_float(float &dst, float src) { // NOLINT + dst = src; } -template <> -inline __device__ float2 mul(float2 a, float b) { - float2 res; - res.x = a.x * b; - res.y = a.y * b; - return res; +inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT + dst = src; } -template <> -inline __device__ float2 mul(float2 a, uint32_t b) { - float2 tmp_b = half2_to_float2(b); - float2 res; - res.x = a.x * tmp_b.x; - res.y = a.y * tmp_b.y; - return res; +inline __device__ void convert_from_float(plat::float16 &dst, // NOLINT + float src) { + dst = static_cast(src); } -template <> -inline __device__ float4 mul(float4 a, float b) { - float4 res; - res.x = a.x * b; - res.y = a.y * b; - res.z = a.z * b; - res.w = a.w * b; - return res; +inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); } -template -inline __device__ Qk_vec apply_rotary_emb(Qk_vec input_left, - Qk_vec input_right, - Qk_vec cos_emb, - Qk_vec sin_emb, - float alpha) { - Qk_vec res1 = mul(input_left, cos_emb); - Qk_vec res2 = mul(input_right, sin_emb); - res2 = mul(res2, alpha); - return add(res1, res2); +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16 &dst, // NOLINT + float src) { // NOLINT + dst = __float2bfloat16(src); } -inline __device__ float sum(float v) { return v; } -inline __device__ float sum(float2 v) { return v.x + v.y; } -inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -inline __device__ float sum(uint16_t v) { return half_to_float(v); } -inline __device__ float sum(uint32_t v) { - float2 tmp = half2_to_float2(v); - return tmp.x + tmp.y; +inline __device__ void convert_from_float(__nv_bfloat162 &dst, // NOLINT + float2 src) { // NOLINT +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif } -inline __device__ float sum(uint2 v) { - uint32_t c = add(v.x, v.y); - return sum(c); +inline __device__ void convert_from_float(bf16_4_t &dst, // NOLINT + Float4_ src) { // NOLINT +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif } -inline __device__ float sum(uint4 v) { - uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); - return sum(c); -} +//////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ float dot(T a, T b) { - return sum(mul(a, b)); +inline __device__ void convert_from_float(bf16_4_t &dst, // NOLINT + float4 src) { // NOLINT + convert_from_float( + dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); } -template -inline __device__ float dot(T a, T b) { - return sum(mul(a, b)); +inline __device__ void convert_from_float(bf16_8_t &dst, // NOLINT + Float8_ src) { // NOLINT +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif } +#endif // ENABLE_BF16 -inline __device__ constexpr uint32_t shfl_mask(int threads) { - return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; -} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT template -inline __device__ __host__ T div_up(T m, T n) { - return (m + n - 1) / n; +inline __device__ void zero(T &dst) { // NOLINT + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; } -inline __device__ float fma(float a, float b, float c) { return a * b + c; } +template +__global__ +__launch_bounds__(THREADS_PER_BLOCK) void masked_multihead_attention_kernel_int8( + Masked_multihead_attention_params params, + LoadFunc load_func, + StoreFunc store_func, + uint8_t *cache_kv_I, + float cache_k_quant_scale, + float cache_v_quant_scale, + float cache_k_dequant_scale, + float cache_v_dequant_scale) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + const int bi = blockIdx.y; + if (params.sequence_lengths && params.sequence_lengths[bi] == 0) { + return; + } -inline __device__ float2 fma(float2 a, float2 b, float2 c) { - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} + typedef PDDataTypeTraits traits_; + typedef typename traits_::DataType DataType_; -inline __device__ float2 fma(float2 a, uint32_t b, float2 c) { - float2 tmp_b = half2_to_float2(b); - float2 d = fma(a, tmp_b, c); - return d; -} + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); -inline __device__ float4 fma(float4 a, float4 b, float4 c) { - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} + constexpr int WARP_SIZE = 32; + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; -inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(d) - : "r"(a), "r"(b), "r"(c)); - return d; -} + extern __shared__ char smem_[]; -inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { - uint2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} + float *qk_smem = reinterpret_cast(smem_); -inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { - uint4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} + char *logits_smem_ = smem_; + // fp32 accum for logits + float *logits_smem = reinterpret_cast(logits_smem_); -inline __device__ float2 fma(float a, float2 b, float2 c) { - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} + T *out_smem = reinterpret_cast(smem_); -inline __device__ float4 fma(float a, float4 b, float4 c) { - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + using Qk_vec = typename Qk_vec_::Type; + using Qk_vec_RoPE = typename Qk_vec_RoPE_::Type; + using QK_Packed_Int8_t = + typename packed_type::value>::type; + __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} + // beam id + const int beami = bi % params.beam_width; + // real batch id + const int bbi = bi / params.beam_width; + const int hi = blockIdx.x; + const int bhi = bi * params.num_head + hi; + const int bbhi = bbi * params.beam_width * params.num_head + hi; + const int tid = threadIdx.x; -inline __device__ uint32_t h0_h0(uint16_t a) { - uint32_t b; - asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); - return b; -} + const int bi_seq_len_offset = bi * params.max_seq_length; -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { - return fma(h0_h0(a), b, c); -} + float qk_max = -FLT_MAX; + float qk = 0; -inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { - uint32_t s = h0_h0(a); - uint2 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} + int act_time_step = params.sequence_lengths == nullptr + ? params.timestep + : params.sequence_lengths[bi]; -inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { - uint32_t s = h0_h0(a); - uint4 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} + __shared__ float k_q_scale; + __shared__ float k_dq_scale; + __shared__ float v_q_scale; + __shared__ float v_dq_scale; -inline __device__ float cast_to_float(float u) { return u; } + k_q_scale = cache_k_quant_scale; + k_dq_scale = cache_k_dequant_scale; + v_q_scale = cache_v_quant_scale; + v_dq_scale = cache_v_dequant_scale; -inline __device__ float2 cast_to_float(float2 u) { return u; } + // qkv [B, S=1, 3, num_head, head_dim] + int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; -inline __device__ float4 cast_to_float(float4 u) { return u; } + constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // Use block reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; -inline __device__ Float8_ cast_to_float(uint4 u) { - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} + // cache_k, [B, num_head, head_dim / x, max_seq_len, x] + // x == 4/8 for FP32/FP16, 128bit, 16Byte + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); -template -inline __device__ float qk_dot_(const K_vec (&q)[N], - const K_vec (&k)[N], - float inv_sqrt_dh) { - K_vec inv_q = mul(q[0], inv_sqrt_dh); - K_vec qk_vec = mul(inv_q, k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - inv_q = mul(q[ii], inv_sqrt_dh); - qk_vec = fma(inv_q, k[ii], qk_vec); + // const T *q_base = params.qkv; + // const T *k_base = params.qkv + params.num_head * Dh; + T *q_bias_base = nullptr; + T *k_bias_base = nullptr; + + if (params.add_qkv_bias) { + q_bias_base = params.qkv_bias; + k_bias_base = params.qkv_bias + params.num_head * Dh; } - float qk = sum(qk_vec); + if (tid < QK_VECS_PER_WARP) { + int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; + int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; + + Qk_vec q; + zero(q); + // q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&q_base[qk_offset]) + // : q; + if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { + load_func.template load(q, qk_offset); + } + + Qk_vec k; + zero(k); + // k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&k_base[qk_offset]) + // : k; + if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { + load_func.template load(k, params.num_head * Dh + qk_offset); + } + + if (params.add_qkv_bias) { + Qk_vec q_bias; + zero(q_bias); + Qk_vec k_bias; + zero(k_bias); + + q_bias = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) + : q_bias; + k_bias = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) + : k_bias; + + q = add(q, q_bias); + // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 + // we may not require k_bias. + k = add(k, k_bias); + } + + if (!params.neox_rotary_style) { + if (params.rotary_emb_dims != 0) { + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const float *cos_base = params.rotary_emb; + const float *sin_base = params.rotary_emb + params.batch_size * Dh; + Qk_vec_RoPE cos_emb, sin_emb; + zero(cos_emb); + zero(sin_emb); + cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &cos_base[rotary_offset]) + : cos_emb; + sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &sin_base[rotary_offset]) + : sin_emb; + apply_rotary_embedding(q, k, cos_emb, sin_emb); + } + } else { + /* old rotary pos emb */ + if (params.rotary_emb_dims != 0) { + int last_dim = Dh / params.rotary_emb_dims; + int half_lastdim = last_dim / 2; + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const float *cos_base = params.rotary_emb; + const float *sin_base = params.rotary_emb + params.batch_size * Dh; + int stride = half_lastdim / QK_VEC_SIZE; + int stride_all_lastdim = 2 * stride; + int right_id = tid / stride_all_lastdim * stride_all_lastdim + + (tid + stride) % (stride_all_lastdim); + int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE; + int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE; + Qk_vec q_right; + zero(q_right); + // q_right = + // (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&q_base[qk_right_offset]) + // : q_right; + if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) { + load_func.template load(q_right, qk_right_offset); + } + Qk_vec k_right; + zero(k_right); + // k_right = + // (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&k_base[qk_right_offset]) + // : k_right; + if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) { + load_func.template load( + k_right, params.num_head * Dh + qk_right_offset); + } + + if (params.add_qkv_bias) { + Qk_vec q_right_bias; + zero(q_right_bias); + q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &q_bias_base[qk_right_bias_offset]) + : q_right_bias; + Qk_vec k_right_bias; + zero(k_right_bias); + k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &k_bias_base[qk_right_bias_offset]) + : k_right_bias; + + q_right = add(q_right, q_right_bias); + k_right = add(k_right, k_right_bias); + } + + Qk_vec_RoPE cos_emb; + zero(cos_emb); + cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &cos_base[rotary_offset]) + : cos_emb; + + Qk_vec_RoPE sin_emb; + zero(sin_emb); + sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &sin_base[rotary_offset]) + : sin_emb; + float alpha = (tid % stride_all_lastdim) < stride + ? static_cast(-1) + : static_cast(1); + q = apply_rotary_emb( + q, q_right, cos_emb, sin_emb, alpha); + k = apply_rotary_emb( + k, k_right, cos_emb, sin_emb, alpha); + } + } + + *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; + + int co = tid / QK_VECS_IN_16B; + int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; + int offset = bhi * params.max_seq_length * Dh + + co * params.max_seq_length * QK_ELTS_IN_16B + + act_time_step * QK_ELTS_IN_16B + ci; + // quant k and store the int8 value into cache kv + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + QK_Packed_Int8_t k_tmp = round_tmp( + mul(k_q_scale, k)); + *reinterpret_cast(&cache_kv_I[offset]) = k_tmp; + } + + qk = dot(q, k); + + if (QK_VECS_PER_WARP <= WARP_SIZE) { #pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } } - return qk; -} + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = + (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + if (tid == 0) { + // NOTE(wangxi): mask must be 0.0 + // T mask = params.attn_mask[ + // bi * (params.timestep + 1) + params.timestep]; + // qk += static_cast(mask); + qk *= params.inv_sqrt_dh; + qk_max = qk; + qk_smem[act_time_step] = qk; + } + __syncthreads(); -inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) { - float4 c; - float zero = 0.f; - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" + using K_vec = typename K_vec_::Type; + using K_vec_I = typename K_vec_I_::Type; + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} + int ko = tid / THREADS_PER_KEY; + int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], - const uint32_t (&k)[N], - float inv_sqrt_dh) { -#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 750 -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = uint32_t; -#endif - K_vec_acum inv_q = mul(q[0], inv_sqrt_dh); - K_vec_acum qk_vec = mul(inv_q, k[0]); + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, ""); + + K_vec q[K_VECS_PER_THREAD]; #pragma unroll - for (int ii = 1; ii < N; ++ii) { - inv_q = mul(q[ii], inv_sqrt_dh); - qk_vec = fma(inv_q, k[ii], qk_vec); + for (int i = 0; i < K_VECS_PER_THREAD; ++i) { + q[i] = *reinterpret_cast( + &q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]); } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], - const K_vec (&k)[N], - float inv_sqrt_dh) { - return qk_dot_(q, k, inv_sqrt_dh); - } -}; + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + uint8_t *k_cache_I = &cache_kv_I[bhi * params.max_seq_length * Dh + ki]; + uint8_t *k_cache_batch_I = + &cache_kv_I[bbhi * params.max_seq_length * Dh + ki]; + + int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; + + const int *beam_offsets = params.beam_cache_offset + ? ¶ms.beam_cache_offset[bi_seq_len_offset] + : nullptr; + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + const int beam_offset = beam_offsets ? beam_offsets[ti] * params.num_head * + params.max_seq_length * Dh + : 0; + K_vec k[K_VECS_PER_THREAD]; + K_vec k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_seq_length + ti; + // get k from the cache_kv, and dequant k for qk operation + if (ti < act_time_step) { + if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) { + mul_pointer_v2( + &k[ii], + k_dq_scale, + // (beam_offset) ? reinterpret_cast(k_cache_batch_I + + // beam_offset + jj * QK_ELTS_IN_16B) : reinterpret_cast(k_cache_I + jj * QK_ELTS_IN_16B)); + reinterpret_cast(k_cache_I + jj * QK_ELTS_IN_16B)); + } else { + k[ii] = k_vec_zero; + } + } + } + + // NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k) + // may overflow with FP16 in large model. + float qk = Qk_dot::dot(q, k, params.inv_sqrt_dh); + + // bool is_mask = false; + if (ti < act_time_step && tid % THREADS_PER_KEY == 0) { + // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + auto mask_bhi = params.mask_broadcast_num_heads ? bi : bhi; + // T mask = params.attn_mask[mask_bhi * (params.timestep + 1) + ti]; + T mask = params.attn_mask[mask_bhi * params.mask_length + ti]; + qk += static_cast(mask); + qk_max = fmaxf(qk_max, qk); -template <> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], - const uint32_t (&k)[N], - float inv_sqrt_dh) { -#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 750 - return qk_hmma_dot_(q, k, inv_sqrt_dh); -#else - return qk_dot_<4>(q, k, inv_sqrt_dh); -#endif + qk_smem[ti] = qk; + } } -}; - -template -inline __device__ float block_sum(float *red_smem, float sum) { - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } + const int warp = tid / WARP_SIZE; + const int lane = tid % WARP_SIZE; + if (lane == 0) { - red_smem[warp] = sum; + red_smem[warp] = qk_max; } - __syncthreads(); - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } + __syncthreads(); + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } - return __shfl_sync(uint32_t(-1), sum, 0); -} + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); -inline __device__ void convert_from_float(float &dst, float src) { // NOLINT - dst = src; -} + float sum = 0.f; + for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { + // bool is_mask = false; + // float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); + float logit = __expf(qk_smem[ti] - qk_max); + sum += logit; + qk_smem[ti] = logit; + } -inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT - dst = src; -} + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); -inline __device__ void convert_from_float(plat::float16 &dst, // NOLINT - float src) { - dst = static_cast(src); -} + // FIXME(wangxi): need add 1.e-6f? + float inv_sum = __fdividef(1.f, sum + 1.e-6f); -inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} + for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { + convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); + } + __syncthreads(); -inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + using V_vec = typename V_vec_::Type; + using V_Packed_Int8_t = + typename packed_type::value>::type; -template -inline __device__ void zero(T &dst) { // NOLINT - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; + int vo = tid / THREADS_PER_VALUE; + int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; + + uint8_t *v_cache_I = &cache_kv_I[params.cache_batch_size * params.num_head * + params.max_seq_length * Dh + + bhi * params.max_seq_length * Dh + vi]; + uint8_t *v_cache_batch_I = + &cache_kv_I[params.batch_size * params.num_head * params.max_seq_length * + Dh + + bbhi * params.max_seq_length * Dh + vi]; + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec; +#endif + + V_vec_acum out; + zero(out); + + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) { + const int beam_offset = + beam_offsets + ? beam_offsets[ti] * params.num_head * params.max_seq_length * Dh + : 0; + V_vec v; + mul_pointer_v2( + &v, + v_dq_scale, + // (beam_offset) ? reinterpret_cast(v_cache_batch_I + beam_offset + ti * Dh) : + // reinterpret_cast(v_cache_I + ti * Dh)); + reinterpret_cast(v_cache_I + ti * Dh)); +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti]; + out = fma(logit, cast_to_float(v), out); +#else + DataType_ logit = static_cast(logits_smem[ti]); + // Update the partial sums. + out = fma(logit, v, out); +#endif + } + } + + V_vec v_bias; + zero(v_bias); + if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { + // V_vec v = *reinterpret_cast( + // ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); + V_vec v; + load_func.template load( + v, 2 * params.num_head * Dh + qkv_base_offset + vi); + if (params.add_qkv_bias) { + v_bias = *reinterpret_cast( + ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); + v = add(v, v_bias); + } + + V_Packed_Int8_t v_tmp = round_tmp( + mul(v_q_scale, v)); + *reinterpret_cast(&v_cache_I[act_time_step * Dh]) = + v_tmp; + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + out = fma(logits_smem[act_time_step], cast_to_float(v), out); +#else + out = fma(logits_smem[act_time_step], v, out); +#endif + } + + __syncthreads(); + + if (Dh == Dh_MAX || vi < Dh) { #pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + int midpoint = active_groups / 2; + + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float( + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), + out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = + add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } } - dst = tmp.raw; + + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + // convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + + // vi]), + // out); + V_vec tmp_out; + convert_from_float(tmp_out, out); + store_func.template store(tmp_out, bhi * Dh + vi); +#else + // *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; + store_func.template store(out, bhi * Dh + vi); +#endif + } + +#else + assert(false); +#endif } template + int THREADS_PER_BLOCK, + typename LoadFunc, + typename StoreFunc> __global__ void masked_multihead_attention_kernel( - Masked_multihead_attention_params params) { + Masked_multihead_attention_params params, + LoadFunc load_func, + StoreFunc store_func) { #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + const int bi = blockIdx.y; + if (params.sequence_lengths && params.sequence_lengths[bi] == 0) { + return; + } + typedef PDDataTypeTraits traits_; typedef typename traits_::DataType DataType_; @@ -765,13 +1443,23 @@ __global__ void masked_multihead_attention_kernel( __shared__ float red_smem[WARPS_PER_BLOCK * 2]; using Qk_vec = typename Qk_vec_::Type; + using Qk_vec_RoPE = typename Qk_vec_RoPE_::Type; __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; - const int bi = blockIdx.y; + // beam id + const int beami = bi % params.beam_width; + // real batch id + const int bbi = bi / params.beam_width; const int hi = blockIdx.x; const int bhi = bi * params.num_head + hi; + const int bbhi = bbi * params.beam_width * params.num_head + hi; + const int ti = + params.cum_offsets ? bi * params.seq_len - params.cum_offsets[bi] : -1; + const int thi = params.cum_offsets ? ti * params.num_head + hi : -1; const int tid = threadIdx.x; + const int bi_seq_len_offset = bi * params.max_seq_length; + float qk_max = -FLT_MAX; float qk = 0; @@ -793,10 +1481,15 @@ __global__ void masked_multihead_attention_kernel( constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); - const T *q_base = params.qkv; - const T *k_base = params.qkv + params.num_head * Dh; - const T *q_bias_base = params.qkv_bias; - const T *k_bias_base = params.qkv_bias + params.num_head * Dh; + // const T *q_base = params.qkv; + // const T *k_base = params.qkv + params.num_head * Dh; + T *q_bias_base = nullptr; + T *k_bias_base = nullptr; + + if (params.add_qkv_bias) { + q_bias_base = params.qkv_bias; + k_bias_base = params.qkv_bias + params.num_head * Dh; + } if (tid < QK_VECS_PER_WARP) { int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; @@ -804,92 +1497,134 @@ __global__ void masked_multihead_attention_kernel( Qk_vec q; zero(q); - q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_base[qk_offset]) - : q; + // q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&q_base[qk_offset]) + // : q; + if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { + load_func.template load(q, qk_offset); + } + Qk_vec k; zero(k); - k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_base[qk_offset]) - : k; - - Qk_vec q_bias; - zero(q_bias); - q_bias = - (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) - : q_bias; - Qk_vec k_bias; - zero(k_bias); - k_bias = - (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) - : k_bias; - - q = add(q, q_bias); - // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 - // we may not require k_bias. - k = add(k, k_bias); - - // rotary pos emb - if (params.rotary_emb_dims != 0) { - int last_dim = Dh / params.rotary_emb_dims; - int half_lastdim = last_dim / 2; - int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; - const T *cos_base = params.rotary_emb; - const T *sin_base = params.rotary_emb + params.batch_size * Dh; - int stride = half_lastdim / QK_VEC_SIZE; - int stride_all_lastdim = 2 * stride; - int right_id = tid / stride_all_lastdim * stride_all_lastdim + - (tid + stride) % (stride_all_lastdim); - int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE; - int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE; - Qk_vec q_right; - zero(q_right); - q_right = - (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_base[qk_right_offset]) - : q_right; - Qk_vec k_right; - zero(k_right); - k_right = - (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_base[qk_right_offset]) - : k_right; - - Qk_vec q_right_bias; - zero(q_right_bias); - q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - ? *reinterpret_cast( - &q_bias_base[qk_right_bias_offset]) - : q_right_bias; - Qk_vec k_right_bias; - zero(k_right_bias); - k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - ? *reinterpret_cast( - &k_bias_base[qk_right_bias_offset]) - : k_right_bias; - - q_right = add(q_right, q_right_bias); - k_right = add(k_right, k_right_bias); - - Qk_vec cos_emb; - zero(cos_emb); - cos_emb = - (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&cos_base[rotary_offset]) - : cos_emb; + // k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&k_base[qk_offset]) + // : k; + if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { + load_func.template load(k, params.num_head * Dh + qk_offset); + } - Qk_vec sin_emb; - zero(sin_emb); - sin_emb = + if (params.add_qkv_bias) { + Qk_vec q_bias; + zero(q_bias); + Qk_vec k_bias; + zero(k_bias); + + q_bias = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) + : q_bias; + k_bias = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&sin_base[rotary_offset]) - : sin_emb; - float alpha = (tid % stride_all_lastdim) < stride ? static_cast(-1) - : static_cast(1); - q = apply_rotary_emb(q, q_right, cos_emb, sin_emb, alpha); - k = apply_rotary_emb(k, k_right, cos_emb, sin_emb, alpha); + ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) + : k_bias; + + q = add(q, q_bias); + // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 + // we may not require k_bias. + k = add(k, k_bias); + } + + if (!params.neox_rotary_style) { + if (params.rotary_emb_dims != 0) { + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const float *cos_base = params.rotary_emb; + const float *sin_base = params.rotary_emb + params.rotary_bsz * Dh; + Qk_vec_RoPE cos_emb, sin_emb; + zero(cos_emb); + zero(sin_emb); + cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &cos_base[rotary_offset]) + : cos_emb; + sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &sin_base[rotary_offset]) + : sin_emb; + apply_rotary_embedding(q, k, cos_emb, sin_emb); + } + } else { + /* old rotary pos emb */ + if (params.rotary_emb_dims != 0) { + int last_dim = Dh / params.rotary_emb_dims; + int half_lastdim = last_dim / 2; + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const float *cos_base = params.rotary_emb; + const float *sin_base = params.rotary_emb + params.rotary_bsz * Dh; + int stride = half_lastdim / QK_VEC_SIZE; + int stride_all_lastdim = 2 * stride; + int right_id = tid / stride_all_lastdim * stride_all_lastdim + + (tid + stride) % (stride_all_lastdim); + int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE; + int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE; + Qk_vec q_right; + zero(q_right); + // q_right = + // (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&q_base[qk_right_offset]) + // : q_right; + if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) { + load_func.template load(q_right, qk_right_offset); + } + Qk_vec k_right; + zero(k_right); + // k_right = + // (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&k_base[qk_right_offset]) + // : k_right; + if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) { + load_func.template load( + k_right, params.num_head * Dh + qk_right_offset); + } + + if (params.add_qkv_bias) { + Qk_vec q_right_bias; + zero(q_right_bias); + q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &q_bias_base[qk_right_bias_offset]) + : q_right_bias; + Qk_vec k_right_bias; + zero(k_right_bias); + k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &k_bias_base[qk_right_bias_offset]) + : k_right_bias; + + q_right = add(q_right, q_right_bias); + k_right = add(k_right, k_right_bias); + } + + Qk_vec_RoPE cos_emb; + zero(cos_emb); + cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &cos_base[rotary_offset]) + : cos_emb; + + Qk_vec_RoPE sin_emb; + zero(sin_emb); + sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &sin_base[rotary_offset]) + : sin_emb; + float alpha = (tid % stride_all_lastdim) < stride + ? static_cast(-1) + : static_cast(1); + q = apply_rotary_emb( + q, q_right, cos_emb, sin_emb, alpha); + k = apply_rotary_emb( + k, k_right, cos_emb, sin_emb, alpha); + } } *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; @@ -899,6 +1634,7 @@ __global__ void masked_multihead_attention_kernel( int offset = bhi * params.max_seq_length * Dh + co * params.max_seq_length * QK_ELTS_IN_16B + act_time_step * QK_ELTS_IN_16B + ci; + // quant k and store the int8 value into cache kv if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { *reinterpret_cast(¶ms.cache_kv[offset]) = k; } @@ -918,10 +1654,6 @@ __global__ void masked_multihead_attention_kernel( qk = block_sum(&red_smem[WARPS_PER_RED], qk); } if (tid == 0) { - // NOTE(wangxi): mask must be 0.0 - // T mask = params.attn_mask[ - // bi * (params.timestep + 1) + params.timestep]; - // qk += static_cast(mask); qk *= params.inv_sqrt_dh; qk_max = qk; qk_smem[act_time_step] = qk; @@ -929,12 +1661,12 @@ __global__ void masked_multihead_attention_kernel( __syncthreads(); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - if (bi == 0 && hi == 0 && tid == 0) { - printf("=======q_out=======\n"); - for (int i = 0; i < Dh; ++i) printf("%f ", static_cast(q_smem[i])); - printf("\n"); - } - __syncthreads(); + // if (bi == 0 && hi == 0 && tid == 0) { + // printf("=======q_out=======\n"); + // for (int i = 0; i < Dh; ++i) printf("%f ", + // static_cast(q_smem[i])); printf("\n"); + // } + // __syncthreads(); #endif using K_vec = typename K_vec_::Type; @@ -959,21 +1691,38 @@ __global__ void masked_multihead_attention_kernel( constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; + T *k_cache_batch = ¶ms.cache_kv[bbhi * params.max_seq_length * Dh + ki]; + int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; + const int *beam_offsets = params.beam_cache_offset + ? ¶ms.beam_cache_offset[bi_seq_len_offset] + : nullptr; for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + const int beam_offset = beam_offsets ? beam_offsets[ti] * params.num_head * + params.max_seq_length * Dh + : 0; K_vec k[K_VECS_PER_THREAD]; K_vec k_vec_zero; zero(k_vec_zero); #pragma unroll for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { int jj = ii * params.max_seq_length + ti; + // get k from the cache_kv, and dequant k for qk operation if (ti < act_time_step) { - k[ii] = - (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) - ? *reinterpret_cast( - &k_cache[jj * QK_ELTS_IN_16B]) - : k_vec_zero; + if (beam_offset) { + k[ii] = + (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) + ? *reinterpret_cast( + &k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]) + : k_vec_zero; + } else { + k[ii] = + (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) + ? *reinterpret_cast( + &k_cache[jj * QK_ELTS_IN_16B]) + : k_vec_zero; + } } } @@ -984,8 +1733,11 @@ __global__ void masked_multihead_attention_kernel( // bool is_mask = false; if (ti < act_time_step && tid % THREADS_PER_KEY == 0) { // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - T mask = params.attn_mask[bi * (params.timestep + 1) + ti]; - qk += static_cast(mask); + auto mask_bhi = params.mask_broadcast_num_heads ? bi : bhi; + if (params.attn_mask) { + T mask = params.attn_mask[mask_bhi * params.mask_length + ti]; + qk += static_cast(mask); + } qk_max = fmaxf(qk_max, qk); qk_smem[ti] = qk; @@ -1015,12 +1767,12 @@ __global__ void masked_multihead_attention_kernel( qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - if (bi == 0 && hi == 0 && tid == 0) { - printf("=======qk_out=======\n"); - for (int i = 0; i <= params.timestep; ++i) printf("%f ", qk_smem[i]); - printf("qk_max=%f\n", qk_max); - } - __syncthreads(); + // if (bi == 0 && hi == 0 && tid == 0) { + // printf("=======qk_out=======\n"); + // for (int i = 0; i <= params.timestep; ++i) printf("%f ", qk_smem[i]); + // printf("qk_max=%f\n", qk_max); + // } + // __syncthreads(); #endif float sum = 0.f; @@ -1036,6 +1788,7 @@ __global__ void masked_multihead_attention_kernel( // FIXME(wangxi): need add 1.e-6f? float inv_sum = __fdividef(1.f, sum + 1.e-6f); + for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); } @@ -1047,9 +1800,12 @@ __global__ void masked_multihead_attention_kernel( int vo = tid / THREADS_PER_VALUE; int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; - T *v_cache = ¶ms.cache_kv[params.batch_size * params.num_head * + T *v_cache = ¶ms.cache_kv[params.cache_batch_size * params.num_head * params.max_seq_length * Dh + bhi * params.max_seq_length * Dh + vi]; + T *v_cache_batch = ¶ms.cache_kv[params.batch_size * params.num_head * + params.max_seq_length * Dh + + bbhi * params.max_seq_length * Dh + vi]; #ifdef MMHA_USE_FP32_ACUM_FOR_OUT using V_vec_acum = typename V_vec_acum_fp32_::Type; @@ -1063,7 +1819,17 @@ __global__ void masked_multihead_attention_kernel( constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; if (Dh == Dh_MAX || vi < Dh) { for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) { - V_vec v = *reinterpret_cast(&v_cache[ti * Dh]); + const int beam_offset = + beam_offsets + ? beam_offsets[ti] * params.num_head * params.max_seq_length * Dh + : 0; + V_vec v; + if (beam_offset) { + v = *reinterpret_cast( + &v_cache_batch[beam_offset + ti * Dh]); + } else { + v = *reinterpret_cast(&v_cache[ti * Dh]); + } #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti]; out = fma(logit, cast_to_float(v), out); @@ -1076,22 +1842,28 @@ __global__ void masked_multihead_attention_kernel( } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - if (bi == 0 && hi == 0 && tid == 0) { - printf("======logits_out=====\n"); - for (int i = 0; i <= params.timestep; ++i) printf("%f ", logits_smem[i]); - printf("\n"); - } - __syncthreads(); + // if (bi == 0 && hi == 0 && tid == 0) { + // printf("======logits_out=====\n"); + // for (int i = 0; i <= params.timestep; ++i) printf("%f ", logits_smem[i]); + // printf("\n"); + // } + // __syncthreads(); #endif V_vec v_bias; zero(v_bias); if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { - V_vec v = *reinterpret_cast( - ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); - v_bias = *reinterpret_cast( - ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); - v = add(v, v_bias); + // V_vec v = *reinterpret_cast( + // ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); + V_vec v; + load_func.template load( + v, 2 * params.num_head * Dh + qkv_base_offset + vi); + if (params.add_qkv_bias) { + v_bias = *reinterpret_cast( + ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); + v = add(v, v_bias); + } + *reinterpret_cast(&v_cache[act_time_step * Dh]) = v; #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) @@ -1129,21 +1901,28 @@ __global__ void masked_multihead_attention_kernel( if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { #ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), - out); + // convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + + // vi]), + // out); + V_vec tmp_out; + convert_from_float(tmp_out, out); + store_func.template store(tmp_out, + thi != -1 ? thi * Dh + vi : bhi * Dh + vi); #else - *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; + // *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; + store_func.template store(out, + thi != -1 ? thi * Dh + vi : bhi * Dh + vi); #endif } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - __syncthreads(); - if (bi == 0 && hi == 0 && tid == 0) { - printf("======fmha_out=====\n"); - for (int i = 0; i < Dh; ++i) - printf("%f ", static_cast(params.out[i])); - printf("\n"); - } + // __syncthreads(); + // if (bi == 0 && hi == 0 && tid == 0) { + // printf("======fmha_out=====\n"); + // for (int i = 0; i < Dh; ++i) + // printf("%f ", static_cast(params.out[i])); + // printf("\n"); + // } #endif #else assert(false); @@ -1172,34 +1951,479 @@ inline size_t smem_size_in_bytes( return max(softmax_sz, red_sz); } -#define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ - size_t smem_sz = \ - smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_head, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) - -template -void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, - const cudaStream_t &stream) { +#define MMHA_LAUNCH_KERNEL_INT8(T, \ + Dh, \ + Dh_MAX, \ + THDS_PER_KEY, \ + THDS_PER_VALUE, \ + THDS_PER_BLOCK, \ + stream, \ + load_func, \ + store_func, \ + cache_kv_I, \ + cache_k_quant_scale, \ + cache_v_quant_scale, \ + cache_k_dequant_scale, \ + cache_v_dequant_scale) \ + size_t smem_sz = \ + smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_head, params.batch_size); \ + constexpr auto kernel_fn = \ + masked_multihead_attention_kernel_int8; \ + if (smem_sz > 0xc000) { \ + cudaFuncSetAttribute( \ + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + kernel_fn<<>>(params, \ + load_func, \ + store_func, \ + cache_kv_I, \ + cache_k_quant_scale, \ + cache_v_quant_scale, \ + cache_k_dequant_scale, \ + cache_v_dequant_scale); + +#define MMHA_LAUNCH_KERNEL(T, \ + Dh, \ + Dh_MAX, \ + THDS_PER_KEY, \ + THDS_PER_VALUE, \ + THDS_PER_BLOCK, \ + stream, \ + load_func, \ + store_func) \ + size_t smem_sz = \ + smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_head, params.batch_size); \ + constexpr auto kernel_fn = \ + masked_multihead_attention_kernel; \ + if (smem_sz > 0xc000) { \ + cudaFuncSetAttribute( \ + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + kernel_fn<<>>( \ + params, load_func, store_func); + +template +void fmha_launch_kernel_impl_int8( + const Masked_multihead_attention_params ¶ms, + const cudaStream_t &stream, + LoadFunc load_func, + StoreFunc store_func, + uint8_t *cache_kv_I, + float cache_k_quant_scale, + float cache_v_quant_scale, + float cache_k_dequant_scale, + float cache_v_dequant_scale) { + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + if (params.timestep < 32) { + MMHA_LAUNCH_KERNEL_INT8(T, + Dh, + Dh_MAX, + 4, + THREADS_PER_VALUE, + 256, + stream, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } else if (params.timestep < 2048) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 + MMHA_LAUNCH_KERNEL_INT8(T, + Dh, + Dh_MAX, + 4, + THREADS_PER_VALUE, + BlockSizeMiddle, + stream, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); +#else + MMHA_LAUNCH_KERNEL_INT8(T, + Dh, + Dh_MAX, + 4, + THREADS_PER_VALUE, + BlockSizeMiddle, + stream, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); +#endif + } else { + MMHA_LAUNCH_KERNEL_INT8(T, + Dh, + Dh_MAX, + 4, + THREADS_PER_VALUE, + BlockSizeMax, + stream, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } +} + +template +void fmha_launch_kernel_impl(const Masked_multihead_attention_params ¶ms, + const cudaStream_t &stream, + LoadFunc load_func, + StoreFunc store_func) { constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; if (params.timestep < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); + MMHA_LAUNCH_KERNEL( + T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream, load_func, store_func); } else if (params.timestep < 2048) { #if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ __CUDA_ARCH__ >= 750 - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 256, stream); + MMHA_LAUNCH_KERNEL(T, + Dh, + Dh_MAX, + 4, + THREADS_PER_VALUE, + 256, + stream, + load_func, + store_func); #else - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); + MMHA_LAUNCH_KERNEL(T, + Dh, + Dh_MAX, + 2, + THREADS_PER_VALUE, + 128, + stream, + load_func, + store_func); #endif } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); + MMHA_LAUNCH_KERNEL(T, + Dh, + Dh_MAX, + 1, + THREADS_PER_VALUE, + 256, + stream, + load_func, + store_func); + } +} + +template +void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, + const cudaStream_t &stream, + LoadFunc load_func, + StoreFunc store_func, + uint8_t *cache_kv_I, + float cache_k_quant_scale, + float cache_v_quant_scale, + float cache_k_dequant_scale, + float cache_v_dequant_scale) { + if (WITH_INT8) { + int dev = 0; + int sm_count = 0; + cudaGetDevice(&dev); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + if (params.num_head * params.batch_size <= sm_count) { + fmha_launch_kernel_impl_int8( + params, + stream, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } else if (params.batch_size) { + fmha_launch_kernel_impl_int8( + params, + stream, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } + } else { + fmha_launch_kernel_impl( + params, stream, load_func, store_func); + } +} + +template +void fmha_impl(const phi::GPUContext &dev_ctx, + const Masked_multihead_attention_params ¶ms, + int dim_head, + LoadFunc load_func, + StoreFunc store_func, + uint8_t *cache_kv_I, + float cache_k_quant_scale, + float cache_v_quant_scale, + float cache_k_dequant_scale, + float cache_v_dequant_scale) { + switch (dim_head) { + case 10: + fmha_launch_kernel( + params, + dev_ctx.stream(), + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + break; + case 26: + fmha_launch_kernel( + params, + dev_ctx.stream(), + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + break; + case 32: + fmha_launch_kernel( + params, + dev_ctx.stream(), + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + break; + case 64: + fmha_launch_kernel( + params, + dev_ctx.stream(), + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + break; + case 96: + fmha_launch_kernel( + params, + dev_ctx.stream(), + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + break; + case 128: + fmha_launch_kernel( + params, + dev_ctx.stream(), + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + break; + case 192: + fmha_launch_kernel( + params, + dev_ctx.stream(), + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Dim_head = %d is unsupport!", dim_head)); + } +} + +template +void DispatchFMHA(const phi::GPUContext &dev_ctx, + const phi::DenseTensor &qkv_tensor, + const Masked_multihead_attention_params ¶ms, + int num_head, + int dim_head, + phi::DenseTensor *out_tensor, + uint8_t *cache_kv_I, + float cache_k_quant_scale, + float cache_v_quant_scale, + float cache_k_dequant_scale, + float cache_v_dequant_scale) { + MMHALoad load_func(qkv_tensor.data()); + MMHAStore store_func(out_tensor->data()); + fmha_impl( + dev_ctx, + params, + dim_head, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); +} + +template +void DispatchFMHA(const phi::GPUContext &dev_ctx, + const phi::DenseTensor &qkv_tensor, + const phi::DenseTensor &shift, + const phi::DenseTensor &smooth, + const Masked_multihead_attention_params ¶ms, + int num_head, + int dim_head, + phi::DenseTensor *out_tensor, + const phi::DenseTensor *dequant_qkv_scales = nullptr, + const float quant_fmha_out_scale = -1, + const int quant_round_type = 1, + const float quant_max_bound = 127.0f, + const float quant_min_bound = -127.0f, + uint8_t *cache_kv_I = nullptr, + float cache_k_quant_scale = -1.0f, + float cache_v_quant_scale = -1.0f, + float cache_k_dequant_scale = -1.0f, + float cache_v_dequant_scale = -1.0f) { + if (dequant_qkv_scales != nullptr && quant_fmha_out_scale > 0) { + MMHALoad load_func(qkv_tensor.data(), + dequant_qkv_scales->data(), + 3 * num_head * dim_head); + MMHAStore store_func(out_tensor->data(), + shift.data(), + smooth.data(), + num_head * dim_head, + quant_round_type, + quant_fmha_out_scale, + quant_max_bound, + quant_min_bound); + fmha_impl( + dev_ctx, + params, + dim_head, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } else if (dequant_qkv_scales == nullptr && quant_fmha_out_scale > 0) { + MMHALoad load_func(qkv_tensor.data()); + MMHAStore store_func(out_tensor->data(), + shift.data(), + smooth.data(), + num_head * dim_head, + quant_round_type, + quant_fmha_out_scale, + quant_max_bound, + quant_min_bound); + fmha_impl( + dev_ctx, + params, + dim_head, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } else if (dequant_qkv_scales != nullptr && quant_fmha_out_scale <= 0) { + MMHALoad load_func(qkv_tensor.data(), + dequant_qkv_scales->data(), + 3 * num_head * dim_head); + MMHAStore store_func(out_tensor->data(), + shift.data(), + smooth.data(), + num_head * dim_head); + fmha_impl( + dev_ctx, + params, + dim_head, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } else { + MMHALoad load_func(qkv_tensor.data()); + MMHAStore store_func(out_tensor->data(), + shift.data(), + smooth.data(), + num_head * dim_head); + fmha_impl( + dev_ctx, + params, + dim_head, + load_func, + store_func, + cache_kv_I, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); } } @@ -1207,117 +2431,188 @@ template void fmha(const phi::GPUContext &dev_ctx, const phi::DenseTensor &qkv_tensor, const phi::DenseTensor &qkv_bias_tensor, - const phi::DenseTensor &src_mask_tensor, + const phi::DenseTensor *src_mask_tensor, + const phi::DenseTensor *cum_offsets_tensor, const phi::DenseTensor *sequence_lengths_tensor, const phi::DenseTensor *rotary_tensor, + const phi::DenseTensor *beam_cache_offset_tensor, phi::DenseTensor *cache_kv_tensor, phi::DenseTensor *out_tensor, int batch_size, + int cache_batch_size, + int seq_len, int max_seq_length, int num_head, int dim_head, int timestep, int rotary_emb_dims, - float inv_sqrt_dh) { + float inv_sqrt_dh, + const bool mask_broadcast_num_heads = true, + const bool add_qkv_bias = true, + const bool neox_rotary_style = false, + const phi::DenseTensor *dequant_qkv_scales = nullptr, + const phi::DenseTensor *shift = nullptr, + const phi::DenseTensor *smooth = nullptr, + const float cache_k_quant_scale = -1.0, + const float cache_v_quant_scale = -1.0, + const float cache_k_dequant_scale = -1.0, + const float cache_v_dequant_scale = -1.0, + const float quant_fmha_out_scale = -1, + const int quant_round_type = 1, + const float quant_max_bound = 127.0f, + const float quant_min_bound = -127.0f) { Masked_multihead_attention_params params; - params.out = out_tensor->data(); - params.qkv = qkv_tensor.data(); - params.qkv_bias = qkv_bias_tensor.data(); - params.attn_mask = src_mask_tensor.data(); + // params.out = out_tensor->data(); + // params.qkv = qkv_tensor.data(); + + if (add_qkv_bias) { + // Because we may not add qkv_bias, so here we cast to T*. + // Author(zhengzekang). + params.qkv_bias = const_cast(qkv_bias_tensor.data()); + } + params.mask_broadcast_num_heads = mask_broadcast_num_heads; params.cache_kv = cache_kv_tensor->data(); + params.neox_rotary_style = neox_rotary_style; + if (src_mask_tensor) { + params.attn_mask = src_mask_tensor->data(); + params.mask_length = src_mask_tensor->dims()[3]; + } else { + params.attn_mask = nullptr; + params.mask_length = -1; + } + if (sequence_lengths_tensor) { params.sequence_lengths = sequence_lengths_tensor->data(); } + if (cum_offsets_tensor) { + params.cum_offsets = cum_offsets_tensor->data(); + } else { + params.cum_offsets = nullptr; + } + params.seq_len = seq_len; + if (rotary_emb_dims > 0) { - params.rotary_emb = rotary_tensor->data(); + params.rotary_emb = rotary_tensor->data(); + params.rotary_bsz = rotary_tensor->dims()[1]; } else { params.rotary_emb = nullptr; + params.rotary_bsz = 0; + } + + if (beam_cache_offset_tensor) { + if (cache_k_quant_scale > 0) { + PADDLE_THROW(phi::errors::Unimplemented( + "MMHA with int8 cache kv does not support beam search yet")); + } + params.beam_cache_offset = beam_cache_offset_tensor->data(); + params.beam_width = beam_cache_offset_tensor->dims()[1]; } + params.add_qkv_bias = add_qkv_bias; params.batch_size = batch_size; + params.cache_batch_size = cache_batch_size; params.num_head = num_head; params.timestep = timestep; params.max_seq_length = max_seq_length; params.inv_sqrt_dh = inv_sqrt_dh; params.rotary_emb_dims = rotary_emb_dims; - switch (dim_head) { - case 10: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 26: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 32: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 64: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 96: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 128: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - case 192: - fmha_launch_kernel(params, dev_ctx.stream()); - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "Dim_head = %d is unsupport!", dim_head)); + if (shift != nullptr) { + if (cache_k_quant_scale > 0) { + DispatchFMHA(dev_ctx, + qkv_tensor, + *shift, + *smooth, + params, + num_head, + dim_head, + out_tensor, + dequant_qkv_scales, + quant_fmha_out_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + cache_kv_tensor->data(), + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } else { + DispatchFMHA(dev_ctx, + qkv_tensor, + *shift, + *smooth, + params, + num_head, + dim_head, + out_tensor, + dequant_qkv_scales, + quant_fmha_out_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + nullptr, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } + } else { + if (cache_k_quant_scale > 0) { + DispatchFMHA(dev_ctx, + qkv_tensor, + params, + num_head, + dim_head, + out_tensor, + cache_kv_tensor->data(), + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } else { + DispatchFMHA(dev_ctx, + qkv_tensor, + params, + num_head, + dim_head, + out_tensor, + nullptr, + cache_k_quant_scale, + cache_v_quant_scale, + cache_k_dequant_scale, + cache_v_dequant_scale); + } } } -template -void fmha(const phi::GPUContext &dev_ctx, - const phi::DenseTensor &qkv_tensor, - const phi::DenseTensor &qkv_bias_tensor, - const phi::DenseTensor &src_mask_tensor, - phi::DenseTensor *cache_kv_tensor, - phi::DenseTensor *out_tensor, - int batch_size, - int max_seq_length, - int num_head, - int dim_head, - int timestep, - float inv_sqrt_dh) { - fmha(dev_ctx, - qkv_tensor, - qkv_bias_tensor, - src_mask_tensor, - nullptr, - nullptr, - cache_kv_tensor, - out_tensor, - batch_size, - max_seq_length, - num_head, - dim_head, - timestep, - 0, - inv_sqrt_dh); -} - // NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8 constexpr int VEC_16B = 16; template __global__ void write_cache_k_kernel(T *cache_k, const T *k, + const int *seq_lens, const int num_head, const int dim_head, const int seq_len, + const int prompt_num, const int max_seq_len) { const int bi = blockIdx.y; + const int seq_len_now = seq_len + prompt_num; + const int len = seq_lens ? seq_lens[bi] + prompt_num : seq_len_now; + if (len == 0) { + return; + } + const int hi = blockIdx.z; constexpr int X_ELEMS = VEC_16B / sizeof(T); // [bsz, num_head, seq_len, dim_head/x, x] auto k_src = reinterpret_cast( - k + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head); + k + bi * num_head * seq_len_now * dim_head + hi * seq_len_now * dim_head); // [bsz, num_head, dim_head/x, max_seq_len, x] auto k_dst = reinterpret_cast( cache_k + bi * num_head * max_seq_len * dim_head + @@ -1337,7 +2632,7 @@ __global__ void write_cache_k_kernel(T *cache_k, idx = idx / max_seq_len; const int k_vec_id = idx % dim_head_div_x; - if (k_seq_len_id < seq_len) { + if (k_seq_len_id < len) { k_dst[out_idx] = k_src[k_seq_len_id * dim_head_div_x + k_vec_id]; } } @@ -1345,16 +2640,24 @@ __global__ void write_cache_k_kernel(T *cache_k, template __global__ void write_cache_v_kernel(T *cache_v, const T *v, + const int *seq_lens, const int num_head, const int dim_head, const int seq_len, + const int prompt_num, const int max_seq_len) { const int bi = blockIdx.y; + const int seq_len_now = seq_len + prompt_num; + const int len = seq_lens ? seq_lens[bi] + prompt_num : seq_len_now; + if (len == 0) { + return; + } + const int hi = blockIdx.z; // [bsz, num_head, seq_len, dim_head/x, x] auto v_src = reinterpret_cast( - v + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head); + v + bi * num_head * seq_len_now * dim_head + hi * seq_len_now * dim_head); // [bsz, num_head, max_seq_len, dim_head/x, x] auto v_dst = reinterpret_cast( cache_v + bi * num_head * max_seq_len * dim_head + @@ -1364,20 +2667,321 @@ __global__ void write_cache_v_kernel(T *cache_v, constexpr int X_ELEMS = VEC_16B / sizeof(T); const int dim_head_div_x = dim_head / X_ELEMS; - if (idx >= dim_head_div_x * seq_len) return; + if (idx >= dim_head_div_x * len) return; v_dst[idx] = v_src[idx]; } +template +__forceinline__ __device__ void VectorizedQuant(const T *in, + const float scale, + const int round_type, + const float max_bound, + const float min_bound, + uint8_t *quant_out) { + phi::AlignedVector in_vec{}; + phi::AlignedVector quant_out_vec{}; + phi::Load(&in[0], &in_vec); + +#pragma unroll + for (int unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { + float quant_value = scale * static_cast(in_vec[unroll_idx]); + if (round_type == 0) { + quant_value = static_cast(roundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + // TODO(Zhengzekang): if use int4 CacheKV, we may pass a int template named + // Quantbit, and 128.0f -> (1 << Quantbit) + quant_out_vec[unroll_idx] = static_cast(quant_value + 128.0f); + } + phi::Store(quant_out_vec, &quant_out[0]); +} + +template +__global__ void write_cache_k_int8_kernel(uint8_t *cache_k, + const T *k, + const int *seq_lens, + const int num_head, + const int dim_head, + const int seq_len, + const int prompt_num, + const int max_seq_len, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + const int bi = blockIdx.y; + const int seq_len_now = seq_len + prompt_num; + const int len = seq_lens ? seq_lens[bi] + prompt_num : seq_len_now; + if (len == 0) { + return; + } + + const int hi = blockIdx.z; + constexpr int X_ELEMS = VEC_16B / sizeof(T); + using Packed_Int8_t = typename packed_type::type; + + // [bsz, num_head, seq_len, dim_head/x, x] + auto k_src = reinterpret_cast( + k + bi * num_head * seq_len_now * dim_head + hi * seq_len_now * dim_head); + // [bsz, num_head, dim_head/x, max_seq_len, x] + auto k_dst = reinterpret_cast( + cache_k + bi * num_head * max_seq_len * dim_head + + hi * max_seq_len * dim_head); + + const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; + // vec size + int dim_head_div_x = dim_head / X_ELEMS; + + // FIXME(wangxi): num_head is not need? + // if (out_idx >= num_head * dim_head_div_x * max_seq_len) return; + if (out_idx >= dim_head_div_x * max_seq_len) return; + + int idx = out_idx; + const int k_seq_len_id = idx % max_seq_len; + // idx = (idx - k_seq_len_id) / max_seq_len; + idx = idx / max_seq_len; + const int k_vec_id = idx % dim_head_div_x; + + if (k_seq_len_id < len) { + VectorizedQuant( + reinterpret_cast(k_src + + (k_seq_len_id * dim_head_div_x + k_vec_id)), + scale, + round_type, + max_bound, + min_bound, + reinterpret_cast(k_dst + out_idx)); + // k_dst[out_idx] = k_src[k_seq_len_id * dim_head_div_x + k_vec_id]; + } +} + +template +__global__ void write_cache_v_int8_kernel(uint8_t *cache_v, + const T *v, + const int *seq_lens, + const int num_head, + const int dim_head, + const int seq_len, + const int prompt_num, + const int max_seq_len, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + const int bi = blockIdx.y; + const int seq_len_now = seq_len + prompt_num; + const int len = seq_lens ? seq_lens[bi] + prompt_num : seq_len_now; + if (len == 0) { + return; + } + + const int hi = blockIdx.z; + + constexpr int X_ELEMS = VEC_16B / sizeof(T); + // [bsz, num_head, seq_len, dim_head/x, x] + using Packed_Int8_t = typename packed_type::type; + + auto v_src = reinterpret_cast( + v + bi * num_head * seq_len_now * dim_head + hi * seq_len_now * dim_head); + // [bsz, num_head, max_seq_len, dim_head/x, x] + auto v_dst = reinterpret_cast( + cache_v + bi * num_head * max_seq_len * dim_head + + hi * max_seq_len * dim_head); + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + const int dim_head_div_x = dim_head / X_ELEMS; + + if (idx >= dim_head_div_x * len) return; + + VectorizedQuant(reinterpret_cast(v_src + idx), + scale, + round_type, + max_bound, + min_bound, + reinterpret_cast(v_dst + idx)); + // v_dst[idx] = v_src[idx]; +} + +template +void write_int8_cache_kv(const phi::GPUContext &dev_ctx, + uint8_t *cache_k, + uint8_t *cache_v, + const T *k, + const T *v, + const int *seq_lens, + const int bsz, + const int num_head, + const int seq_len, + const int prompt_num, + const int max_seq_len, + const int dim_head, + const int round_type, + const float max_bound, + const float min_bound, + const float cache_k_scale = -1.0, + const float cache_v_scale = -1.0) { + constexpr int block_sz = 128; + constexpr int x = VEC_16B / sizeof(T); + + assert(dim_head % x == 0); + PADDLE_ENFORCE_EQ( + dim_head % x, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", dim_head, x)); + + int max_size = max_seq_len * dim_head / x; + int size = (seq_len + prompt_num) * dim_head / x; + dim3 grid(div_up(max_size, block_sz), bsz, num_head); + dim3 grid_v(div_up(size, block_sz), bsz, num_head); + + // transpose [bsz, num_head, seq_len, dim_head/x, x]-> + // [bsz, num_head, dim_head/x, max_seq_len, x] + int gk = bsz * num_head * max_seq_len; + write_cache_k_int8_kernel<<>>( + cache_k, + k, + seq_lens, + num_head, + dim_head, + seq_len, + prompt_num, + max_seq_len, + cache_k_scale, + round_type, + max_bound, + min_bound); + + // copy [bsz, num_head, seq_len, dim_head/x, x]-> + // [bsz, num_head, max_seq_len, dim_head/x, x] + int gv = bsz * num_head * max_seq_len; + write_cache_v_int8_kernel<<>>( + cache_v, + v, + seq_lens, + num_head, + dim_head, + seq_len, + prompt_num, + max_seq_len, + cache_v_scale, + round_type, + max_bound, + min_bound); +} + +template +void write_int8_cache_kv(const phi::GPUContext &dev_ctx, + uint8_t *cache_k, + uint8_t *cache_v, + const T *k, + const T *v, + const int bsz, + const int num_head, + const int seq_len, + const int max_seq_len, + const int dim_head, + const int round_type, + const float max_bound, + const float min_bound, + const float cache_k_scale = -1.0, + const float cache_v_scale = -1.0) { + write_int8_cache_kv(dev_ctx, + cache_k, + cache_v, + k, + v, + nullptr, + bsz, + num_head, + seq_len, + 0, /*prompt_num*/ + max_seq_len, + dim_head, + round_type, + max_bound, + min_bound, + cache_k_scale, + cache_v_scale); +} + +template +void WriteInt8CacheKV(const phi::GPUContext &dev_ctx, + const phi::DenseTensor *pre_cache_kv_out, + phi::DenseTensor *cache_kv_out, + const phi::DenseTensor *kv_transpose_out, + const int *sequence_lengths_data, + const int cache_bsz, + const int bsz, + const int num_head, + const int seq_len, + const int dim_head, + const int cache_offset, + const int round_type, + const float max_bound, + const float min_bound, + const float cache_k_scale = -1.0, + const float cache_v_scale = -1.0) { + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out->data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; + } else { + // [3, bsz, num_head, seq_len, head_dim] + int64_t k_size = bsz * seq_len * num_head * dim_head; + k_ptr = kv_transpose_out->data(); + v_ptr = k_ptr + k_size; + } + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + uint8_t *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = cache_bsz * num_head * max_seq_len * dim_head; + + uint8_t *cache_k_ptr = cache_kv_data; + uint8_t *cache_v_ptr = cache_kv_data + cache_k_size; + + // const int seq_len_tmp = seq_len + cache_offset; + write_int8_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + sequence_lengths_data, + bsz, + num_head, + seq_len, + cache_offset, // prompt_num + max_seq_len, + dim_head, + round_type, + max_bound, + min_bound, + cache_k_scale, + cache_v_scale); +} + template void write_cache_kv(const phi::GPUContext &dev_ctx, T *cache_k, T *cache_v, const T *k, const T *v, + const int *seq_lens, const int bsz, const int num_head, const int seq_len, + const int prompt_num, const int max_seq_len, const int dim_head) { constexpr int block_sz = 128; @@ -1391,19 +2995,211 @@ void write_cache_kv(const phi::GPUContext &dev_ctx, "dim_head=%d must be divisible by vec_size=%d", dim_head, x)); int max_size = max_seq_len * dim_head / x; - int size = seq_len * dim_head / x; + int size = (seq_len + prompt_num) * dim_head / x; dim3 grid(div_up(max_size, block_sz), bsz, num_head); dim3 grid_v(div_up(size, block_sz), bsz, num_head); // transpose [bsz, num_head, seq_len, dim_head/x, x]-> // [bsz, num_head, dim_head/x, max_seq_len, x] - write_cache_k_kernel<<>>( - cache_k, k, num_head, dim_head, seq_len, max_seq_len); + write_cache_k_kernel<<>>(cache_k, + k, + seq_lens, + num_head, + dim_head, + seq_len, + prompt_num, + max_seq_len); // copy [bsz, num_head, seq_len, dim_head/x, x]-> // [bsz, num_head, max_seq_len, dim_head/x, x] - write_cache_v_kernel<<>>( - cache_v, v, num_head, dim_head, seq_len, max_seq_len); + write_cache_v_kernel<<>>(cache_v, + v, + seq_lens, + num_head, + dim_head, + seq_len, + prompt_num, + max_seq_len); +} + +template +void write_cache_kv(const phi::GPUContext &dev_ctx, + T *cache_k, + T *cache_v, + const T *k, + const T *v, + const int bsz, + const int num_head, + const int seq_len, + const int max_seq_len, + const int dim_head) { + write_cache_kv(dev_ctx, + cache_k, + cache_v, + k, + v, + nullptr, + bsz, + num_head, + seq_len, + 0, + max_seq_len, + dim_head); +} + +template +void WriteCacheKV(const phi::GPUContext &dev_ctx, + const phi::DenseTensor *pre_cache_kv_out, + phi::DenseTensor *cache_kv_out, + const phi::DenseTensor *kv_transpose_out, + const int *sequence_lengths_data, + const int cache_bsz, + const int bsz, + const int num_head, + const int seq_len, + const int dim_head, + const int cache_offset) { + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out->data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; + } else { + // [3, bsz, num_head, seq_len, head_dim] + int64_t k_size = bsz * seq_len * num_head * dim_head; + k_ptr = kv_transpose_out->data(); + v_ptr = k_ptr + k_size; + } + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = cache_bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + sequence_lengths_data, + bsz, + num_head, + seq_len, + cache_offset, + max_seq_len, + dim_head); +} + +template +__global__ void fusedQKV_transpose_split_kernel(T *q_buf, + T *kv_buf, + const T *qkv, + const int *padding_offset, + const int *seq_lens, + const int32_t elem_cnt, + const int batch_size, + const int max_len_this_time, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head) { + const int32_t offset = + batch_size * max_len_this_time * head_num * size_per_head; + const int32_t hidden_size = head_num * size_per_head; + const int32_t fused_hidden_size = 3 * hidden_size; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = phi::AlignedVector; + LoadT src_vec; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + phi::Load(&qkv[linear_index], &src_vec); + int32_t bias_idx = linear_index % fused_hidden_size; + const int32_t token_idx = linear_index / fused_hidden_size; + const int32_t ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int32_t target_batch_id = ori_token_idx / seq_len; + if (seq_lens[target_batch_id] == 0) continue; + const int32_t seq_id = ori_token_idx % seq_len; + + const int32_t qkv_id = bias_idx / hidden_size; + const int32_t head_id = (linear_index % hidden_size) / size_per_head; + const int32_t size_id = linear_index % size_per_head; + + if (qkv_id == 0) { + phi::Store( + src_vec, + &q_buf[target_batch_id * head_num * max_len_this_time * + size_per_head + + head_id * max_len_this_time * size_per_head + + seq_id * size_per_head + size_id]); + } else { + const int32_t kv_store_offset = (qkv_id - 1) * offset; + phi::Store( + src_vec, + &kv_buf[kv_store_offset + + target_batch_id * head_num * max_len_this_time * + size_per_head + + head_id * max_len_this_time * size_per_head + + seq_id * size_per_head + size_id]); + } + } +} + +template +__global__ void fusedQKV_transpose_split_kernel(T *q_buf, + T *k_buf, + T *v_buf, + const T *qkv, + const int *padding_offset, + const int *seq_lens, + const int32_t elem_cnt, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head) { + const int32_t hidden_size = head_num * size_per_head; + const int32_t fused_hidden_size = 3 * hidden_size; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = phi::AlignedVector; + LoadT src_vec; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + phi::Load(&qkv[linear_index], &src_vec); + int32_t bias_idx = linear_index % fused_hidden_size; + const int32_t token_idx = linear_index / fused_hidden_size; + const int32_t ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int32_t target_batch_id = ori_token_idx / seq_len; + if (seq_lens[target_batch_id] == 0) continue; + + const int32_t qkv_id = bias_idx / hidden_size; + const int32_t head_id = (linear_index % hidden_size) / size_per_head; + const int32_t size_id = linear_index % size_per_head; + + const int32_t write_idx = + token_idx * hidden_size + head_id * size_per_head + size_id; + if (qkv_id == 0) { + phi::Store(src_vec, &q_buf[write_idx]); + } else if (qkv_id == 1) { + phi::Store(src_vec, &k_buf[write_idx]); + } else { + phi::Store(src_vec, &v_buf[write_idx]); + } + } } template @@ -1487,6 +3283,86 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { return cudaSuccess; } +template +void qkv_transpose_split(const phi::GPUContext &dev_ctx, + T *q_buf, + T *kv_buf, + const T *qkv, + const int *padding_offset, + const int *seq_lens, + const int token_num, + const int batch_size, + const int head_num, + const int max_len_this_time, + const int seq_len, + const int size_per_head) { + const int32_t elem_cnt = token_num * head_num * size_per_head * 3; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(size_per_head % PackSize, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", + size_per_head, + PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t blocksize = 128; + int32_t grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + fusedQKV_transpose_split_kernel + <<>>(q_buf, + kv_buf, + qkv, + padding_offset, + seq_lens, + elem_cnt, + batch_size, + max_len_this_time, + seq_len, + token_num, + head_num, + size_per_head); +} + +template +void qkv_transpose_split(const phi::GPUContext &dev_ctx, + T *q_buf, + T *k_buf, + T *v_buf, + const T *qkv, + const int *padding_offset, + const int *seq_lens, + const int token_num, + const int batch_size, + const int head_num, + const int seq_len, + const int size_per_head) { + const int32_t elem_cnt = token_num * head_num * size_per_head * 3; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(size_per_head % PackSize, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", + size_per_head, + PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t blocksize = 128; + int32_t grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + fusedQKV_transpose_split_kernel + <<>>(q_buf, + k_buf, + v_buf, + qkv, + padding_offset, + seq_lens, + elem_cnt, + batch_size, + seq_len, + token_num, + head_num, + size_per_head); +} + template void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, T *q_buf, @@ -1541,10 +3417,52 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, } } +/* old rope emb */ +template +__global__ void NeoXRotaryKernel(const T *input, + const float *cos_emb, + const float *sin_emb, + const int *sequence_lengths, + T *output, + const int rotary_emb_dims, + const int batch_size, + const int head_num, + const int seq_len, + const int last_dim) { + int bi = blockIdx.x; + int hi = blockIdx.y; + int si = blockIdx.z; + if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return; + int half_lastdim = last_dim / 2; + for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) { + int base_idx = bi * head_num * seq_len * last_dim + + hi * seq_len * last_dim + si * last_dim; + int left_idx = base_idx + ti; + const int right_idx = base_idx + ti + half_lastdim; + int emb_idx_left = bi * seq_len * last_dim + si * last_dim + ti; + int emb_idx_right = + bi * seq_len * last_dim + si * last_dim + ti + half_lastdim; + float input_left = static_cast(input[left_idx]); + float input_right = static_cast(input[right_idx]); + + float cos_tmp_left = cos_emb[emb_idx_left]; + float sin_tmp_left = sin_emb[emb_idx_left]; + float cos_tmp_right = cos_emb[emb_idx_right]; + float sin_tmp_right = sin_emb[emb_idx_right]; + + T res1 = + static_cast(input_left * cos_tmp_left - input_right * sin_tmp_left); + T res2 = static_cast(input_right * cos_tmp_right + + input_left * sin_tmp_right); + output[left_idx] = res1; + output[right_idx] = res2; + } +} + template -__global__ void RotrayKernel(const T *input, - const T *cos_emb, - const T *sin_emb, +__global__ void RotaryKernel(const T *input, + const float *cos_emb, + const float *sin_emb, const int *sequence_lengths, T *output, const int rotary_emb_dims, @@ -1562,15 +3480,51 @@ __global__ void RotrayKernel(const T *input, for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) { int base_idx = bi * head_num * seq_len * last_dim + hi * seq_len * last_dim + si * last_dim; - int left_idx = base_idx + ti; - const int right_idx = base_idx + ti + half_lastdim; - int emb_idx = bi * seq_len * last_dim + si * last_dim + ti; - T input_left = input[left_idx]; - T input_right = input[right_idx]; - T cos_tmp = cos_emb[emb_idx]; - T sin_tmp = sin_emb[emb_idx]; - T res1 = input_left * cos_tmp - input_right * sin_tmp; - T res2 = input_right * cos_tmp + input_left * sin_tmp; + int left_idx = base_idx + 2 * ti; + const int right_idx = base_idx + 2 * ti + 1; + int emb_idx = bi * seq_len * last_dim + si * last_dim + 2 * ti; + float input_left = static_cast(input[left_idx]); + float input_right = static_cast(input[right_idx]); + float cos_tmp = cos_emb[emb_idx]; + float sin_tmp = sin_emb[emb_idx]; + T res1 = static_cast(input_left * cos_tmp - input_right * sin_tmp); + T res2 = static_cast(input_right * cos_tmp + input_left * sin_tmp); + output[left_idx] = res1; + output[right_idx] = res2; + } +} + +template +__global__ void RotaryKernel(const T *input, + const float *cos_emb, + const float *sin_emb, + const int *sequence_lengths, + T *output, + const int rotary_emb_dims, + const int batch_size, + const int head_num, + const int max_len_this_time, + const int seq_len, + const int last_dim) { + int bi = blockIdx.x; + int hi = blockIdx.y; + int si = blockIdx.z; + if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return; + int half_lastdim = last_dim / 2; + // Note(ZhenyuLi): Calculate the relevant data at one time, so that no + // additional space is required. + for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) { + int base_idx = bi * head_num * max_len_this_time * last_dim + + hi * max_len_this_time * last_dim + si * last_dim; + int left_idx = base_idx + 2 * ti; + const int right_idx = base_idx + 2 * ti + 1; + int emb_idx = bi * seq_len * last_dim + si * last_dim + 2 * ti; + float input_left = static_cast(input[left_idx]); + float input_right = static_cast(input[right_idx]); + float cos_tmp = cos_emb[emb_idx]; + float sin_tmp = sin_emb[emb_idx]; + T res1 = static_cast(input_left * cos_tmp - input_right * sin_tmp); + T res2 = static_cast(input_right * cos_tmp + input_left * sin_tmp); output[left_idx] = res1; output[right_idx] = res2; } @@ -1582,20 +3536,22 @@ void rotary_qk(const phi::GPUContext &dev_ctx, T *k, // kv const T *q_input, // q const T *k_input, // kv - const T *rotary_emb, + const float *rotary_emb, const int *sequence_lengths, const int rotary_emb_dims, + const int rope_bsz, const int batch_size, const int head_num, + const int max_len_this_time, const int seq_len, const int dim_head) { - // q_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, - // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] - // kv_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, - // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] rotary_emb [2, bs, - // 1, seq_len, dim_head] -> [2, bs, 1, seq_len * rotary_emb_dims, dim_head / - // rotary_emb_dims] - dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims); + // q_transpose_out_data [bs, head_num, max_len_this_time, dim_head] -> [bs, + // head_num, max_len_this_time * rotary_emb_dims, dim_head / rotary_emb_dims] + // kv_transpose_out_data [bs, head_num, max_len_this_time, dim_head] -> [bs, + // head_num, max_len_this_time * rotary_emb_dims, dim_head / rotary_emb_dims] + // rotary_emb [2, bs, 1, seq_len, dim_head] -> [2, bs, 1, seq_len * + // rotary_emb_dims, dim_head / rotary_emb_dims] + dim3 grid(batch_size, head_num, max_len_this_time * rotary_emb_dims); const int last_dim = dim_head / rotary_emb_dims; auto getBlockSize = [](int dim) { if (dim > 256) { @@ -1611,9 +3567,9 @@ void rotary_qk(const phi::GPUContext &dev_ctx, } }; int BlockSize = getBlockSize(last_dim / 2); - const T *cos_emb = rotary_emb; - const T *sin_emb = rotary_emb + batch_size * seq_len * dim_head; - RotrayKernel<<>>( + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + rope_bsz * seq_len * dim_head; + RotaryKernel<<>>( q_input, cos_emb, sin_emb, @@ -1622,9 +3578,10 @@ void rotary_qk(const phi::GPUContext &dev_ctx, rotary_emb_dims, batch_size, head_num, + max_len_this_time * rotary_emb_dims, seq_len * rotary_emb_dims, last_dim); - RotrayKernel<<>>( + RotaryKernel<<>>( k_input, cos_emb, sin_emb, @@ -1633,12 +3590,102 @@ void rotary_qk(const phi::GPUContext &dev_ctx, rotary_emb_dims, batch_size, head_num, + max_len_this_time * rotary_emb_dims, seq_len * rotary_emb_dims, last_dim); } +template +void rotary_qk(const phi::GPUContext &dev_ctx, + T *q, + T *k, // kv + const T *q_input, // q + const T *k_input, // kv + const float *rotary_emb, + const int *sequence_lengths, + const int rotary_emb_dims, + const int rope_bsz, + const int batch_size, + const int head_num, + const int seq_len, + const int dim_head, + const bool neox_rotary_style) { + // q_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, + // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] + // kv_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, + // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] rotary_emb [2, bs, + // 1, seq_len, dim_head] -> [2, bs, 1, seq_len * rotary_emb_dims, dim_head / + // rotary_emb_dims] + dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims); + const int last_dim = dim_head / rotary_emb_dims; + auto getBlockSize = [](int dim) { + if (dim > 256) { + return 512; + } else if (dim > 128) { + return 256; + } else if (dim > 64) { + return 128; + } else if (dim > 32) { + return 64; + } else { + return 32; + } + }; + int BlockSize = getBlockSize(last_dim / 2); + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + batch_size * seq_len * dim_head; + if (!neox_rotary_style) { + RotaryKernel<<>>( + q_input, + cos_emb, + sin_emb, + sequence_lengths, + q, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + RotaryKernel<<>>( + k_input, + cos_emb, + sin_emb, + sequence_lengths, + k, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + } else { + NeoXRotaryKernel<<>>( + q_input, + cos_emb, + sin_emb, + sequence_lengths, + q, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + NeoXRotaryKernel<<>>( + k_input, + cos_emb, + sin_emb, + sequence_lengths, + k, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + } +} + __global__ void GetPaddingOffset(int *d_token_num, int *padding_offset, + int *cu_seqlens_data, const int *sequence_lengths, const int batch_size, const int max_seq_len) { @@ -1646,6 +3693,7 @@ __global__ void GetPaddingOffset(int *d_token_num, int total_seq_len = 0; int cum_offset = 0; int index = 0; + cu_seqlens_data[0] = 0; for (int i = 0; i < batch_size; i++) { const int seq_len = sequence_lengths[i]; for (int j = 0; j < seq_len; j++) { @@ -1654,6 +3702,7 @@ __global__ void GetPaddingOffset(int *d_token_num, } cum_offset += max_seq_len - seq_len; total_seq_len += seq_len; + cu_seqlens_data[i + 1] = cu_seqlens_data[i] + seq_len; } d_token_num[0] = total_seq_len; } @@ -1662,11 +3711,16 @@ void InvokeGetPaddingOffset(const phi::GPUContext &dev_ctx, int *h_token_num, int *d_token_num, int *padding_offset, + int *cu_seqlens_data, const int *sequence_lengths, const int batch_size, const int max_seq_len) { - GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>( - d_token_num, padding_offset, sequence_lengths, batch_size, max_seq_len); + GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>(d_token_num, + padding_offset, + cu_seqlens_data, + sequence_lengths, + batch_size, + max_seq_len); memory::Copy(platform::CPUPlace(), h_token_num, dev_ctx.GetPlace(), @@ -1675,6 +3729,221 @@ void InvokeGetPaddingOffset(const phi::GPUContext &dev_ctx, dev_ctx.stream()); } +template +__global__ void RebuildPadding(T *output_data, + const T *input_data, + const int *cum_offsets, + const int *seq_len_decoder, + const int *seq_len_encoder, + const int seq_len, + const int dim_embed, + const int elem_nums) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int global_idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int i = global_idx * VecSize; i < elem_nums; + i += gridDim.x * blockDim.x * VecSize) { + const int bi = i / dim_embed; + const int bias_idx = i % dim_embed; + int seq_id = 0; + if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; + // just encoder or stop, get last token; just decoder, get first token. + if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] != 0) + seq_id = seq_len_encoder[bi] - 1; + const int ori_token_idx = bi * seq_len - cum_offsets[bi] + seq_id; + const int src_offset = ori_token_idx * dim_embed + bias_idx; + phi::Load(&input_data[src_offset], &src_vec); + phi::Store(src_vec, &output_data[i]); + } +} + +template +void InvokeRebuildPadding(const phi::GPUContext &dev_ctx, + T *output_data, + const T *input_data, + const int *cum_offsets, + const int *seq_len_decoder, + const int *seq_len_encoder, + const int seq_len, + const int token_num, + const int dim_embed, + const int64_t elem_nums) { + // src: [token_num, dim_embed] + // dst: [batch_size, 1, dim_embed] + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(dim_embed % PackSize, + 0, + platform::errors::PreconditionNotMet( + "dim_embed=%d must be divisible by vec_size=%d", + dim_embed, + PackSize)); + int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + RebuildPadding + <<>>(output_data, + input_data, + cum_offsets, + seq_len_decoder, + seq_len_encoder, + seq_len, + dim_embed, + elem_nums); +} + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return max(a, b); + } +}; + +template +__global__ void GetMaxLenKernel(const int *seq_lens, + int *max_len, + const int batch_size) { + const int tid = threadIdx.x; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int max_len_this_thread = 0; + for (int i = tid; i < batch_size; i += blockDim.x) { + max_len_this_thread = max(seq_lens[i], max_len_this_thread); + } + int total = + BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); + if (tid == 0) { + *max_len = total; + } +} + +int GetMaxLen(const phi::GPUContext &dev_ctx, + const phi::DenseTensor &seq_lens_tensor, + phi::DenseTensor *max_len_tensor, + const int batch_size) { + constexpr int blockSize = 128; + int max_len_cpu = 0; + GetMaxLenKernel<<<1, blockSize, 0, dev_ctx.stream()>>>( + seq_lens_tensor.data(), max_len_tensor->data(), batch_size); + memory::Copy(platform::CPUPlace(), + &max_len_cpu, + dev_ctx.GetPlace(), + max_len_tensor->data(), + sizeof(int), + dev_ctx.stream()); + return max_len_cpu; +} + +template +__global__ void GetDecoderTensorKernel(const T *qkv_out, + const int *cum_offsets, + T *qkv_out_decoder, + const int token_num, + const int batch_size, + const int head_num, + const int seq_len, + const int dim_head, + const int elem_nums) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int32_t hidden_size = head_num * dim_head; + const int32_t fused_hidden_size = 3 * hidden_size; + const int global_idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int i = global_idx * VecSize; i < elem_nums; + i += gridDim.x * blockDim.x * VecSize) { + const int bi = i / fused_hidden_size; + const int bias_idx = i % fused_hidden_size; + const int ori_token_idx = bi * seq_len - cum_offsets[bi]; + const int qkv_id = bias_idx / hidden_size; + const int head_id = (i % hidden_size) / dim_head; + const int size_id = i % dim_head; + const int src_offset = ori_token_idx * fused_hidden_size + + qkv_id * hidden_size + head_id * dim_head + size_id; + phi::Load(&qkv_out[src_offset], &src_vec); + phi::Store(src_vec, &qkv_out_decoder[i]); + } +} + +template +__global__ void GetDecoderRoPEKernel(const T *rope_emb, + T *rope_out_emb, + const int rope_bsz, + const int batch_size, + const int seq_len, + const int dim_head, + const int elem_nums) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + const T *rope_cos_emb = rope_emb; + const T *rope_sin_emb = rope_emb + rope_bsz * seq_len * dim_head; + T *cos_emb = rope_out_emb; + T *sin_emb = rope_out_emb + batch_size * dim_head; + const int global_idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int i = global_idx * VecSize; i < elem_nums; + i += gridDim.x * blockDim.x * VecSize) { + const int bi = i / dim_head; + const int src_offset = bi * seq_len * dim_head + i % dim_head; + phi::Load(&rope_cos_emb[src_offset], &src_vec); + phi::Store(src_vec, &cos_emb[i]); + phi::Load(&rope_sin_emb[src_offset], &src_vec); + phi::Store(src_vec, &sin_emb[i]); + } +} + +template +void GetDecoderTensor(const phi::GPUContext &dev_ctx, + const phi::DenseTensor &qkv_out, + const phi::DenseTensor *rope_emb, + const int *cum_offsets, + phi::DenseTensor *qkv_out_decoder, + phi::DenseTensor *rope_out_emb, + const int token_num, + const int batch_size, + const int num_head, + const int seq_len, + const int dim_head) { + // qkv_out: [token_num, 3, num_head, dim_head] -> [bs, 1, 3, num_head, + // dim_head] rope: [2, bsz, 1, seq_len, dim_head] -> [2, bsz, 1, 1, dim_head] + int elem_nums = qkv_out_decoder->numel(); + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ( + dim_head % PackSize, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", dim_head, PackSize)); + int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + GetDecoderTensorKernel + <<>>( + qkv_out.data(), + cum_offsets, + qkv_out_decoder->data(), + token_num, + batch_size, + num_head, + seq_len, + dim_head, + elem_nums); + if (rope_out_emb) { + elem_nums = rope_out_emb->numel() / 2; + pack_num = elem_nums / PackSize; + GetNumBlocks(pack_num, &grid_size); + GetDecoderRoPEKernel + <<>>( + rope_emb->data(), + rope_out_emb->data(), + rope_emb->dims()[1], + batch_size, + seq_len, + dim_head, + elem_nums); + } +} + template __global__ void RemovePadding(T *output_data, const T *input_data, @@ -1731,264 +4000,359 @@ void InvokeRebuildPadding(const phi::GPUContext &dev_ctx, output_data, input_data, padding_offset, dim_embed); } -#if CUDA_VERSION >= 11060 -// Only Used in Inference -template -class CublasFusedMLP { - public: - // (m, n, k) = bsz_seq, hidden_feature, in_feature - explicit CublasFusedMLP(const phi::GPUContext &dev_ctx) : dev_ctx_(dev_ctx) { - cudaDataType_t mat_type = CUDA_R_32F; - cudaDataType_t scale_type = CUDA_R_32F; - cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; - if (std::is_same::value) { - mat_type = CUDA_R_16F; - if (FLAGS_gemm_use_half_precision_compute_type) { - // This option default value is true, it tends to result NaN, but get - // better inference speed. you can turn off by using `export - // FLAGS_gemm_use_half_precision_compute_type=0`. - compute_type = CUBLAS_COMPUTE_16F; - scale_type = CUDA_R_16F; - } - } - if (std::is_same::value) { - mat_type = CUDA_R_16BF; - } - if (std::is_same::value) { - mat_type = CUDA_R_64F; - scale_type = CUDA_R_64F; - compute_type = CUBLAS_COMPUTE_64F; +template +__global__ void InitOutValueKernel(T *output_data, + const int64_t numel, + const T init_value) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + int64_t global_thread_idx = bid * blockDim.x + tid; + + for (int linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < numel; + linear_index += step) { + for (int i = 0; i < VecSize; i++) { + output_data[linear_index + i] = init_value; } + } +} + +template +void InitValue(const phi::GPUContext &dev_ctx, + T *output_data, + const int64_t numel, + const T init_value) { + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ( + numel % PackSize, + 0, + platform::errors::PreconditionNotMet( + "numel=%d must be divisible by vec_size=%d", numel, PackSize)); + const int pack_num = numel / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + InitOutValueKernel + <<>>( + output_data, numel, init_value); +} - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( - &operation_desc_, compute_type, scale_type)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &x_desc_, mat_type, 1, 1, 1)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &w_desc_, mat_type, 1, 1, 1)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &out_desc_, mat_type, 1, 1, 1)); - } - ~CublasFusedMLP() { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescDestroy(operation_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(x_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(w_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(out_desc_)); - } - - void Setup(const phi::DDim &x_shape, - const phi::DDim &w_shape, - bool trans_x, - bool trans_w) { - int64_t M = trans_x ? x_shape[1] : x_shape[0]; - int64_t K = trans_w ? w_shape[1] : w_shape[0]; - int64_t N = trans_w ? w_shape[0] : w_shape[1]; - - cublasOperation_t cublas_transA = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cublas_transB = trans_w ? CUBLAS_OP_T : CUBLAS_OP_N; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_TRANSB, - &cublas_transA, - sizeof(cublas_transA))); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &cublas_transB, - sizeof(cublas_transB))); - - SetCublasMatrixLayout(x_desc_, trans_x, M, K); - SetCublasMatrixLayout(w_desc_, trans_w, K, N); - SetCublasMatrixLayout(out_desc_, false, M, N); - } - - void ComputeForward(const phi::DenseTensor *x, - const phi::DenseTensor *weight, - const phi::DenseTensor *bias, - phi::DenseTensor *residual, - phi::DenseTensor *output, - const std::string &activation) { - T *out_data = output->data(); - - const bool add_residual = (residual == nullptr) ? false : true; - const bool add_bias = (bias == nullptr) ? false : true; - - const T *bias_data = nullptr; - if (add_bias) { - bias_data = bias->data(); +template +__global__ void ActFFNGlu(const T *bias, + Functor act_functor, + const int token_num, + const int hid_dim, + const int elem_num, + LoadFunc load_func, + StoreFunc store_func) { + using LoadT = phi::AlignedVector; + LoadT src_vec1; + LoadT src_vec2; + LoadT bias_vec1; + LoadT bias_vec2; + const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = global_tid * VecSize; i < elem_num; + i += gridDim.x * blockDim.x * VecSize) { + int bi = i / hid_dim; + int idx = i % hid_dim; + // const T *input_this_thread = input + bi * hid_dim * 2; + // T *output_this_thread = output + bi * hid_dim; + // phi::Load(&input_this_thread[idx], &src_vec1); + // phi::Load(&input_this_thread[idx + hid_dim], &src_vec2); + + load_func.template load(&src_vec1, bi * hid_dim * 2 + idx); + load_func.template load(&src_vec2, + bi * hid_dim * 2 + idx + hid_dim); + + if (bias) { + phi::Load(&bias[idx], &bias_vec1); + phi::Load(&bias[idx + hid_dim], &bias_vec2); } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias_data, - sizeof(bias_data))); - - cublasLtEpilogue_t epiloque_func = GetEpilogueType(activation, add_bias); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epiloque_func, - sizeof(epiloque_func))); - - T *residual_data = add_residual ? residual->data() : out_data; - - cublasLtHandle_t lt_handle = dev_ctx_.cublaslt_handle(); - size_t workspace_size = static_cast(4) * 1024 * 1024; - cudaStream_t stream = dev_ctx_.stream(); - memory::allocation::AllocationPtr workspace = memory::Alloc( - dev_ctx_.GetPlace(), - workspace_size, - phi::Stream(reinterpret_cast(dev_ctx_.stream()))); - - // if add_residual, we compute result + 1.0 * residual, - // else result + 0.0 * out. - double alpha64 = 1.0, beta64 = add_residual ? 1.0 : 0.0; - float alpha32 = 1.0f, beta32 = add_residual ? 1.0f : 0.0f; - half alpha16 = static_cast(1.0), - beta16 = - add_residual ? static_cast(1.0) : static_cast(0.0); - - void *alpha = &alpha32, *beta = &beta32; - if (std::is_same::value) { - alpha = &alpha64; - beta = &beta64; +#pragma unroll + for (int j = 0; j < VecSize; j++) { + if (bias) { + src_vec1[j] += bias_vec1[j]; + src_vec2[j] += bias_vec2[j]; + } + src_vec1[j] = act_functor(src_vec1[j]); + src_vec1[j] *= src_vec2[j]; } + // phi::Store(src_vec1, &output_this_thread[idx]); + store_func.template store(src_vec1, bi * hid_dim + idx); + } +} - if (std::is_same::value && - FLAGS_gemm_use_half_precision_compute_type) { - alpha = &alpha16; - beta = &beta16; - } +template +void LaunchActFFNGlu(const phi::GPUContext &dev_ctx, + const T *bias, + const int token_num, + const int hid_dim, + LoadFunc load_func, + StoreFunc store_func) { + constexpr int VecSize = 16; + constexpr int PackSize = VecSize / sizeof(LoadT); + const int elem_cnt = token_num * hid_dim; + const int blocksize = 128; + int grid_size = 1; + Functor functor; + switch (hid_dim % PackSize) { + case 0: + GetNumBlocks(elem_cnt / PackSize, &grid_size); + ActFFNGlu + <<>>(bias, + functor, + token_num, + hid_dim, + elem_cnt, + load_func, + store_func); + break; + default: + GetNumBlocks(elem_cnt, &grid_size); + ActFFNGlu<<>>( + bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func); + break; + } +} - const auto *x_data = x->data(); - const auto *w_data = weight->data(); - - auto algo = phi::funcs::GemmEpilogueAlgoCache::Instance().GetGemmAlgo( - lt_handle, - operation_desc_, - w_desc_, - x_desc_, - out_desc_, - alpha, - beta, - w_data, - x_data, - out_data, - stream, - workspace->ptr(), - workspace_size); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmul(lt_handle, - operation_desc_, - alpha, - w_data, - w_desc_, - x_data, - x_desc_, - beta, - residual_data, - out_desc_, - out_data, - out_desc_, - algo, - workspace->ptr(), - workspace_size, - stream)); - } - - private: - cublasLtEpilogue_t GetEpilogueType(const std::string &activation, - const bool add_bias) { - if (activation == "relu") { - if (add_bias) { - return CUBLASLT_EPILOGUE_RELU_BIAS; - } else { - return CUBLASLT_EPILOGUE_RELU; - } - } else if (activation == "gelu") { - if (add_bias) { - return CUBLASLT_EPILOGUE_GELU_BIAS; - } else { - return CUBLASLT_EPILOGUE_GELU; - } - } else if (activation == "none") { - if (add_bias) { - return CUBLASLT_EPILOGUE_BIAS; - } else { - return CUBLASLT_EPILOGUE_DEFAULT; +template +__global__ void BiasAct(const T *bias, + Functor act_functor, + const int rows, + const int cols, + const int elem_num, + LoadFunc load_func, + StoreFunc store_func) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + LoadT bias_vec; + +// Zero Initialize BiasVec. +#pragma unroll + for (int unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { + bias_vec[unroll_idx] = 0; + } + + const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = global_tid * VecSize; i < elem_num; + i += gridDim.x * blockDim.x * VecSize) { + int row_idx = i / cols; + int col_idx = i % cols; + int linear_idx = row_idx * cols + col_idx; + // phi::Load(&input[linear_idx], &src_vec); + load_func.template load(&src_vec, linear_idx); + if (bias) { + phi::Load(&bias[col_idx], &bias_vec); + } +#pragma unroll + for (int j = 0; j < VecSize; j++) { + if (bias) { + src_vec[j] += bias_vec[j]; } - } else { - PADDLE_ENFORCE_EQ( - true, - false, - platform::errors::InvalidArgument( - "The activation attribute of fused_gemm_epilogue op should be" - " one of {\"none\", \"relu\", \"gelu\"}. But received %s." - "But received activation=%s.", - activation)); + src_vec[j] = act_functor(src_vec[j]); } + // phi::Store(src_vec, &output[linear_idx]); + store_func.template store(src_vec, linear_idx); } +} - void SetCublasMatrixLayout(cublasLtMatrixLayout_t layout_desc, - const bool transpose, - const uint64_t cublas_row, - const uint64_t cublas_col) { - cudaDataType_t mat_type = CUDA_R_32F; - if (std::is_same::value) { - mat_type = CUDA_R_16F; - } - if (std::is_same::value) { - mat_type = CUDA_R_16BF; +template +void LaunchBiasAct(const phi::GPUContext &dev_ctx, + const T *bias, + const int token_num, + const int hid_dim, + LoadFunc load_func, + StoreFunc store_func) { + constexpr int VecSize = 16; + constexpr int PackSize = VecSize / sizeof(LoadT); + const int elem_cnt = token_num * hid_dim; + const int blocksize = 128; + int grid_size = 1; + Functor functor; + switch (hid_dim % PackSize) { + case 0: + GetNumBlocks(elem_cnt / PackSize, &grid_size); + BiasAct + <<>>(bias, + functor, + token_num, + hid_dim, + elem_cnt, + load_func, + store_func); + break; + default: + GetNumBlocks(elem_cnt, &grid_size); + BiasAct<<>>( + bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func); + break; + } +} + +template +__global__ void fused_transpose_split_kernel( + T *q_out, // [total, num_head, head_dim] + T *k_out, // [total, num_head, head_dim] + T *v_out, // [total, num_head, head_dim] + const T *q_input, // [bsz, num_head, seq_len, head_dim] + const T *kv_input, // [2, bsz, num_head, seq_len, head_dim] + const int *padding_offset, + const int *seq_lens, + const int32_t elem_cnt, + const int batch_size, + const int max_len_this_time, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head) { + const int32_t offset = + batch_size * max_len_this_time * head_num * size_per_head; + const int32_t hidden_size = head_num * size_per_head; + const int32_t fused_hidden_size = 3 * hidden_size; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = phi::AlignedVector; + LoadT src_vec; + LoadT bias_vec; + + int q_size = token_num * hidden_size; + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + int32_t bias_idx = linear_index % fused_hidden_size; + int32_t current_token = linear_index / fused_hidden_size; + const int32_t token_idx = linear_index / fused_hidden_size; + const int32_t ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int32_t target_batch_id = ori_token_idx / seq_len; + if (seq_lens[target_batch_id] == 0) continue; + const int32_t seq_id = ori_token_idx % seq_len; + + // equal to: + // const int qkv_id = (linear_index % fused_hidden_size) / hidden_size; + const int32_t qkv_id = bias_idx / hidden_size; + const int32_t head_id = (linear_index % hidden_size) / size_per_head; + const int32_t size_id = linear_index % size_per_head; + + if (qkv_id == 0) { // read q + phi::Load( + &q_input[target_batch_id * head_num * max_len_this_time * + size_per_head + + head_id * max_len_this_time * size_per_head + + seq_id * size_per_head + size_id], + &src_vec); + } else { // read k/v + const int32_t kv_store_offset = (qkv_id - 1) * offset; + phi::Load( + &kv_input[kv_store_offset + + target_batch_id * head_num * max_len_this_time * + size_per_head + + head_id * max_len_this_time * size_per_head + + seq_id * size_per_head + size_id], + &src_vec); } - if (std::is_same::value) { - mat_type = CUDA_R_64F; + int32_t write_index = + linear_index - (qkv_id + 2 * current_token) * hidden_size; + if (qkv_id == 0) { + phi::Store(src_vec, &q_out[write_index]); + } else if (qkv_id == 1) { + phi::Store(src_vec, &k_out[write_index]); + } else if (qkv_id == 2) { + phi::Store(src_vec, &v_out[write_index]); } + } +} - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, - CUBLASLT_MATRIX_LAYOUT_TYPE, - &mat_type, - sizeof(mat_type))); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, - CUBLASLT_MATRIX_LAYOUT_ROWS, - transpose ? &cublas_row : &cublas_col, - sizeof(cublas_row))); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, - CUBLASLT_MATRIX_LAYOUT_COLS, - transpose ? &cublas_col : &cublas_row, - sizeof(cublas_col))); - int64_t cublas_ld = transpose ? cublas_row : cublas_col; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, - CUBLASLT_MATRIX_LAYOUT_LD, - &cublas_ld, - sizeof(cublas_ld))); - } - - const phi::GPUContext &dev_ctx_; - cublasLtMatmulDesc_t operation_desc_ = NULL; - cublasLtMatrixLayout_t x_desc_ = NULL; - cublasLtMatrixLayout_t w_desc_ = NULL; - cublasLtMatrixLayout_t out_desc_ = NULL; -}; +template +void TransposeSplit(const phi::GPUContext &dev_ctx, + T *q_out, + T *k_out, + T *v_out, + const T *q_input, + const T *kv_input, + const int *padding_offset, + const int *seq_lens, + const int token_num, + const int batch_size, + const int head_num, + const int max_len_this_time, + const int seq_len, + const int size_per_head) { + const int32_t elem_cnt = token_num * head_num * size_per_head * 3; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(size_per_head % PackSize, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", + size_per_head, + PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t blocksize = 128; + int32_t grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + fused_transpose_split_kernel + <<>>(q_out, + k_out, + v_out, + q_input, + kv_input, + padding_offset, + seq_lens, + elem_cnt, + batch_size, + max_len_this_time, + seq_len, + token_num, + head_num, + size_per_head); +} -#endif // PADDLE_FLUID_OPERATORS_FUSED_FUSED_MULTI_TRANSFORMER_OP_CU_H_ +template +void TransposeSplit(const phi::GPUContext &dev_ctx, + T *q_out, + T *k_out, + T *v_out, + const T *q_input, + const T *kv_input, + const int *padding_offset, + const int *seq_lens, + const int token_num, + const int batch_size, + const int head_num, + const int seq_len, + const int size_per_head) { + TransposeSplit(dev_ctx, + q_out, + k_out, + v_out, + q_input, + kv_input, + padding_offset, + seq_lens, + token_num, + batch_size, + head_num, + seq_len, + seq_len, + size_per_head); +} } // namespace diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index 43c25bcaf10cb..abce94864c541 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -227,6 +227,7 @@ __global__ void FusedResidualDropoutBias( /** * @brief dst = residual + dropout(src + bias); */ +// TODO(@wufeisheng): residual_alpha template +// TODO(@wufeisheng): mask_broadcast_num_heads template void LaunchFusedSoftmaxMaskKernel(const T* src, const T* mask, @@ -155,6 +156,7 @@ void LaunchFusedSoftmaxMaskKernel(const T* src, const int batch_size, const int head_num, const int seq_len, + const bool mask_broadcast_num_heads, cudaStream_t stream) { PADDLE_ENFORCE_EQ( seq_len > 0 && seq_len <= 4096, diff --git a/paddle/fluid/operators/fused/mmha_util.cu.h b/paddle/fluid/operators/fused/mmha_util.cu.h new file mode 100644 index 0000000000000..e69582f300383 --- /dev/null +++ b/paddle/fluid/operators/fused/mmha_util.cu.h @@ -0,0 +1,2372 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifndef MMHA_UTIL_CU_H_ +#define MMHA_UTIL_CU_H_ +#ifdef __NVCC__ + +#include +#include +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/fused/attention_layer_norm.h" +#include "paddle/fluid/operators/fused/attn_gemm.h" +#include "paddle/fluid/operators/fused/fmha_ref.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/fluid/platform/dynload/cublasLt.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" + +// #include +// "paddle/fluid/operators/fused/cutlass/cutlass_extensions/interleaved_numeric_conversion.h" + +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +#include "paddle/fluid/operators/fused/datatype_traits.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif + +#if defined(__CUDACC__) && CUDA_VERSION >= 11000 +#define ENABLE_BF16 +#include +#endif + + +namespace paddle { +namespace operators { + +namespace { // NOLINT +namespace plat = paddle::platform; +using float16 = plat::float16; +// using float16 = half; +using bfloat16 = plat::bfloat16; +// using bfloat16 = __nv_bfloat16; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +struct Float4_ { + float2 x; + float2 y; +}; + +#ifdef ENABLE_BF16 +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; +#endif // ENABLE_BF16 + +//------------------------------------ +template +struct num_elems; +template <> +struct num_elems { + static constexpr int value = 1; +}; +template <> +struct num_elems { + static constexpr int value = 2; +}; +template <> +struct num_elems { + static constexpr int value = 4; +}; +// lyq::todo float8 +// lyq::todo uint16_t +template <> +struct num_elems { + static constexpr int value = 2; +}; +template <> +struct num_elems { + static constexpr int value = 4; +}; +template <> +struct num_elems { + static constexpr int value = 8; +}; +#ifdef ENABLE_BF16 +template <> +struct num_elems<__nv_bfloat162> { + static constexpr int value = 2; +}; +template <> +struct num_elems { + static constexpr int value = 4; +}; +template <> +struct num_elems { + static constexpr int value = 8; +}; +#endif // ENABLE_BF16 + +//------------------------------------ +template +struct packed_type; +template +struct packed_type { + using type = T; +}; +template <> +struct packed_type { + using type = uint16_t; +}; +template <> +struct packed_type { + using type = uint32_t; +}; +template <> +struct packed_type { + using type = uint64_t; +}; +template <> +struct packed_type { + using type = float2; +}; +template <> +struct packed_type { + using type = float4; +}; +template <> +struct packed_type { + using type = Float8_; +}; + +//------------------------------------ +template +struct Qk_vec_ {}; +template <> +struct Qk_vec_ { + using Type = float; +}; +template <> +struct Qk_vec_ { + using Type = float2; +}; +template <> +struct Qk_vec_ { + using Type = float4; +}; +template <> +struct Qk_vec_ { + using Type = float4; +}; +template <> +struct Qk_vec_ { + using Type = uint32_t; +}; +template <> +struct Qk_vec_ { + using Type = uint32_t; +}; +template <> +struct Qk_vec_ { + using Type = uint2; +}; +template <> +struct Qk_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template <> +struct Qk_vec_ { + using Type = __nv_bfloat162; +}; +template <> +struct Qk_vec_ { + using Type = __nv_bfloat162; +}; +template <> +struct Qk_vec_ { + using Type = bf16_4_t; +}; +template <> +struct Qk_vec_ { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +// RoPE Type +template +struct Qk_vec_RoPE_ {}; +template <> +struct Qk_vec_RoPE_ { + using Type = float2; +}; +template <> +struct Qk_vec_RoPE_ { + using Type = float2; +}; +template <> +struct Qk_vec_RoPE_ { + using Type = float4; +}; +template <> +struct Qk_vec_RoPE_ { + using Type = Float8_; +}; +template <> +struct Qk_vec_RoPE_ { + using Type = float; +}; +template <> +struct Qk_vec_RoPE_ { + using Type = float2; +}; +template <> +struct Qk_vec_RoPE_ { + using Type = float4; +}; +template <> +struct Qk_vec_RoPE_ { + using Type = float4; +}; +#ifdef ENABLE_BF16 +template <> +struct Qk_vec_RoPE_ { + using Type = float2; +}; +template <> +struct Qk_vec_RoPE_ { + using Type = float2; +}; +template <> +struct Qk_vec_RoPE_ { + using Type = float4; +}; +template <> +struct Qk_vec_RoPE_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 +//------------------------------------ + +template +struct K_vec_ {}; +template <> +struct K_vec_ { + using Type = float; +}; +template <> +struct K_vec_ { + using Type = float2; +}; +template <> +struct K_vec_ { + using Type = float4; +}; +template <> +struct K_vec_ { + using Type = uint32_t; +}; +template <> +struct K_vec_ { + using Type = uint2; +}; +template <> +struct K_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template <> +struct K_vec_ { + using Type = __nv_bfloat162; +}; +template <> +struct K_vec_ { + using Type = bf16_4_t; +}; +template <> +struct K_vec_ { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +//------------------------------------ + +template +struct K_vec_I_ { + using Type = uint8_t; +}; + +#ifdef ENABLE_BF16 +template <> +struct K_vec_I_ { + using Type = uint16_t; +}; +template <> +struct K_vec_I_ { + using Type = uint32_t; +}; +template <> +struct K_vec_I_ { + using Type = uint64_t; +}; +#endif // ENABLE_BF16 + +template +struct V_vec_ {}; +template <> +struct V_vec_ { + using Type = float; +}; +template <> +struct V_vec_ { + using Type = float2; +}; +template <> +struct V_vec_ { + using Type = float4; +}; +template <> +struct V_vec_ { + using Type = uint32_t; +}; +template <> +struct V_vec_ { + using Type = uint2; +}; +template <> +struct V_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template <> +struct V_vec_ { + using Type = __nv_bfloat162; +}; +template <> +struct V_vec_ { + using Type = bf16_4_t; +}; +template <> +struct V_vec_ { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, + const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); +#else + return __hmul2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, + const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y)); +#else + return __hmul(x, y); +#endif +} +#endif // ENABLE_BF16 + +inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? + float zero = 0.f; + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#endif + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, + const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#else + return __hadd2(x, y); +#endif +} + +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; +#else + return __bfloat1622float2(val); +#endif +} + +inline __device__ __nv_bfloat162 float22bf162(const float2 val) { + __nv_bfloat162 val_; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + val_ = __float22bfloat162_rn(val); +#else + val_.x = __float2bfloat16_rn(val.x); + val_.y = __float2bfloat16_rn(val.y); +#endif + return val_; +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + __nv_bfloat162 val2; + val2.x = val; + val2.y = val; + return val2; +#else + return __bfloat162bfloat162(val); +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, + const __nv_bfloat162 y, + const __nv_bfloat162 z) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh, fzl, fzh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + fzl = __low2float(z); + fzh = __high2float(z); + return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); +#else + return __hfma2(x, y, z); +#endif +} +#endif // ENABLE_BF16 + +inline __device__ float add(float a, float b) { return a + b; } + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { + return a + b; +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { + return bf16hadd2(a, b); +} + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float add(float a, __nv_bfloat16 b) { + return a + __bfloat162float(b); +} + +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} +#endif // ENABLE_BF16 + +template +inline __device__ void mul_pointer_v2(T* c, float a, IntT* b); + +template <> +inline __device__ void mul_pointer_v2(float4* c, float a, uint8_t* b) { + c->x = a * (static_cast(b[0]) - 128.0); + c->y = a * (static_cast(b[1]) - 128.0); + c->z = a * (static_cast(b[2]) - 128.0); + c->w = a * (static_cast(b[3]) - 128.0); +} + +template <> +inline __device__ void mul_pointer_v2(float4* c, float a, uint32_t* b) { + uint8_t* b_tmp = reinterpret_cast(b); + c->x = a * (static_cast(b_tmp[0]) - 128.0); + c->y = a * (static_cast(b_tmp[1]) - 128.0); + c->z = a * (static_cast(b_tmp[2]) - 128.0); + c->w = a * (static_cast(b_tmp[3]) - 128.0); +} + +template <> +inline __device__ void mul_pointer_v2(float2* c, float a, uint8_t* b) { + c->x = a * (static_cast(b[0]) - 128.0); + c->y = a * (static_cast(b[1]) - 128.0); +} + +template <> +inline __device__ void mul_pointer_v2(float* c, float a, uint8_t* b) { + c[0] = a * (static_cast(b[0]) - 128.0); +} + +template <> +inline __device__ void mul_pointer_v2(uint32_t* c, float a, uint8_t* b) { + float16* tmp_fp16 = reinterpret_cast(c); + float16 a_prime = static_cast(a); + float16 offset = static_cast(128.0); +#pragma unroll + for (int i = 0; i < 2; ++i) { + tmp_fp16[i] = a_prime * (static_cast(b[i]) - offset); + } +} + +template <> +inline __device__ void mul_pointer_v2(uint2* c, float a, uint8_t* b) { + float16* tmp_fp16 = reinterpret_cast(c); + float16 a_prime = static_cast(a); + float16 offset = static_cast(128.0); +#pragma unroll + for (int i = 0; i < 4; ++i) { + tmp_fp16[i] = a_prime * (static_cast(b[i]) - offset); + } +} + +template <> +inline __device__ void mul_pointer_v2(uint4* c, float a, uint8_t* b) { + float16* tmp_fp16 = reinterpret_cast(c); + float16 a_prime = static_cast(a); + float16 offset = static_cast(128.0); +#pragma unroll + for (int i = 0; i < 8; ++i) { + tmp_fp16[i] = a_prime * (static_cast(b[i]) - offset); + } +} + +template <> +inline __device__ void mul_pointer_v2(uint4* c, float a, uint64_t* b) { + uint8_t* tmp_b = reinterpret_cast(b); + float16* tmp_fp16 = reinterpret_cast(c); + float16 a_prime = static_cast(a); + float16 offset = static_cast(128.0); +#pragma unroll + for (int i = 0; i < 8; ++i) { + tmp_fp16[i] = a_prime * (static_cast(tmp_b[i]) - offset); + } +} + + +#ifdef ENABLE_BF16 +inline __device__ static void convert_(__nv_bfloat16* result, + uint32_t const& source) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= (8388608.f + 128.f); + } + +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +#endif +} + +template <> +inline __device__ void mul_pointer_v2(__nv_bfloat162* c, float a, uint8_t* b) { +#if __CUDA_ARCH__ >= 800 + __nv_bfloat16 a_prime = static_cast<__nv_bfloat16>(a); + __nv_bfloat16* c_prime = reinterpret_cast<__nv_bfloat16*>(c); + convert_(c_prime, static_cast(*reinterpret_cast(b))); +#pragma unroll + for (int i = 0; i < 2; ++i) { + c_prime[i] *= a_prime; + } +#endif +} + +template <> +inline __device__ void mul_pointer_v2(__nv_bfloat162* c, float a, uint16_t* b) { +#if __CUDA_ARCH__ >= 800 + using Packed_Int8_t = typename packed_type::type; + Packed_Int8_t int8_vec_4_val = *reinterpret_cast(b); + uint8_t* int8_vec_pointer = reinterpret_cast(&int8_vec_4_val); + + uint32_t* bf16_result_ptr = reinterpret_cast(c); + uint32_t const i8s = int8_vec_4_val; + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[2]; + + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + fp32_intermediates[ii] -= (8388608.f + 128.f); + } + + bf16_result_ptr[0] = __byte_perm( + fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + __nv_bfloat16 scale = static_cast<__nv_bfloat16>(a); + c->x *= scale; + c->y *= scale; +#endif +} + +template <> +inline __device__ void mul_pointer_v2(bf16_4_t* c, float a, uint8_t* b) { +#if __CUDA_ARCH__ >= 800 + __nv_bfloat16 a_prime = static_cast<__nv_bfloat16>(a); + __nv_bfloat16* c_prime = reinterpret_cast<__nv_bfloat16*>(c); + convert_(c_prime, *reinterpret_cast(b)); +#pragma unroll + for (int i = 0; i < 4; ++i) { + c_prime[i] *= a_prime; + } +#endif +} + +template <> +inline __device__ void mul_pointer_v2(bf16_4_t* c, float a, uint32_t* b) { +#if __CUDA_ARCH__ >= 800 + __nv_bfloat16 a_prime = static_cast<__nv_bfloat16>(a); + __nv_bfloat16* c_prime = reinterpret_cast<__nv_bfloat16*>(c); + convert_(c_prime, *b); +#pragma unroll + for (int i = 0; i < 4; ++i) { + c_prime[i] *= a_prime; + } +#endif +} + +template <> +inline __device__ void mul_pointer_v2(bf16_8_t* c, float a, uint8_t* b) { + bf16_4_t* tmp_c = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 2; ++i) { + mul_pointer_v2(tmp_c + i, a, b + 4 * i); + } +} + +template <> +inline __device__ void mul_pointer_v2(bf16_8_t* c, float a, uint64_t* b) { + bf16_4_t* tmp_c = reinterpret_cast(c); + uint64_t bb = *b; + uint32_t* tmp_b = reinterpret_cast(&bb); +#pragma unroll + for (int i = 0; i < 2; ++i) { + mul_pointer_v2(tmp_c + i, a, tmp_b + i); + } +} +#endif // ENABLE_BF16 + +template +inline __device__ Acc mul(A a, B b); + +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + + +#ifdef ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 mul(float a, __nv_bfloat162 b) { + __nv_bfloat162 ret; + ret.x = static_cast<__nv_bfloat16>(a) * b.x; + ret.y = static_cast<__nv_bfloat16>(a) * b.y; + return ret; +} + +template <> +inline __device__ bf16_4_t mul(float a, bf16_4_t b) { + bf16_4_t ret; + ret.x = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.x); + ret.y = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.y); + return ret; +} + +template <> +inline __device__ bf16_8_t mul(float a, bf16_8_t b) { + bf16_8_t ret; + ret.x = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.x); + ret.y = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.y); + ret.z = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.z); + ret.w = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.w); + return ret; +} +#endif // ENABLE_BF16 + +template <> +inline __device__ uint32_t mul(float a, uint32_t b) { + union { + float16 out[2]; + uint32_t t_out; + }; + + union { + float16 in[2]; + uint32_t t_in; + }; + t_in = b; +#pragma unroll + for (int i = 0; i < 2; ++i) { + out[i] = static_cast(a) * in[i]; + } + return t_out; +} + +template <> +inline __device__ float16 mul(float a, float16 b) { + return static_cast(a) * b; +} + +template <> +inline __device__ uint2 mul(float a, uint2 b) { + union { + uint2 tmp_in; + float16 tmp_in_fp16[4]; + }; + tmp_in = b; + union { + uint2 ret; + float16 tmp_out_fp16[4]; + }; + +#pragma unroll + for (int i = 0; i < 4; ++i) { + tmp_out_fp16[i] = mul(a, tmp_in_fp16[i]); + } + return ret; +} + +template <> +inline __device__ uint4 mul(float a, uint4 b) { + union { + uint4 tmp_in; + float16 tmp_in_fp16[8]; + }; + tmp_in = b; + union { + uint4 ret; + float16 tmp_out_fp16[8]; + }; +#pragma unroll + for (int i = 0; i < 8; ++i) { + tmp_out_fp16[i] = mul(a, tmp_in_fp16[i]); + } + return ret; +} + +template <> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +template <> +inline __device__ Float8_ mul(float a, Float8_ b) { + Float8_ c; + c.x = mul(a, b.x); + c.y = mul(a, b.y); + c.z = mul(a, b.z); + c.w = mul(a, b.w); + return c; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template <> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +template <> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, float b) { + float2 tmp = half2_to_float2(a); + float2 tmp_res; + tmp_res.x = tmp.x * b; + tmp_res.y = tmp.y * b; + uint32_t res = float2_to_half2(tmp_res); + return res; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, float2 b) { + float2 tmp = half2_to_float2(a); + float2 tmp_res; + tmp_res.x = tmp.x * b.x; + tmp_res.y = tmp.y * b.y; + uint32_t res = float2_to_half2(tmp_res); + return res; +} + +template <> +inline __device__ float2 mul(uint32_t a, float b) { + float2 tmp = half2_to_float2(a); + float2 res; + res.x = tmp.x * b; + res.y = tmp.y * b; + return res; +} + +template <> +inline __device__ uint2 mul(uint2 a, float b) { + uint2 res; + res.x = mul(a.x, b); + res.y = mul(a.y, b); + return res; +} + +template <> +inline __device__ uint2 mul(uint2 a, float4 b) { + Float4_& b_ = *reinterpret_cast(&b); + uint2 res; + res.x = mul(a.x, b_.x); + res.y = mul(a.y, b_.y); + return res; +} + +template <> +inline __device__ uint4 mul(uint4 a, float b) { + uint4 res; + res.x = mul(a.x, b); + res.y = mul(a.y, b); + res.z = mul(a.z, b); + res.w = mul(a.w, b); + return res; +} + +template <> +inline __device__ uint4 mul(uint4 a, Float8_ b) { + uint4 res; + res.x = mul(a.x, b.x); + res.y = mul(a.y, b.y); + res.z = mul(a.z, b.z); + res.w = mul(a.w, b.w); + return res; +} + +template <> +inline __device__ float2 mul(float2 a, float b) { + float2 res; + res.x = a.x * b; + res.y = a.y * b; + return res; +} + +template <> +inline __device__ float2 mul(float2 a, uint32_t b) { + float2 tmp_b = half2_to_float2(b); + float2 res; + res.x = a.x * tmp_b.x; + res.y = a.y * tmp_b.y; + return res; +} + +template <> +inline __device__ float4 mul(float4 a, float b) { + float4 res; + res.x = a.x * b; + res.y = a.y * b; + res.z = a.z * b; + res.w = a.w * b; + return res; +} + +#ifdef ENABLE_BF16 +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hmul(a, b); +#else + return bf16hmul(a, b); +#endif +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + return bf16hmul2(a, b); +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +template <> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +template <> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +template <> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +template <> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +template <> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { + float fa = static_cast(a); + float fb = static_cast(b); + return fa * fb; +} + +template <> +inline __device__ float mul(__nv_bfloat16 a, float b) { + return __bfloat162float(a) * b; +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, float b) { + __nv_bfloat162 res; + __nv_bfloat162 _bf16 = __float2bfloat162_rn(b); + res = bf16hmul2(a, _bf16); + return res; +} + +template <> +inline __device__ __nv_bfloat162 mul(float2 a, float2 b) { + float2 res = mul(a, b); + __nv_bfloat162 bf16_res = float22bf162(res); + return bf16_res; +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, float2 b) { + float2 a_ = bf1622float2(a); + float2 res = mul(a_, b); + __nv_bfloat162 bf16_res = float22bf162(res); + return bf16_res; +} + +template <> +inline __device__ bf16_4_t mul(bf16_4_t a, float b) { + __nv_bfloat162 s = __float2bfloat162_rn(b); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, s); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, s); + return c; +} + +template <> +inline __device__ bf16_4_t mul(bf16_4_t a, float4 b) { + Float4_& b_ = *reinterpret_cast(&b); + float2 a1 = bf1622float2(a.x); + float2 a2 = bf1622float2(a.y); + + bf16_4_t c; + c.x = mul<__nv_bfloat162, float2, float2>(a1, b_.x); + c.y = mul<__nv_bfloat162, float2, float2>(a2, b_.y); + return c; +} + +template <> +inline __device__ bf16_8_t mul(bf16_8_t a, float b) { + __nv_bfloat162 s = __float2bfloat162_rn(b); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, s); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, s); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, s); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, s); + return c; +} + +template <> +inline __device__ bf16_8_t mul(bf16_8_t a, Float8_ b) { + float2 a1 = bf1622float2(a.x); + float2 a2 = bf1622float2(a.y); + float2 a3 = bf1622float2(a.z); + float2 a4 = bf1622float2(a.w); + + bf16_8_t c; + c.x = mul<__nv_bfloat162, float2, float2>(a1, b.x); + c.y = mul<__nv_bfloat162, float2, float2>(a2, b.y); + c.z = mul<__nv_bfloat162, float2, float2>(a3, b.z); + c.w = mul<__nv_bfloat162, float2, float2>(a4, b.w); + return c; +} + +template <> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul(bf162bf162(a), b); +} + +template <> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} +#endif // ENABLE_BF16 + +template +inline __device__ Qk_vec apply_rotary_emb(Qk_vec input_left, + Qk_vec input_right, + Qk_vec_RoPE cos_emb, + Qk_vec_RoPE sin_emb, + float alpha) { + Qk_vec res1 = mul(input_left, cos_emb); + Qk_vec res2 = mul(input_right, sin_emb); + res2 = mul(res2, alpha); + return add(res1, res2); +} + +inline __device__ float sum(float v) { return v; } +inline __device__ float sum(float2 v) { return v.x + v.y; } +inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } +inline __device__ float sum(uint16_t v) { return half_to_float(v); } +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +#ifdef ENABLE_BF16 +inline __device__ float sum(__nv_bfloat162 v) { + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +inline __device__ float sum(bf16_4_t v) { return sum(v.x) + sum(v.y); } + +inline __device__ float sum(bf16_8_t v) { + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} +#endif // ENABLE_BF16 + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +inline __device__ constexpr uint32_t shfl_mask(int threads) { + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +template +inline __device__ __host__ T div_up(T m, T n) { + return (m + n - 1) / n; +} + +inline __device__ float fma(float a, float b, float c) { return a * b + c; } + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float2 a, uint32_t b, float2 c) { + float2 tmp_b = half2_to_float2(b); + float2 d = fma(a, tmp_b, c); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); + return d; +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 fma(float a, float2 b, __nv_bfloat162 c) { + return bf16hfma2(__float2bfloat162_rn(a), float22bf162(b), c); +} + +inline __device__ bf16_4_t fma(float a, Float4_ b, bf16_4_t c) { + bf16_4_t d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} +#endif // ENABLE_BF16 + +inline __device__ uint32_t h0_h0(uint16_t a) { + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +#ifdef ENABLE_BF16 + +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, + __nv_bfloat162 b, + __nv_bfloat162 c) { + return bf16hfma2(a, b, c); +} + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, + __nv_bfloat162 b, + __nv_bfloat162 c) { + return bf16hfma2(bf162bf162(a), b, c); +} + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { + return fma(bf162bf162(a), b, fc); +} + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} +#endif // ENABLE_BF16 + +inline __device__ float cast_to_float(float u) { return u; } + +inline __device__ float2 cast_to_float(float2 u) { return u; } + +inline __device__ float4 cast_to_float(float4 u) { return u; } + +inline __device__ float2 cast_to_float(uint32_t u) { + return half2_to_float2(u); +} + +inline __device__ Float4_ cast_to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ cast_to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +inline __device__ Float4_ cast_to_float(Float4_ u) { return u; } + +inline __device__ Float8_ cast_to_float(Float8_ u) { return u; } + +#ifdef ENABLE_BF16 +inline __device__ float cast_to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + +inline __device__ float2 cast_to_float(__nv_bfloat162 u) { + return bf1622float2(u); +} + +inline __device__ Float4_ cast_to_float(bf16_4_t u) { + Float4_ tmp; + tmp.x = bf1622float2(u.x); + tmp.y = bf1622float2(u.y); + return tmp; +} + +inline __device__ Float8_ cast_to_float(bf16_8_t u) { + Float8_ tmp; + tmp.x = bf1622float2(u.x); + tmp.y = bf1622float2(u.y); + tmp.z = bf1622float2(u.z); + tmp.w = bf1622float2(u.w); + return tmp; +} +#endif // ENABLE_BF16 + +template +inline __device__ T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +inline __device__ T round_tmp(D val); + +template <> +inline __device__ uint8_t round_tmp(float val) { + float quant_value = roundWithTiesToEven(val); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + return static_cast(quant_value + 128.0); +} + +template <> +inline __device__ uint8_t round_tmp(float16 val) { + float quant_value = roundWithTiesToEven(static_cast(val)); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + return static_cast(quant_value + 128.0); +} + +#ifdef ENABLE_BF16 +template <> +inline __device__ uint8_t round_tmp(__nv_bfloat16 val) { + float quant_value = + static_cast(roundWithTiesToEven(static_cast(val))); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + return static_cast(quant_value + 128.0); +} +#endif // ENABLE_BF16 + + +template <> +inline __device__ uint16_t round_tmp(float2 val) { + union { + uint16_t ret; + uint8_t tmp[2]; + }; + tmp[0] = round_tmp(val.x); + tmp[1] = round_tmp(val.y); + return ret; +} + +template <> +inline __device__ uint32_t round_tmp(float4 val) { + union { + uint32_t ret; + uint8_t tmp[4]; + }; + tmp[0] = round_tmp(val.x); + tmp[1] = round_tmp(val.y); + tmp[2] = round_tmp(val.z); + tmp[3] = round_tmp(val.w); + return ret; +} + +template <> +inline __device__ uint16_t round_tmp(uint32_t val) { + union { + uint8_t int8[2]; + uint16_t ret; + }; + union { + float16 fp16[2]; + uint32_t tmp; + }; + tmp = val; + +#pragma unroll + for (int i = 0; i < 2; ++i) { + int8[i] = round_tmp(fp16[i]); + } + + return ret; +} + +template <> +inline __device__ uint32_t round_tmp(uint2 val) { + union { + uint8_t int8[4]; + uint32_t ret; + }; + + union { + uint2 ui2; + float16 tmp_fp16[4]; + }; + ui2 = val; + +#pragma unroll + for (int i = 0; i < 4; ++i) { + int8[i] = round_tmp(tmp_fp16[i]); + } + return ret; +} + +template <> +inline __device__ uint64_t round_tmp(uint4 val) { + union { + uint8_t int8[8]; + uint64_t ret; + }; + + union { + uint4 ui4; + float16 tmp_fp16[8]; + }; + ui4 = val; + +#pragma unroll + for (int i = 0; i < 8; ++i) { + int8[i] = round_tmp(tmp_fp16[i]); + } + return ret; +} + + +#ifdef ENABLE_BF16 +template <> +inline __device__ uint16_t round_tmp(__nv_bfloat162 val) { + union { + uint8_t tmp[2]; + uint16_t ret; + }; + tmp[0] = round_tmp(val.x); + tmp[1] = round_tmp(val.y); + return ret; +} + +template <> +inline __device__ uint32_t round_tmp(bf16_4_t val) { + union { + uint16_t tmp[2]; + uint32_t ret; + }; + tmp[0] = round_tmp(val.x); + tmp[1] = round_tmp(val.y); + return ret; +} + +template <> +inline __device__ uint64_t round_tmp(bf16_8_t val) { + union { + uint16_t int16[4]; + uint64_t int64; + }; + int16[0] = round_tmp(val.x); + int16[1] = round_tmp(val.y); + int16[2] = round_tmp(val.z); + int16[3] = round_tmp(val.w); + return int64; +} +#endif // ENABLE_BF16 + +inline __device__ float2 rotary_embedding_coefficient(const int zid, + const int rot_embed_dim, + const float t_step) { + const float inv_freq = t_step / pow(10000.0f, zid / (float)rot_embed_dim); + return {cos(inv_freq), sin(inv_freq)}; +} + +inline __device__ float2 rotary_embedding_transform(const float2 v, + const float2 coef) { + float2 rot_v; + rot_v.x = coef.x * v.x - coef.y * v.y; + rot_v.y = coef.x * v.y + coef.y * v.x; + return rot_v; +} + +inline __device__ float2 rotary_embedding_transform(const float2 v, + const float2 cos, + const float2 sin) { + float2 rot_v; + rot_v.x = v.x * cos.x - v.y * sin.x; + rot_v.y = v.y * cos.y + v.x * sin.y; + return rot_v; +} + +inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, + const float2 coef) { + float2 fv = half2_to_float2(v); + float2 rot_fv = rotary_embedding_transform(fv, coef); + return float2_to_half2(rot_fv); +} + +inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, + const uint32_t cos, + const uint32_t sin) { + float2 fv = half2_to_float2(v); + float2 fcos = half2_to_float2(cos); + float2 fsin = half2_to_float2(sin); + float2 rot_fv = rotary_embedding_transform(fv, fcos, fsin); + return float2_to_half2(rot_fv); +} + +inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, + const float2 cos, + const float2 sin) { + float2 fv = half2_to_float2(v); + float2 rot_fv = rotary_embedding_transform(fv, cos, sin); + return float2_to_half2(rot_fv); +} + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 +rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) { + float2 fv = bf1622float2(v); + float2 rot_fv = rotary_embedding_transform(fv, coef); + return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); +} + +inline __device__ __nv_bfloat162 +rotary_embedding_transform(const __nv_bfloat162 v, + const __nv_bfloat162 cos, + const __nv_bfloat162 sin) { + float2 fv = bf1622float2(v); + float2 fcos = bf1622float2(cos); + float2 fsin = bf1622float2(sin); + float2 rot_fv = rotary_embedding_transform(fv, fcos, fsin); + return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); +} + +inline __device__ __nv_bfloat162 rotary_embedding_transform( + const __nv_bfloat162 v, const float2 cos, const float2 sin) { + float2 fv = bf1622float2(v); + float2 rot_fv = rotary_embedding_transform(fv, cos, sin); + return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); +} + +#endif + +inline __device__ void apply_rotary_embedding(float& q, + float& k, + float& cos, + float& sin) { + return; +} + +inline __device__ void apply_rotary_embedding(float2& q, + float2& k, + float2& cos, + float2& sin) { + q = rotary_embedding_transform(q, cos, sin); + k = rotary_embedding_transform(k, cos, sin); +} + +inline __device__ void apply_rotary_embedding(float4& q, + float4& k, + float4& cos, + float4& sin) { + Float4_& q_ = *reinterpret_cast(&q); + Float4_& k_ = *reinterpret_cast(&k); + Float4_& cos_ = *reinterpret_cast(&cos); + Float4_& sin_ = *reinterpret_cast(&sin); + q_.x = rotary_embedding_transform(q_.x, cos_.x, sin_.x); + k_.x = rotary_embedding_transform(k_.x, cos_.x, sin_.x); + q_.y = rotary_embedding_transform(q_.y, cos_.y, sin_.y); + k_.y = rotary_embedding_transform(k_.y, cos_.y, sin_.y); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, + uint32_t& k, + uint32_t& cos, + uint32_t& sin) { + q = rotary_embedding_transform(q, cos, sin); + k = rotary_embedding_transform(k, cos, sin); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, + uint32_t& k, + float2& cos, + float2& sin) { + q = rotary_embedding_transform(q, cos, sin); + k = rotary_embedding_transform(k, cos, sin); +} + +inline __device__ void apply_rotary_embedding(uint2& q, + uint2& k, + uint2& cos, + uint2& sin) { + q.x = rotary_embedding_transform(q.x, cos.x, sin.x); + k.x = rotary_embedding_transform(k.x, cos.x, sin.x); + q.y = rotary_embedding_transform(q.y, cos.y, sin.y); + k.y = rotary_embedding_transform(k.y, cos.y, sin.x); +} + +inline __device__ void apply_rotary_embedding(uint2& q, + uint2& k, + float4& cos, + float4& sin) { + Float4_& cos_ = *reinterpret_cast(&cos); + Float4_& sin_ = *reinterpret_cast(&sin); + q.x = rotary_embedding_transform(q.x, cos_.x, sin_.x); + k.x = rotary_embedding_transform(k.x, cos_.x, sin_.x); + q.y = rotary_embedding_transform(q.y, cos_.y, sin_.y); + k.y = rotary_embedding_transform(k.y, cos_.y, sin_.x); +} + +inline __device__ void apply_rotary_embedding(uint4& q, + uint4& k, + uint4& cos, + uint4& sin) { + q.x = rotary_embedding_transform(q.x, cos.x, sin.x); + k.x = rotary_embedding_transform(k.x, cos.x, sin.x); + q.y = rotary_embedding_transform(q.y, cos.y, sin.y); + k.y = rotary_embedding_transform(k.y, cos.y, sin.y); + q.z = rotary_embedding_transform(q.z, cos.z, sin.z); + k.z = rotary_embedding_transform(k.z, cos.z, sin.z); + q.w = rotary_embedding_transform(q.w, cos.w, sin.w); + k.w = rotary_embedding_transform(k.w, cos.w, sin.w); +} + +inline __device__ void apply_rotary_embedding(uint4& q, + uint4& k, + Float8_& cos, + Float8_& sin) { + q.x = rotary_embedding_transform(q.x, cos.x, sin.x); + k.x = rotary_embedding_transform(k.x, cos.x, sin.x); + q.y = rotary_embedding_transform(q.y, cos.y, sin.y); + k.y = rotary_embedding_transform(k.y, cos.y, sin.y); + q.z = rotary_embedding_transform(q.z, cos.z, sin.z); + k.z = rotary_embedding_transform(k.z, cos.z, sin.z); + q.w = rotary_embedding_transform(q.w, cos.w, sin.w); + k.w = rotary_embedding_transform(k.w, cos.w, sin.w); +} + +inline __device__ void apply_rotary_embedding( + float& q, int zid, int rot_embed_dim, int t_step, float compression_ratio) { + return; +} + +inline __device__ void apply_rotary_embedding(float& q, + float& k, + int zid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + return; +} + +inline __device__ void apply_rotary_embedding(float2& q, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (2 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef = + rotary_embedding_coefficient(2 * tid, rot_embed_dim, float_t_step); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(float2& q, + float2& k, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (2 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef = + rotary_embedding_coefficient(2 * tid, rot_embed_dim, float_t_step); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(float4& q, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (4 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + Float4_& q_ = *reinterpret_cast(&q); + const auto coef0 = + rotary_embedding_coefficient(4 * tid, rot_embed_dim, float_t_step); + q_.x = rotary_embedding_transform(q_.x, coef0); + const auto coef1 = + rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, float_t_step); + q_.y = rotary_embedding_transform(q_.y, coef1); +} + +inline __device__ void apply_rotary_embedding(float4& q, + float4& k, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (4 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + Float4_& q_ = *reinterpret_cast(&q); + Float4_& k_ = *reinterpret_cast(&k); + const auto coef0 = + rotary_embedding_coefficient(4 * tid, rot_embed_dim, float_t_step); + q_.x = rotary_embedding_transform(q_.x, coef0); + k_.x = rotary_embedding_transform(k_.x, coef0); + const auto coef1 = + rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, float_t_step); + q_.y = rotary_embedding_transform(q_.y, coef1); + k_.y = rotary_embedding_transform(k_.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (2 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef = + rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, + uint32_t& k, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (2 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef = + rotary_embedding_coefficient(2 * tid, rot_embed_dim, float_t_step); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding( + uint2& q, int tid, int rot_embed_dim, int t_step, float compression_ratio) { + if (4 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef0 = + rotary_embedding_coefficient(4 * tid, rot_embed_dim, float_t_step); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = + rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, float_t_step); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint2& q, + uint2& k, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (4 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef0 = + rotary_embedding_coefficient(4 * tid, rot_embed_dim, float_t_step); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = + rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, float_t_step); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding( + uint4& q, int tid, int rot_embed_dim, int t_step, float compression_ratio) { + if (8 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef0 = + rotary_embedding_coefficient(8 * tid, rot_embed_dim, float_t_step); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = + rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, float_t_step); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = + rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, float_t_step); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = + rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, float_t_step); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void apply_rotary_embedding(uint4& q, + uint4& k, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (8 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef0 = + rotary_embedding_coefficient(8 * tid, rot_embed_dim, float_t_step); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = + rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, float_t_step); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = + rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, float_t_step); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = + rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, float_t_step); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} + +#ifdef ENABLE_BF16 +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, + __nv_bfloat162& k, + __nv_bfloat162& cos, + __nv_bfloat162& sin) { + q = rotary_embedding_transform(q, cos, sin); + k = rotary_embedding_transform(k, cos, sin); +} + +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, + __nv_bfloat162& k, + float2& cos, + float2& sin) { + q = rotary_embedding_transform(q, cos, sin); + k = rotary_embedding_transform(k, cos, sin); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, + bf16_4_t& k, + bf16_4_t& cos, + bf16_4_t& sin) { + q.x = rotary_embedding_transform(q.x, cos.x, sin.x); + k.x = rotary_embedding_transform(k.x, cos.x, sin.x); + q.y = rotary_embedding_transform(q.y, cos.y, sin.y); + k.y = rotary_embedding_transform(k.y, cos.y, sin.y); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, + bf16_4_t& k, + float4& cos, + float4& sin) { + Float4_& cos_ = *reinterpret_cast(&cos); + Float4_& sin_ = *reinterpret_cast(&sin); + q.x = rotary_embedding_transform(q.x, cos_.x, sin_.x); + k.x = rotary_embedding_transform(k.x, cos_.x, sin_.x); + q.y = rotary_embedding_transform(q.y, cos_.y, sin_.y); + k.y = rotary_embedding_transform(k.y, cos_.y, sin_.y); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, + bf16_8_t& k, + bf16_8_t& cos, + bf16_8_t& sin) { + q.x = rotary_embedding_transform(q.x, cos.x, sin.x); + k.x = rotary_embedding_transform(k.x, cos.x, sin.x); + q.y = rotary_embedding_transform(q.y, cos.y, sin.y); + k.y = rotary_embedding_transform(k.y, cos.y, sin.y); + q.z = rotary_embedding_transform(q.z, cos.z, sin.z); + k.z = rotary_embedding_transform(k.z, cos.z, sin.z); + q.w = rotary_embedding_transform(q.w, cos.w, sin.w); + k.w = rotary_embedding_transform(k.w, cos.w, sin.w); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, + bf16_8_t& k, + Float8_& cos, + Float8_& sin) { + q.x = rotary_embedding_transform(q.x, cos.x, sin.x); + k.x = rotary_embedding_transform(k.x, cos.x, sin.x); + q.y = rotary_embedding_transform(q.y, cos.y, sin.y); + k.y = rotary_embedding_transform(k.y, cos.y, sin.y); + q.z = rotary_embedding_transform(q.z, cos.z, sin.z); + k.z = rotary_embedding_transform(k.z, cos.z, sin.z); + q.w = rotary_embedding_transform(q.w, cos.w, sin.w); + k.w = rotary_embedding_transform(k.w, cos.w, sin.w); +} + +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (2 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef = + rotary_embedding_coefficient(2 * tid, rot_embed_dim, float_t_step); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, + __nv_bfloat162& k, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (2 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef = + rotary_embedding_coefficient(2 * tid, rot_embed_dim, float_t_step); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (4 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef0 = + rotary_embedding_coefficient(4 * tid, rot_embed_dim, float_t_step); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = + rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, float_t_step); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, + bf16_4_t& k, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (4 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef0 = + rotary_embedding_coefficient(4 * tid, rot_embed_dim, float_t_step); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = + rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, float_t_step); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (8 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef0 = + rotary_embedding_coefficient(8 * tid, rot_embed_dim, float_t_step); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = + rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, float_t_step); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = + rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, float_t_step); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = + rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, float_t_step); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, + bf16_8_t& k, + int tid, + int rot_embed_dim, + int t_step, + float compression_ratio) { + if (8 * tid >= rot_embed_dim) { + return; + } + float float_t_step = static_cast(t_step); + float_t_step /= compression_ratio; + const auto coef0 = + rotary_embedding_coefficient(8 * tid, rot_embed_dim, float_t_step); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = + rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, float_t_step); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = + rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, float_t_step); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = + rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, float_t_step); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} +#endif // ENABLE_BF16 + +} // namespace + +} // namespace operators +} // namespace paddle +#endif diff --git a/paddle/fluid/platform/dynload/cublasLt.h b/paddle/fluid/platform/dynload/cublasLt.h index c3425ac604858..1e877854a048c 100644 --- a/paddle/fluid/platform/dynload/cublasLt.h +++ b/paddle/fluid/platform/dynload/cublasLt.h @@ -1,12 +1,9 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -39,7 +36,7 @@ namespace dynload { extern DynLoad__##__name __name // APIs available after CUDA 10.1 -// #if CUDA_VERSION >= 10100 +#if CUDA_VERSION >= 11010 #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasLtCreate); \ __macro(cublasLtDestroy); \ @@ -61,7 +58,33 @@ namespace dynload { __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); \ __macro(cublasLtMatmulAlgoInit); \ - __macro(cublasLtMatmulAlgoConfigSetAttribute); + __macro(cublasLtMatmulAlgoConfigSetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ + __macro(cublasLtMatmulAlgoCheck); \ + __macro(cublasLtGetCudartVersion); +#else +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); +#endif CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif diff --git a/paddle/fluid/pybind/eager_generator.h b/paddle/fluid/pybind/eager_generator.h index 03b8690569c22..2e2771f945ecd 100644 --- a/paddle/fluid/pybind/eager_generator.h +++ b/paddle/fluid/pybind/eager_generator.h @@ -60,25 +60,13 @@ std::map> op_ins_map = { "OutLinearWeight", "OutLinearBias"}}, {"fused_multi_transformer", - {"X", - "LnScale", - "LnBias", - "QKVW", - "QKVBias", - "CacheKV", - "PreCaches", - "RotaryPosEmb", - "TimeStep", - "SeqLengths", - "SrcMask", - "OutLinearW", - "OutLinearBias", - "FFNLnScale", - "FFNLnBias", - "FFN1Weight", - "FFN1Bias", - "FFN2Weight", - "FFN2Bias"}}, + {"X", "LnScale", "LnBias", + "QKVW", "QKVBias", "CacheKV", + "PreCaches", "RotaryPosEmb", "BeamCacheOffset", + "TimeStep", "SeqLengths", "SrcMask", + "OutLinearW", "OutLinearBias", "FFNLnScale", + "FFNLnBias", "FFN1Weight", "FFN1Bias", + "FFN2Weight", "FFN2Bias"}}, {"fused_multi_transformer_int8", {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "TimeStep", "SrcMask", diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 0546dd84b6882..4683ff0db3fae 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -182,6 +182,9 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) { case PaddleDType::FLOAT16: dt = py::dtype::of(); break; + // case PaddleDType::BFLOAT16: + // dt = py::dtype::of(); + // break; case PaddleDType::UINT8: dt = py::dtype::of(); break; @@ -194,7 +197,7 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) { default: PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data type. Now only supports INT32, INT64, FLOAT64, " - "FLOAT32, FLOAT16, INT8, UINT8 and BOOL.")); + "FLOAT32, FLOAT16, BFLOAT16, INT8, UINT8 and BOOL.")); } return dt; @@ -262,6 +265,11 @@ void PaddleInferShareExternalData(paddle_infer::Tensor &tensor, // NOLINT static_cast(input_tensor.data()), shape, ToPaddleInferPlace(input_tensor.place().GetType())); + } else if (input_tensor.dtype() == phi::DataType::BFLOAT16) { + tensor.ShareExternalData( + static_cast(input_tensor.data()), + shape, + ToPaddleInferPlace(input_tensor.place().GetType())); } else if (input_tensor.dtype() == phi::DataType::FLOAT16) { tensor.ShareExternalData( static_cast(input_tensor.data()), @@ -280,7 +288,7 @@ void PaddleInferShareExternalData(paddle_infer::Tensor &tensor, // NOLINT } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data type. Now share_external_data only supports INT32, " - "INT64, FLOAT64, FLOAT32 and FLOAT16.")); + "INT64, FLOAT64, FLOAT32, BFLOAT16 and FLOAT16.")); } } @@ -301,6 +309,11 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT static_cast(paddle_tensor.data()), shape, ToPaddleInferPlace(paddle_tensor.place().GetType())); + } else if (paddle_tensor.dtype() == phi::DataType::BFLOAT16) { + tensor.ShareExternalData( + static_cast(paddle_tensor.data()), + shape, + ToPaddleInferPlace(paddle_tensor.place().GetType())); } else if (paddle_tensor.dtype() == phi::DataType::FLOAT16) { tensor.ShareExternalData( static_cast( @@ -320,7 +333,7 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data type. Now share_external_data only supports INT32, " - "INT64, FLOAT32 and FLOAT16.")); + "INT64, FLOAT32, BFLOAT16 and FLOAT16.")); } } @@ -355,6 +368,9 @@ size_t PaddleGetDTypeSize(PaddleDType dt) { case PaddleDType::FLOAT16: size = sizeof(phi::dtype::float16); break; + case PaddleDType::BFLOAT16: + size = sizeof(paddle_infer::bfloat16); + break; case PaddleDType::INT8: size = sizeof(int8_t); break; @@ -395,6 +411,10 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT tensor.copy_to_cpu( static_cast(array.mutable_data())); break; + case PaddleDType::BFLOAT16: + tensor.copy_to_cpu( + static_cast(array.mutable_data())); + break; case PaddleDType::UINT8: tensor.copy_to_cpu(static_cast(array.mutable_data())); break; @@ -435,6 +455,10 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT tensor.CopyToCpu( static_cast(array.mutable_data())); break; + case PaddleDType::BFLOAT16: + tensor.CopyToCpu( + static_cast(array.mutable_data())); + break; case PaddleDType::UINT8: tensor.CopyToCpu(static_cast(array.mutable_data())); break; @@ -529,6 +553,7 @@ void BindPaddleDType(py::module *m) { .value("FLOAT64", PaddleDType::FLOAT64) .value("FLOAT32", PaddleDType::FLOAT32) .value("FLOAT16", PaddleDType::FLOAT16) + .value("BFLOAT16", PaddleDType::BFLOAT16) .value("INT64", PaddleDType::INT64) .value("INT32", PaddleDType::INT32) .value("UINT8", PaddleDType::UINT8) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 72dee765d6157..5035a027986f8 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -453,6 +453,15 @@ func : cumprod backward : cumprod_grad +- op : cutlass_fused_multihead_attention + args : (Tensor query, Tensor key, Tensor value, Tensor mask, float scale, bool causal) + output : Tensor + infer_meta : + func : FusedMultiHeadAttentionInferMeta + kernel : + func : fused_multihead_attention + optional : mask + - op : det args : (Tensor x) output : Tensor @@ -775,6 +784,16 @@ func : frame backward : frame_grad +- op : fused_multihead_attention_variable + args : (Tensor query, Tensor key, Tensor value, Tensor seq_lens, Tensor mask, float scale, bool causal) + output : Tensor + infer_meta : + func : FusedMultiHeadAttentionVariableInferMeta + kernel : + func : fused_multihead_attention_variable + data_type : query + optional : mask + - op : gather_nd args : (Tensor x, Tensor index) output : Tensor diff --git a/paddle/phi/backends/dynload/cublasLt.h b/paddle/phi/backends/dynload/cublasLt.h index 90492ff4ba69d..52196ee9d5715 100644 --- a/paddle/phi/backends/dynload/cublasLt.h +++ b/paddle/phi/backends/dynload/cublasLt.h @@ -1,12 +1,9 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -54,6 +51,7 @@ extern void *cublasLt_dso_handle; // APIs available after CUDA 10.1 // #if CUDA_VERSION >= 10100 +#if CUDA_VERSION >= 11010 #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasLtCreate); \ __macro(cublasLtDestroy); \ @@ -75,7 +73,33 @@ extern void *cublasLt_dso_handle; __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); \ __macro(cublasLtMatmulAlgoInit); \ - __macro(cublasLtMatmulAlgoConfigSetAttribute); + __macro(cublasLtMatmulAlgoConfigSetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ + __macro(cublasLtMatmulAlgoCheck); \ + __macro(cublasLtGetCudartVersion); +#else +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); +#endif CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index d425f74349b78..4695a694a186d 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1283,3 +1283,21 @@ PADDLE_DEFINE_EXPORTED_int64(alloc_fill_value, -1, "Whether to fill fixed value after allocation. " "This is usefull for debugging."); + +/** + * The fmha mode in FusedMultiTransformer + * Value Range: string, {naive, cutlass, flash_attention_v2} + */ +PADDLE_DEFINE_EXPORTED_string(fmha_mode, + "cutlass", + "The mode of fmha in FusedMultiTransformer."); + +PADDLE_DEFINE_EXPORTED_bool(print_matrix, false, ""); +PADDLE_DEFINE_EXPORTED_bool(fuse_softmax, false, ""); + +PADDLE_DEFINE_EXPORTED_int64(custom_allreduce_one_shot_threshold, + -1, + ""); // 393216 +PADDLE_DEFINE_EXPORTED_int64(custom_allreduce_two_shot_threshold, + -1, + ""); // 50331648 diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 4bfce8d4bb45e..035c96f12b392 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3269,5 +3269,64 @@ void MoeInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& seq_lens, + const MetaTensor& mask, + float scale, + bool causal, + MetaTensor* out) { + const int64_t query_batch_size = query.dims()[0]; + const int64_t query_seq_length = query.dims()[2]; + const int64_t query_num_head = query.dims()[1]; + const int64_t value_head_size = value.dims()[3]; + std::vector out_dims( + {query_batch_size, query_num_head, query_seq_length, value_head_size}); + out->set_dims(phi::make_ddim(out_dims)); + out->share_lod(query); + out->set_dtype(query.dtype()); + out->set_layout(query.layout()); +} + +void FusedMultiHeadAttentionInferMeta(const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& mask, + float scale, + bool causal, + MetaTensor* out) { + PADDLE_ENFORCE_EQ( + query.dims().size(), + 4, + phi::errors::InvalidArgument("Query should be a 4-D tensor" + "But received Query dimension(%s)", + query.dims().size())); + PADDLE_ENFORCE_EQ( + key.dims().size(), + 4, + phi::errors::InvalidArgument("Key should be a 4-D tensor" + "But received Key dimension(%s)", + key.dims().size())); + PADDLE_ENFORCE_EQ( + value.dims().size(), + 4, + phi::errors::InvalidArgument("Value should be a 4-D tensor" + "But received Value dimension(%s)", + value.dims().size())); + const int64_t query_batch_size = query.dims()[0]; + const int64_t query_num_head = query.dims()[1]; + const int64_t query_seq_length = query.dims()[2]; + const int64_t value_head_size = value.dims()[3]; + + std::vector out_dims( + {query_batch_size, query_num_head, query_seq_length, value_head_size}); + + out->set_dims(phi::make_ddim(out_dims)); + out->share_lod(query); + out->set_dtype(query.dtype()); + out->set_layout(query.layout()); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 90e4f8cb47391..7e274dfd042ce 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -612,5 +612,19 @@ void FusedRopeInferMeta(const MetaTensor& q, MetaTensor* out_q, MetaTensor* out_k, MetaTensor* out_v); - +void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& seq_lens, + const MetaTensor& mask, + float scale, + bool causal, + MetaTensor* out); +void FusedMultiHeadAttentionInferMeta(const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& mask, + float scale, + bool causal, + MetaTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index fe8548f9fb5fc..f558453ac0527 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -131,6 +131,13 @@ if(WITH_CUTLASS) --cuda_arch "${NVCC_ARCH_BIN}" RESULT_VARIABLE memory_efficient_attention_gen_res) + execute_process( + COMMAND + ${PYTHON_EXECUTABLE} + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/generate_variable_forward_kernels.py + --cuda_arch "${NVCC_ARCH_BIN}" + RESULT_VARIABLE memory_efficient_attention_gen_res) + if(NOT memory_efficient_attention_gen_res EQUAL 0) message( FATAL_ERROR @@ -138,9 +145,14 @@ if(WITH_CUTLASS) ) endif() - file(GLOB cutlass_cu "fusion/cutlass/conv2d/generated/*.cu" - "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu" - "fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu") + file( + GLOB + cutlass_cu + "fusion/cutlass/conv2d/generated/*.cu" + "fusion/cutlass/conv2d/*.cu" + "fusion/cutlass/*.cu" + "fusion/cutlass/fused_multi_head_attention/autogen/impl/*.cu" + "fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu") add_definitions("-DPADDLE_WITH_MEMORY_EFFICIENT_ATTENTION") list(APPEND kernel_cu ${cutlass_cu}) endif() diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/debug_utils.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/debug_utils.h new file mode 100644 index 0000000000000..08f189eed5be0 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/debug_utils.h @@ -0,0 +1,207 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Debugging functions +//////////////////////////////////////////////////////////////////////////////// +// Nans & inf detection +#define NANCHECK(frag) \ + { \ + for (int _i = 0; _i < frag.size(); ++_i) { \ + assert(std::isfinite(float(frag[_i]))); \ + assert(!std::isnan(float(frag[_i]))); \ + } \ + } + +// Print on the first thread of the first block +#if 1 +#define PRINT_WARP_ID 0 +#define PRINT_LANE_ID 0 +#define PRINT_T0_L0(msg, ...) \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ + threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_TX_LX(msg, ...) \ + for (int bx = 0; bx < gridDim.x; ++bx) { \ + for (int by = 0; by < gridDim.y; ++by) { \ + for (int bz = 0; bz < gridDim.z; ++bz) { \ + for (int tx = 0; tx < blockDim.x; ++tx) { \ + for (int ty = 0; ty < blockDim.y; ++ty) { \ + for (int tz = 0; tz < blockDim.z; ++tz) { \ + __syncthreads(); \ + if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \ + threadIdx.x == tx && threadIdx.y == ty && \ + threadIdx.z == tz) { \ + printf("[%d,%d,%d][%d,%d,%d]" msg "\n", \ + bx, \ + by, \ + bz, \ + tx, \ + ty, \ + tz, \ + ##__VA_ARGS__); \ + } \ + } \ + } \ + } \ + } \ + } \ + } +#else +#define PRINT_T0_L0 +#define PRINT_TX_LX +#endif + +struct __string_view { + char const* data; + std::size_t size; +}; +#if __cplusplus >= 201402L +template +constexpr __string_view __get_type_name() { + char const* p = __PRETTY_FUNCTION__; + while (*p++ != '=') + ; + for (; *p == ' '; ++p) + ; + char const* p2 = p; + int count = 1; + for (;; ++p2) { + switch (*p2) { + case '[': + ++count; + break; + case ']': + --count; + if (!count) return {p, std::size_t(p2 - p)}; + } + } + return {}; +} +#else +template +constexpr __string_view __get_type_name() { + return {"unsupported", 11}; +} +#endif + +// Print a given array +#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ + PRINT_T0_L0("%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ + name, \ + int(start), \ + int(start + 8), \ + float(accum[start + 0]), \ + float(accum[start + 1]), \ + float(accum[start + 2]), \ + float(accum[start + 3]), \ + float(accum[start + 4]), \ + float(accum[start + 5]), \ + float(accum[start + 6]), \ + float(accum[start + 7])); +#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0) +#define PRINT_FRAG_T0_L0(name, frag) \ + { \ + auto typeStr = __get_type_name(); \ + PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \ + for (int _start = 0; _start < frag.size(); _start += 8) { \ + PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ + } \ + /*__syncthreads(); \ + NANCHECK(frag); */ \ + } +#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ + { \ + PRINT_T0_L0("printing %s (len=%d)", name, int(length)); \ + for (int _start = 0; _start < length; _start += incr) { \ + PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ + } \ + } +#define PRINT_ARRAY_T0_L0(name, array, length) \ + PRINT_ARRAY_T0_L0_INCR(name, array, length, 8) + +// Print a 4x4 matrix +#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ + PRINT_T0_L0( \ + "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, " \ + "%f, %f, %f\n %f, %f, %f, %f", \ + name, \ + int(start_x), \ + int(start_x + 4), \ + int(start_y), \ + int(start_y + 4), \ + float(ref.at({start_x + 0, start_y + 0})), \ + float(ref.at({start_x + 0, start_y + 1})), \ + float(ref.at({start_x + 0, start_y + 2})), \ + float(ref.at({start_x + 0, start_y + 3})), \ + float(ref.at({start_x + 1, start_y + 0})), \ + float(ref.at({start_x + 1, start_y + 1})), \ + float(ref.at({start_x + 1, start_y + 2})), \ + float(ref.at({start_x + 1, start_y + 3})), \ + float(ref.at({start_x + 2, start_y + 0})), \ + float(ref.at({start_x + 2, start_y + 1})), \ + float(ref.at({start_x + 2, start_y + 2})), \ + float(ref.at({start_x + 2, start_y + 3})), \ + float(ref.at({start_x + 3, start_y + 0})), \ + float(ref.at({start_x + 3, start_y + 1})), \ + float(ref.at({start_x + 3, start_y + 2})), \ + float(ref.at({start_x + 3, start_y + 3}))); +#define PRINT_TENSOR4x4_T0_L0(name, ref) \ + PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) + +#define PRINT_PROBLEM_SIZE(name, ps) \ + PRINT_T0_L0("%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ + name, \ + int(ps.m()), \ + int(ps.n()), \ + int(ps.k())) + +template +CUTLASS_DEVICE void print_warp_accum(AccumT accum, + LaneOffsetT lane_offset, + int32_t num_rows, + int32_t num_cols) { + bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + if (col % 32 == 0) { + if (is_main) { + printf("\nmat[%3d, %3d:%3d]", row, col, col + 32); + } + __syncthreads(); + } + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (row == accum_m && col == accum_n) { + printf(" %6.1f", float(accum[idx])); + } + }, + [&](int accum_m) {}); + __syncthreads(); + } + if (is_main) { + printf("\n"); + } + } +} diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/default_fmha_grouped.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/default_fmha_grouped.h new file mode 100644 index 0000000000000..2f42865682f54 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/default_fmha_grouped.h @@ -0,0 +1,302 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix + multiply-add with the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major + outputs are accommodated by exchanging A and B operands and assuming + transposed layouts. Partial specializations here choose + 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "gemm/attention_scaling_coefs_updater.h" +#include "gemm/find_default_mma.h" +#include "gemm/fmha_grouped.h" +#include "gemm/mma_from_smem.h" +#include "gemm_kernel_utils.h" +#include "transform/tile_smem_loader.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + bool maskIsAligned_, + int QueriesPerBlock_, + int KeysPerBlock_, + bool SingleValueIteration_, + GroupScheduleMode GroupScheduleMode_, + bool AddMask, + bool MaskBroadcastRow> +struct DefaultFMHAGrouped { + using scalar_t = scalar_t_; + using accum_t = float; + using output_t = scalar_t; + + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + + using ArchTag = ArchTag_; + static bool const kIsAligned = isAligned_; + static bool const kAddMask = AddMask; + static bool const kMaskBroadcastRow = MaskBroadcastRow; + static bool const kSingleValueIteration = SingleValueIteration_; + static int const kKeysPerBlock = KeysPerBlock_; + static bool const kMaskIsAligned = maskIsAligned_; + static int const kWarpSize = 32; + static int const kNumWarpsPerBlock = + QueriesPerBlock_ * KeysPerBlock_ / (kWarpSize * kWarpSize); + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + + using GemmType = + gemm_kernel_utils::DefaultGemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = scalar_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator>; + + static int const kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator>::DefaultMma; + + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using Mma = typename DefaultMma::ThreadblockMma; + using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< + typename Mma::Operator::IteratorC, + ElementAccumulator, + kWarpSize>::Updater; + static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, ""); + + // used for efficient load of mask_ tile Bij from global to shared memory + using MaskLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + kMaskIsAligned ? 128 / cutlass::sizeof_bits::value : 1>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /* + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + + using GemmType = typename MM0::GemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator>; + + static int const kAlignmentA = DefaultConfig::kAlignmentA; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = typename MM0::ThreadblockShape; + using WarpShape = typename MM0::WarpShape; + using InstructionShape = typename MM0::InstructionShape; + + using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using ThreadblockSwizzle = void; // Swizzling is unused + static bool const kSplitKSerial = false; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm; + + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MM0::AccumulatorSharedStorage, + false>; + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert(WarpCount::kCount == kNumWarpsPerBlock, ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + + struct SharedStorageMM1 { + typename Mma::SharedStorage mm; + }; + }; + + /// Define the kernel in terms of the default kernel + using FMHAKernel = kernel::FMHAGrouped; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/epilogue/epilogue_pipelined.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/epilogue/epilogue_pipelined.h new file mode 100644 index 0000000000000..eacb359c182b1 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/epilogue/epilogue_pipelined.h @@ -0,0 +1,617 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + File copied from "cutlass/epilogue/threadblock/epilogue.h" + then modified to: + (1) load 2 source fragments at the same time (pipelining) + (2) support reading from a different dtype + (3) pass the row id to the OutputOp if it takes it + (see MemoryEfficientAttentionNormalize) + Note that in general the fragment passed to the OutputOp could + span multiple rows but it does not happen with the configurations we have +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct ApplyEpilogueOp { + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentOutput const& source) { + return output_op(accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template ::value), + typename OutputTileSourceIterator_ = + OutputTileIterator_ ///< Tile iterator reading tensors + > +class EpiloguePipelined : public EpilogueBase { + public: + using Base = EpilogueBase; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using OutputTileSourceIterator = OutputTileSourceIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + using ElementSource = typename OutputTileSourceIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array; + using SourceAccessType = Array; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + public: + static_assert( + OutputTileSourceIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between input tile and output tile iterator (kElements)"); + static_assert( + OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, + "Mismatch between input tile and output tile iterator (kIterations)"); + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + public: + /// Constructor + CUTLASS_DEVICE + EpiloguePipelined( + typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) {} + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator) { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_( + output_op, destination_iterator, accumulators, source_iterator); + } + } + CUTLASS_DEVICE + void operator()(OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators) { ///< Complete warp-level accumulator tile + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + + private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = {(pos == (Seq * Base::kFragmentsPerIteration)) && + (helper( + iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators ///< Complete warp-level accumulator tile + ) { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / \ + Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * + kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = {(pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator ///< Threadblock tile coordinate in GEMM (in units of + ///< threadblock tiles) + ) { + typename OutputTileSourceIterator::Fragment source_fragment[2]; + + source_fragment[0].clear(); + source_iterator.load(source_fragment[0]); + ++source_iterator; + source_fragment[1].clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + if (iter > 0) { + __syncthreads(); + } + // + // Load the source for next iteration (pipelining) + // + + if (iter + 1 < OutputTileIterator::kIterations) { + source_iterator.load(source_fragment[(iter + 1) % 2]); + } + ++source_iterator; + acc2smem_source_needed>::push(iter, + accum_fragment_iterator, + this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], + aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * + kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_(destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment[iter % 2]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileSourceIterator::Fragment const& source_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + SourceAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i], + source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i]); + } + } + + // This should be constexpr, but it's only supported on c++14 + static int CUTLASS_HOST_DEVICE getRowOffset(int i) { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = + ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { + return row_offset; + } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/epilogue/epilogue_rescale_output.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/epilogue/epilogue_rescale_output.h new file mode 100644 index 0000000000000..b199dd268db1e --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/epilogue/epilogue_rescale_output.h @@ -0,0 +1,237 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + This is a copy of cutlass/epilogue/threadblock/epilogue.h that can + handle "row_id" as a first argument, as uses it to get the corresponding + `m_prime` / `s_prime` to rescale the output. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "epilogue_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template , + ///< but we use 64 or 32 sometimes when there are not + ///< enough data to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear + ///< combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { + public: + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + + private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize(FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return !isFirst; } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, + FragmentAccumulator const& accumulator, + FragmentSource const& source) const { + assert(!isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, + FragmentAccumulator const& accumulator) const { + assert(isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + + intermediate = mul_accumulator( + alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +} // namespace thread + +namespace threadblock { +template +struct ApplyEpilogueOp> { + using Op = thread:: + MemoryEfficientAttentionNormalize; + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentSource const& source) { + return output_op(row_id, accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(row_id, accum); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h new file mode 100644 index 0000000000000..fda4a07887f40 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h @@ -0,0 +1,170 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayExponential { + CUTLASS_HOST_DEVICE + Array operator()( + Array const& input) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = expf(input[i]); + } + + return result; + } +}; + +template +struct ArrayExponential { + CUTLASS_DEVICE + Array operator()( + Array const& input) const { + Array result; + + int const kVectorCount = ElementsPerAccess / 2; + + __half2 const* input_ptr = + reinterpret_cast<__half2 const*>(input.raw_data()); + __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = h2exp(input_ptr[i]); + } + + return result; + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies: +/// output <- (input - lse).exp() +template +class ApplyLogSumExp { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementLSE = ElementLSE_; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + static const ScaleType::Kind kScale = + cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentLSE = Array; + using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h + + public: + // + // Methods + // + + CUTLASS_HOST_DEVICE + ApplyLogSumExp() {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const& AB, + FragmentLSE const& scale_unused, + // bias used as LSE + FragmentLSE const& bias) const { + FragmentCompute frag_AB = NumericArrayConverter()(AB); + FragmentCompute frag_lse_compute = + NumericArrayConverter()( + bias); + FragmentCompute frag_compute; + + minus minus_lse; + detail::ArrayExponential apply_exp; + frag_compute = minus_lse(frag_AB, frag_lse_compute); + frag_compute = apply_exp(frag_compute); + + return NumericArrayConverter()(frag_compute); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/attention_scaling_coefs_updater.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/attention_scaling_coefs_updater.h new file mode 100644 index 0000000000000..0fb02cb751191 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/attention_scaling_coefs_updater.h @@ -0,0 +1,508 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "../gemm_kernel_utils.h" +#include "../kernel_forward.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +/* Iterates on the accumulator and corresponding position on result matrix + +(1) Update `mi[r]` to the max value of the row `r` +(2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + +All of this is done on registers, before we store all of this +on shared memory for the next matmul with Value. + +We have multiple implementations, because each configuration has a different way +of iterating in the accumulators. +*/ + +template +struct RegisterOps { + template + CUTLASS_DEVICE static void update( + typename T::Fragment& frag_o, // output so far + typename T::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + cutlass::Array& addition_storage, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int16_t max_col, + typename T::TensorCoord const& tile_offset, + float scaling) { + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + if (!kIsFirst) { + if (thread_id < kQueriesPerBlock) { + m_prime[thread_id] = mi[thread_id]; + } + __syncthreads(); + } + + auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + BASE::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (kFullColumns || accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max * scaling); + }); + } + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (thread_id < kQueriesPerBlock) { + auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); + m_prime[thread_id] = m_prime_exp; + s_prime[thread_id] *= m_prime_exp; + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !kIsFirst) { + accum_t mp; + BASE::iterateRows( + lane_offset, + [&](int accum_m) { mp = m_prime[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) {}); + __syncthreads(); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + BASE::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = (kFullColumns || accum_n < max_col) + ? exp2f(frag[idx] - mi_row) + : accum_t(0.0); + }, + [&](int accum_m) {}); + BASE::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (BASE::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + addition_storage[accum_m + kQueriesPerBlock * + tile_offset.column()] = total_row; + } + }); + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + total_row = s_prime[id]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kWarpN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } + } +}; + +template +struct AttentionScalingCoefsUpdaterSm80 + : RegisterOps, + T, + accum_t, + kWarpSize> { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord(quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + // printf("mma_m: %d, mma_n: %d, row: %d, col: %d, idx: %d\n", + // mma_m, mma_n, row, col, idx); + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AttentionScalingCoefsUpdaterVolta + : RegisterOps, + T, + accum_t, + kWarpSize> { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord( + accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + static_assert(cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; + ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; + ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AttentionScalingCoefsUpdaterSimt + : RegisterOps, + T, + accum_t, + kWarpSize> { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = + mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = + n + Policy::LaneMmaShape::kN * + (mma_n + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + static_assert(cutlass::platform::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = + lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, + Policy::LaneMmaShape::kN); + return lane_offset + + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultAttentionScalingCoefsUpdater; + +// Simt +template +struct DefaultAttentionScalingCoefsUpdater< + cutlass::gemm::warp::MmaSimtTileIterator, + accum_t, + kWarpSize> { + using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>; + using Updater = + AttentionScalingCoefsUpdaterSimt; +}; + +// TensorOp - Volta +template +struct DefaultAttentionScalingCoefsUpdater< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>, + accum_t, + kWarpSize> { + using Iterator = + typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Updater = + AttentionScalingCoefsUpdaterVolta; +}; + +// TensorOp - Sm75+ +template +struct DefaultAttentionScalingCoefsUpdater< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>, + accum_t, + kWarpSize> { + using Iterator = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>; + using Updater = + AttentionScalingCoefsUpdaterSm80; +}; diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/base_grouped.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/base_grouped.h new file mode 100644 index 0000000000000..8ce3e0dca66f6 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/base_grouped.h @@ -0,0 +1,483 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Base device-level grouped kernel. +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class BaseGrouped { + public: + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = + typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + + protected: + /// Kernel parameters object + typename BaseKernel::Params params_; + + private: + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count( + const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) { + cudaError_t cuda_error = + cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " + << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const& args, + int32_t tile_count, + void* workspace) { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes, + args.problem_count, + args.threadblock_count, + (void*)host_workspace.data()); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, const std::vector& indices) { + // For now, simply create a copy of the data and then copy over to the + // original. + std::vector copy(indices.size()); + for (int i = 0; i < indices.size(); ++i) { + copy.at(i) = data[indices[i]]; + } + + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + + public: + /// Constructs the GEMM. + BaseGrouped() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + return BaseKernel::can_implement(args); + } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const& args) { + if (args.host_problem_sizes == nullptr) { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; + } + + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + return BaseKernel::ProblemVisitor::get_workspace_size( + args.host_problem_sizes, args.problem_count, args.threadblock_count); + } else { + return 0; + } + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + return dim3(args.threadblock_count, 1, 1); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, + Kernel, + BaseKernel::kThreadCount, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, + cutlass::gemm::GemmCoord* problem_sizes_ptr, + int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, + int64_t* ldc_host_ptr, + int64_t* ldd_host_ptr, + int64_t* offset_A_ptr, + int64_t* offset_B_ptr, + int64_t* offset_C_ptr, + int64_t* offset_D_ptr) { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), + indices.end(), + [&problem_sizes_ptr](size_t i, size_t j) { + return problem_sizes_ptr[i].k() > + problem_sizes_ptr[j].k(); + }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient( + const cutlass::gemm::GemmCoord* problem_sizes_ptr = nullptr, + int problem_count = 0, + int available_sm_count = -1) { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + // printf("custom base\n"); + static cudaDeviceProp properties; + static bool count = true; + if (count) { + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " + << cudaGetErrorString(result)); + return 0; + } + + result = cudaGetDeviceProperties(&properties, device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDeviceProperties() returned error " + << cudaGetErrorString(result)); + return 0; + } + } + count = false; + + bool override_sm_count = + (available_sm_count < 0 || + available_sm_count > properties.multiProcessorCount); + if (override_sm_count) { + available_sm_count = properties.multiProcessorCount; + } + + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) { + return 0; + } + + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) { + return occupancy_based_block_count; + } + + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); + + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return + // total_tiles unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) { + return total_tiles; + } + + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating + // through problem sizes to determine that they have no work to do. This + // competes for cycles with those threadblocks that are assigned tiles to + // compute. + return min(total_tiles, occupancy_based_block_count); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("BaseGrouped::initialize() - workspace " + << workspace + << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + params_ = typename BaseKernel::Params(args, workspace, tile_count); + } else { + params_ = typename BaseKernel::Params(args, workspace); + } + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + params_.update(args, workspace, tile_count); + } else { + params_.update(args, workspace); + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + // + // Configure grid and block dimensions + // + + if (!params_.problem_visitor.problem_count) { + return Status::kSuccess; + } + + dim3 grid(params_.threadblock_count, 1, 1); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + // + // Launch kernel + // + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Initializes and runs the kernel. + Status operator()(Arguments const& args, + void* workspace, + cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma.h new file mode 100644 index 0000000000000..eb3b47dc257cd --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma.h @@ -0,0 +1,100 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "../utils.h" +#include "custom_mma_multistage.h" +#include "custom_mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" + +template +struct MakeCustomMma; + +template +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaMultistage, + kMaxK> { + // Reduce the number of stages if we don't need that many + static int constexpr kStages = + kMaxK == cutlass::platform::numeric_limits::max() + ? Stages + : cutlass::const_min(Stages, + (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); + using Mma = cutlass::gemm::threadblock::CustomMmaMultistage; +}; + +template +struct MakeCustomMma, + kMaxK> { + using Mma = cutlass::gemm::threadblock::CustomMmaPipelined; +}; diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma_base.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma_base.h new file mode 100644 index 0000000000000..1eb41391f5e40 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma_base.h @@ -0,0 +1,177 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template + struct OperandSharedStorage { + AlignedBuffer buffer; + using TensorRef = TensorRef; + + CUTLASS_DEVICE + static OperandLayout Layout() { + return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); + } + + /// Returns a TensorRef to the operand + CUTLASS_HOST_DEVICE + TensorRef ref() { return TensorRef{buffer.data(), Layout()}; } + }; + + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + using SharedStorageA = OperandSharedStorage; + using SharedStorageB = OperandSharedStorage; + using TensorRefA = typename SharedStorageA::TensorRef; + using TensorRefB = typename SharedStorageB::TensorRef; + + struct SharedStorage { + /// Buffer for A operand + SharedStorageA operand_A; + + /// Buffer for B operand + SharedStorageB operand_B; + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorageA& shared_storageA, + SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), + warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma_multistage.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma_multistage.h new file mode 100644 index 0000000000000..2a3ccb69351ed --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma_multistage.h @@ -0,0 +1,740 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "../utils.h" +#include "custom_mma_base.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Upper boundon the K dimension + int kMaxK = cutlass::platform::numeric_limits::max(), + /// Used for partial specialization + typename Enable = bool> +class CustomMmaMultistage : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireMat ? Stages : Stages - 1; + + private: + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + bool prologue_done_; + + // Set to `True` to ensure the accumulator will be zero outside the GEMM + // footprint + bool zero_outside_bounds_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx), + prologue_done_(false), + zero_outside_bounds_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaMultistage( + st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx) {} + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { prologue_done_ = value; } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { zero_outside_bounds_ = value; } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue(shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); + SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); + int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; + _prologue( + iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + template + CUTLASS_DEVICE static void _prologue(IteratorA& iterator_A, + IteratorB& iterator_B, + int32_t& gemm_k_iterations, + SmemIteratorA& smem_iterator_A_, + SmemIteratorB& smem_iterator_B_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + if (kLoadA) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + } + + ++iterator_A; + } + + ++smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (kLoadB) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } + + ++iterator_B; + } + + ++smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + if (!prologue_done_) { + _prologue(iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else if (!kSmemContainsEntireMat) { + _prologue(iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else { + gemm_k_iterations -= kNumStagesConcurrentLoad; + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], + warp_transformed_frag_B[0], + warp_loaded_frag_A[0], + warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (platform::is_same::value || + platform::is_same::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + + // In case of a non-circular buffer ("kSmemContainsEntireMat") + // make sure we don't load out of bounds data. + if (!kSmemContainsEntireMat || + gemm_k_iterations > (-kNumStagesConcurrentLoad) || + warp_mma_k < Base::kWarpGemmIterations - 1) { + this->warp_tile_iterator_A_.load( + warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (platform::is_same::value || + platform::is_same::value) { + warp_mma(tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma(accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (!kSmemContainsEntireMat && + warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + if (!kSmemContainsEntireMat) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (!kSmemContainsEntireMat && + smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same::value || + platform::is_same::value) { + accum = plus_accum(accum, tmp_accum); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma_pipelined.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma_pipelined.h new file mode 100644 index 0000000000000..05d5dd7458b4b --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/custom_mma_pipelined.h @@ -0,0 +1,393 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = + NumericArrayConverter, + /// + /// Transformation applied to B operand + typename TransformB_ = + NumericArrayConverter, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaPipelined : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + static bool const kSmemContainsEntireMat = false; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaPipelined(typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaPipelined( + st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx) {} + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { + // NOT IMPLEMENTED FOR PIPELINED + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { + // NOT NEEDED FOR PIPELINED + // shared memory will always be zero-filled + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue(shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + // NOT IMPLEMENTED FOR PIPELINED + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + TransformA transform_A = + TransformA(), ///< transformation applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma(accum, + warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/find_default_mma.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/find_default_mma.h new file mode 100644 index 0000000000000..bcfa875e487a0 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/find_default_mma.h @@ -0,0 +1,171 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/*! \file + \brief Cutlass provides helper template functions to figure out the right + datastructures to instanciate to run a GEMM with various parameters (see + `cutlass/gemm/threadblock/default_mma.h`). However, due to template + instantiation priority rules, it will only create an MmaMultiStage with + kStages=3 (otherwise creates an MmePipelined - which is not compatible with + FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, + so we just copy-pasted some code from `default_mma.h` and + `default_mma_core.h` files and wrapped this template to allow our usecase. + This is really only for the FastF32 case - aka using TensorCores with fp32. +*/ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + typename Enable_ = void> +struct FindDefaultMma { + static constexpr bool AccumulatorsInRowMajor = false; + static constexpr SharedMemoryClearOption SharedMemoryClear = + SharedMemoryClearOption::kNone; + using DefaultMma = + cutlass::gemm::threadblock::DefaultMma; +}; + +/// Specialization for sm80 / FastF32 / multistage with kStages=2 +template +struct FindDefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { + using LayoutC = layout::RowMajor; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm80; + + using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma; + struct DefaultMma : DefaultMma_ { + using MmaCore_ = typename DefaultMma_::MmaCore; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore_::Shape, + typename DefaultMma_::IteratorA, + typename MmaCore_::SmemIteratorA, + MmaCore_::kCacheOpA, + typename DefaultMma_::IteratorB, + typename MmaCore_::SmemIteratorB, + MmaCore_::kCacheOpB, + ElementAccumulator, + LayoutC, + typename MmaCore_::MmaPolicy, + kStages>; + }; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/fmha_grouped.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/fmha_grouped.h new file mode 100644 index 0000000000000..138ac2b30e6e8 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/fmha_grouped.h @@ -0,0 +1,974 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped FMHA kernel +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "../epilogue/epilogue_rescale_output.h" +#include "../gemm_kernel_utils.h" +#include "attention_scaling_coefs_updater.h" +#include "fmha_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FMHAGrouped { + public: + using MM0 = MM0_; + using MM1 = MM1_; + + using scalar_t = scalar_t_; + using accum_t = accum_t_; + using output_t = output_t_; + using output_accum_t = output_accum_t_; + + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + + static constexpr bool kNeedsOutputAccumulatorBuffer = + !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + // Parameters to satisfy BaseGrouped + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = accum_t; + using LayoutA = typename MM0::LayoutA; + using LayoutB = typename MM0::ElementB; + using LayoutC = typename MM1::ElementC; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static int const kAlignmentA = MM0::kAlignmentA; + static int const kAlignmentB = MM0::kAlignmentB; + static int const kAlignmentC = 1; + using Mma = typename MM1::Mma; + using EpilogueOutputOp = typename MM1::EpilogueOutputOp; + using ThreadblockSwizzle = void; + using Operator = typename MM1::Operator; + using WarpShape = typename MM1::WarpShape; + using InstructionShape = typename MM1::InstructionShape; + + using ElementQ = scalar_t; + using ElementK = scalar_t; + using ElementP = accum_t; + using ElementM = scalar_t; + using ElementV = scalar_t; + using ElementO = output_t; + using ElementOAccum = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutQ = typename MM0::LayoutA; + using LayoutK = typename MM0::LayoutB; + using LayoutP = typename MM0::LayoutC; + using LayoutM = typename MM0::LayoutC; + using LayoutV = typename MM1::LayoutB; + using LayoutO = typename MM1::LayoutC; + + static bool const kPreloadV = + (MM1::Mma::ArchTag::kMinComputeCapability >= 80 && + cutlass::sizeof_bits::value == 16); + + static int const kAlignmentQ = MM0::kAlignmentA; + static int const kAlignmentK = MM0::kAlignmentB; + static int const kAlignmentV = 1; + static int64_t const kAlignmentM = kMaskIsAligned ? kAlignmentQ : 1; + + using ThreadblockShape = typename MM0::ThreadblockShape; + + static int const kQueriesPerBlock = ThreadblockShape::kM; + static int const kKeysPerBlock = ThreadblockShape::kN; + + /// Warp count (concept: GemmShape) + using WarpCount = typename MM1::WarpCount; + static int const kThreadsPerWarp = 32; + static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; + + using ProblemVisitor = FMHAGroupedProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord *problem_sizes0; + GemmCoord *problem_sizes1; + + int problem_count; + int threadblock_count; + int num_heads; + + ElementQ *ptr_Q; + ElementK *ptr_K; + ElementP *ptr_P; + ElementM *ptr_M; + ElementV *ptr_V; + ElementO *ptr_O; + ElementOAccum *ptr_O_accum; + + typename LayoutQ::Stride::LongIndex ldq; + typename LayoutK::Stride::LongIndex ldk; + typename LayoutK::Stride::LongIndex ldm; + typename LayoutP::Stride::LongIndex ldv; + typename LayoutO::Stride::LongIndex ldo; + + int64_t kElementQ; + int64_t kElementK; + int64_t kElementM; + int64_t kElementV; + int64_t kElementO; + + // Scale + ElementAccumulator scale; + + // Whether causal masking is to be performed + bool causal; + + // Only used by device-level operator + GemmCoord *host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + num_heads(0), + ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_M(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_O_accum(nullptr), + ldq(0), + ldk(0), + ldm(0), + ldv(0), + ldo(0), + scale(0), + kElementQ(0), + kElementK(0), + kElementM(0), + kElementV(0), + kElementO(0), + causal(false), + host_problem_sizes(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord *problem_sizes0, + GemmCoord *problem_sizes1, + int problem_count, + int threadblock_count, + int num_heads, + ElementQ *ptr_Q, + ElementK *ptr_K, + ElementM *ptr_M, + ElementV *ptr_V, + ElementO *ptr_O, + ElementOAccum *ptr_O_accum, + typename LayoutQ::Stride::LongIndex ldq, + typename LayoutK::Stride::LongIndex ldk, + typename LayoutM::Stride::LongIndex ldm, + typename LayoutV::Stride::LongIndex ldv, + typename LayoutO::Stride::LongIndex ldo, + int64_t kElementQ, + int64_t kElementK, + int64_t kElementM, + int64_t kElementV, + int64_t kElementO, + bool causal, + ElementAccumulator scale, + GemmCoord *host_problem_sizes = nullptr) + : problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + threadblock_count(threadblock_count), + num_heads(num_heads), + ptr_Q(ptr_Q), + ptr_K(ptr_K), + ptr_M(ptr_M), + ptr_V(ptr_V), + ptr_O(ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum + : (accum_t *)ptr_O), // wip + ldq(ldq), + ldk(ldk), + ldm(ldm), + ldv(ldv), + ldo(ldo), + kElementQ(kElementQ), + kElementK(kElementK), + kElementM(kElementM), + kElementV(kElementV), + kElementO(kElementO), + causal(causal), + scale(scale), + host_problem_sizes(host_problem_sizes) {} + + bool __host__ check_supported() { + CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ); + CHECK_ALIGNED_PTR(ptr_K, kAlignmentK); + CHECK_ALIGNED_PTR(ptr_V, kAlignmentV); + if (ptr_M != nullptr) { + CHECK_ALIGNED_PTR(ptr_M, kAlignmentM); + XFORMERS_CHECK(ldm % kAlignmentM == 0, + "attn_mask is not correctly aligned"); + } + XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned"); + XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned"); + XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned"); + return true; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int num_heads; + + ElementQ *ptr_Q; + ElementK *ptr_K; + ElementP *ptr_P; + ElementM *ptr_M; + ElementV *ptr_V; + ElementO *ptr_O; + ElementOAccum *ptr_O_accum; + + typename LayoutQ::Stride::LongIndex ldq; + typename LayoutK::Stride::LongIndex ldk; + typename LayoutM::Stride::LongIndex ldm; + typename LayoutP::Stride::LongIndex ldv; + typename LayoutO::Stride::LongIndex ldo; + + int64_t kElementQ; + int64_t kElementK; + int64_t kElementM; + int64_t kElementV; + int64_t kElementO; + + ElementAccumulator scale; + bool causal; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_M(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_O_accum(nullptr), + ldq(0), + ldk(0), + ldm(0), + ldv(0), + ldo(0), + kElementQ(0), + kElementK(0), + kElementM(0), + kElementV(0), + kElementO(0), + causal(false), + scale(0) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const &args, void *workspace = nullptr, int tile_count = 0) + : problem_visitor(args.problem_sizes0, + args.problem_sizes1, + args.problem_count, + workspace, + tile_count), + threadblock_count(args.threadblock_count), + num_heads(args.num_heads), + ptr_Q(args.ptr_Q), + ptr_K(args.ptr_K), + ptr_P(args.ptr_P), + ptr_M(args.ptr_M), + ptr_V(args.ptr_V), + ptr_O(args.ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum + : (accum_t *)args.ptr_O), + ldq(args.ldq), + ldk(args.ldk), + ldm(args.ldm), + ldv(args.ldv), + ldo(args.ldo), + kElementQ(args.kElementQ), + kElementK(args.kElementK), + kElementM(args.kElementM), + kElementV(args.kElementV), + kElementO(args.kElementO), + causal(args.causal), + scale(args.scale) {} + + // CUTLASS_HOST_DEVICE + void update(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0, + args.problem_sizes1, + args.problem_count, + workspace, + tile_count); + threadblock_count = args.threadblock_count; + num_heads = args.num_heads; + ptr_Q = args.ptr_Q; + ptr_K = args.ptr_K; + ptr_P = args.ptr_P; + ptr_M = args.ptr_M; + ptr_V = args.ptr_V; + ptr_O = args.ptr_O; + ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum + : (accum_t *)args.ptr_O; + ldq = args.ldq; + ldk = args.ldk; + ldm = args.ldm; + ldv = args.ldv; + ldo = args.ldo; + kElementQ = args.kElementQ; + kElementK = args.kElementK; + kElementM = args.kElementM; + kElementV = args.kElementV; + kElementO = args.kElementO; + causal = args.causal; + scale = args.scale; + } + }; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::MaskLoader::SmemTile mask; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage & + epilogue_shared_storage() { + return epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::MaskLoader::SmemTile mask; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage & + epilogue_shared_storage() { + return after_mm0.epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + using SharedStorage = typename cutlass::platform::conditional< + kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + + private: + // Parameters to be used by an individual tile + struct TileParams { + CUTLASS_HOST_DEVICE + static int query_start(int threadblock_idx) { + return threadblock_idx * kQueriesPerBlock; + } + + // Returns whether this threadblock computes within the number of queries, + // which is determined by the M dimension of problem 0 + CUTLASS_HOST_DEVICE + static bool can_compute(int threadblock_idx, + const GemmCoord &problem_size0) { + return query_start(threadblock_idx) < problem_size0.m(); + } + + CUTLASS_HOST_DEVICE + static int num_queries(int threadblock_idx, + const GemmCoord &problem_size0) { + return problem_size0.m() - query_start(threadblock_idx); + } + + CUTLASS_HOST_DEVICE + static int num_keys(int threadblock_idx, + const GemmCoord &problem_size0, + bool causal) { + int nk = problem_size0.n(); + if (causal) { + nk = cutlass::fast_min( + int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk); + } + return nk; + } + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + FMHAGrouped() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const &problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + static CUTLASS_DEVICE int16_t thread_id() { return threadIdx.x; } + + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.x / kThreadsPerWarp; + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x % kThreadsPerWarp; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + auto &m_prime = shared_storage.m_prime; + auto &s_prime = shared_storage.s_prime; + [[maybe_unused]] auto &si = shared_storage.after_mm0.si; + auto &mi = shared_storage.mi; + + ProblemVisitor problem_visitor( + params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + const int32_t threadblock_idx = + int32_t(problem_visitor.threadblock_idx()); + + if (!TileParams::can_compute(threadblock_idx, problem_size0)) { + problem_visitor.advance(gridDim.x); + continue; + } + + const int32_t problem_idx = problem_visitor.problem_index(); + const int32_t batch_idx = problem_idx / params.num_heads; + + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = ElementAccumulator(0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + } + + ElementO *ptr_O = params.ptr_O + problem_idx * params.kElementQ + + TileParams::query_start(threadblock_idx) * params.ldo; + ElementOAccum *ptr_O_accum = + params.ptr_O_accum + problem_idx * params.kElementO + + TileParams::query_start(threadblock_idx) * params.ldo; + const int num_queries = + TileParams::num_queries(threadblock_idx, problem_size0); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)params.ldo}, + ptr_O, + typename OutputTileIterator::TensorCoord{num_queries, + problem_size1.n()}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = + [&](int col) -> typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)params.ldo}, + ptr_O_accum, + typename OutputTileIteratorAccum::TensorCoord{num_queries, + problem_size1.n()}, + thread_id(), + {0, col}); + }; + + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + const int num_keys = + TileParams::num_keys(threadblock_idx, problem_size0, params.causal); + + for (int32_t iter_key_start = 0; iter_key_start < num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries); + int32_t problem_size_0_n = cutlass::fast_min((int32_t)kKeysPerBlock, + num_keys - iter_key_start); + int32_t const &problem_size_0_k = problem_size0.k(); + int32_t const &problem_size_1_n = problem_size1.n(); + int32_t const &problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv)}, + params.ptr_V + problem_idx * params.kElementV + + iter_key_start * params.ldv, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + MM1::Mma::prologue(shared_storage.after_mm0.mm1.mm, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and + // `m_prime` updated from end of prev iter + + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + ElementQ *ptr_Q = params.ptr_Q + problem_idx * params.kElementQ + + TileParams::query_start(threadblock_idx) * params.ldq; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(params.ldq)), + ptr_Q, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + {0, 0}); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(params.ldk)), + params.ptr_K + problem_idx * params.kElementK + + iter_key_start * params.ldk, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + {0, 0}); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), warp_id(), lane_id()); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = {(warp_id() % MM0::Mma::WarpCount::kM), + (warp_id() / MM0::Mma::WarpCount::kM)}; + + // apply attention mask if applicable + if (kAddMask) { + accum = cutlass::multiplies()( + params.scale, accum); + // load mask tile Bij into shared memory + typename MM0::MaskLoader::GmemTileIterator mask_iter( + {cutlass::layout::RowMajor(params.ldm)}, + // attn_mask_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + params.ptr_M + batch_idx * params.kElementM + + TileParams::query_start(threadblock_idx) * params.ldm + + iter_key_start, + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef + mask_tensor_ref( + shared_storage.after_mm0.mask.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::MaskLoader::SmemTileIterator smem_tile_iter( + mask_tensor_ref, thread_id()); + MM0::MaskLoader::load(mask_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + + MM0::ScalingCoefsUpdater::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += mask_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + + // Mask out last if causal + if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) { + auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + int32_t last_col; + MM0::ScalingCoefsUpdater::iterateRows( + lane_offset, + [&](int accum_m) { + last_col = TileParams::query_start(threadblock_idx) + accum_m - + iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + num_keys - iter_key_start >= kKeysPerBlock, + kFullColumns, + ([&] { + // Update `mi` from accum stored in registers + // Also updates `accum` with accum[i] <- + // exp(accum[i] * scale + // - mi) + MM0::ScalingCoefsUpdater::update< + kQueriesPerBlock, + MM0::MmaCore::WarpCount::kCount, + MM0::MmaCore::WarpCount::kN, + kFullColumns, + kIsFirst, + kKeepOutputInRF>( + accum_o, + accum, + mi, + m_prime, + s_prime, + shared_storage.addition_storage, + lane_id(), + thread_id(), + warp_id(), + num_keys - iter_key_start, + iteratorC_tile_offset, + kAddMask ? 1.0f : params.scale); + })); + })); + + // Output results to shared-memory + int warp_idx_mn_0 = warp_id() % (MM0::Mma::Base::WarpCount::kM * + MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = + cutlass::MatrixCoord{warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = + kKeepOutputInRF ? 1 + : ceil_div((int64_t)problem_size_1_n, + int64_t(MM1::ThreadblockShape::kN)); + + // Iterate over the N dimension of GEMM1 + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = (problem_size_1_k + MM1::Mma::Shape::kK - 1) / + MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in + // accum (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv)}, + params.ptr_V + problem_idx * params.kElementV + + iter_key_start * params.ldv, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + typename MM1::Mma mma_pv(shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.si, + (int)thread_id(), + (int)warp_id(), + (int)lane_id(), + (int)problem_size_1_k); + + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = + typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = + typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + output_accum_t, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue:: + threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>:: + type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1:: + OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = gemm_kernel_utils::call_conditional< + kIsLast, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kKeepOutputInRF) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + const bool kIsFirst = true; + const bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue::thread:: + MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue(shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o); + } + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/fmha_grouped_problem_visitor.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/fmha_grouped_problem_visitor.h new file mode 100644 index 0000000000000..663107d6821c0 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/fmha_grouped_problem_visitor.h @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped FMHA +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +// Helper for correctly representing problem sizes in grouped kernels +template +struct FMHAGroupedProblemSizeHelper { + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape( + const cutlass::gemm::GemmCoord &problem) { + // FMHA only partitions tiles across the M dimension. + return cutlass::gemm::GemmCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + 1, + 1); + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord &problem) {} + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord &grid) { + return grid.m() * grid.n(); + } +}; + +} // namespace detail + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct FMHAGroupedProblemVisitor + : public GroupedProblemVisitor< + detail::FMHAGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + using ProblemSizeHelper = + detail::FMHAGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : problem_sizes0(nullptr), + problem_sizes1(nullptr), + problem_count(0), + workspace(nullptr), + tile_count(0) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Params(cutlass::gemm::GemmCoord const *problem_sizes0, + cutlass::gemm::GemmCoord const *problem_sizes1, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0) + : problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) {} + + /// Convert the FMHA-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams( // Set problem_sizes as problem_sizes1 because these + // determine shape of the final output of FMHA + problem_sizes1, + problem_count, + workspace, + tile_count); + } + }; + + // + // Methods + // + CUTLASS_DEVICE + FMHAGroupedProblemVisitor(Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx) + : Base(params_.to_base(), shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/gemm_grouped.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/gemm_grouped.h new file mode 100644 index 0000000000000..935d354a8b4d3 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/gemm_grouped.h @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Device-level grouped GEMM. +*/ + +#pragma once + +#include "base_grouped.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class GemmGrouped : public BaseGrouped { + public: + using GemmKernel = GemmKernel_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h new file mode 100644 index 0000000000000..74a08930d671d --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h @@ -0,0 +1,351 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cutlass/functional.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. +*/ + +template +struct AccumLambdaIteratorSm80 { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord(quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AccumLambdaIteratorSm70 { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord( + accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + static_assert(cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; + ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; + ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AccumLambdaIteratorSimt { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = + mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = + n + Policy::LaneMmaShape::kN * + (mma_n + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + static_assert(cutlass::platform::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = + lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, + Policy::LaneMmaShape::kN); + return lane_offset + + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultMmaAccumLambdaIterator; + +// Simt +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaSimtTileIterator, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>; + using Iterator = AccumLambdaIteratorSimt; +}; + +// TensorOp - Volta +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Iterator = AccumLambdaIteratorSm70; +}; + +// TensorOp - Sm75+ +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>; + using Iterator = AccumLambdaIteratorSm80; +}; diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/mma_from_smem.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/mma_from_smem.h new file mode 100644 index 0000000000000..642eac0e7bffe --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/mma_from_smem.h @@ -0,0 +1,2007 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/vector_iterator.h" + +#include "../epilogue/epilogue_thread_apply_logsumexp.h" +#include "../gemm/mma_accum_lambda_iterator.h" +#include "../gemm_kernel_utils.h" +#include "../iterators/make_residual_last.h" +#include "../iterators/transpose_warp_iterator.h" +#include "../iterators/warp_iterator_from_smem.h" +#include "../utils.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Shared storage object needed by accumulator +/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +template +class AccumulatorSharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = cutlass::TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = cutlass::MatrixShape; + + public: + // + // Data members + // + + /// Buffer for accumulator + cutlass::AlignedBuffer accum; + + public: + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { + return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); + } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { + return TensorRefAccum{accum.data(), LayoutAccum()}; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // Maximum value for K + int kMaxK, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaBaseFromSharedMemory { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + using WarpCount1 = WarpCount; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = kWarpGemmIterations; + + /// Number of stages + static int const kStages = Stages; + + /// If this is true, we fill the entire shmem buffer at start + /// and don't need to iterate through it in a circular fashion + static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; + + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + // + // Methods + // + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + // /// Iterator to load a warp-scoped tile of A operand from shared memory + // typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaBaseFromSharedMemory( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { + public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // using size 1 is kind of a hack to get around arrays of zero-sized objects + // not being allowed. the compiler is probably smart enough to wipe it out + // anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset( + typename TensorRef::TensorCoord const&) { + return *this; + } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { return *this; } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { + public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) { + Fragment converted_scale_frag = + cutlass::NumericArrayConverter()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { + public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { return frag; } +}; +} // namespace + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // BEGIN smem + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + // Accumulator type + typename AccumulatorSharedStorage, + // END smem + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to B operand + typename TransformB_ = + NumericArrayConverter, + /// Used for partial specialization + typename Enable = bool> +class MmaPipelinedFromSharedMemory + : public MmaBaseFromSharedMemory { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorB = SmemIteratorB_; + + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + + using WarpFragmentB = typename Operator::FragmentB; + + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = FragmentElementwiseScaler; + + protected: + // /// Iterator to write threadblock-scoped tile of A operand to shared memory + // SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to load a warp-scoped tile of A operand from intermediate + /// accumulator tile + WarpIteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp iterator over A tile held in shared memory + WarpIteratorA warp_iter_a, + // warp iterator over A_scale tile held in shared memory + WarpIteratorAScale warp_iter_a_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(warp_iter_a), + warp_tile_iterator_A_scale_(warp_iter_a_scale), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + AccumulatorSharedStorage& accumulator_shared_storage, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + // For API compatibility with MmaMultistageFromSharedMemory + // but not supported as it worsens perf: older gpus < sm80 don't + // support async tranfers and have to waste registers + CUTLASS_DEVICE + void set_prologue_done(bool value) {} + CUTLASS_DEVICE + static void prologue(typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) {} + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + // IteratorA iterator_A, ///< iterator over A + // operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + // TransformA transform_A = TransformA(), ///< transformation + // applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentB tb_frag_B; + + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_B.set_residual_tile(gemm_k_iterations == 1); + iterator_B.load(tb_frag_B); + + ++iterator_B; + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_B_; + + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; + WarpFragmentB warp_frag_B[2]; + warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); + warp_frag_B[0].clear(); + + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_B.set_residual_tile(gemm_k_iterations == 2); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + bool hasNext = true; + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory SMEM: Don't reset iterator A, as + // we are continuing our iteration at this point + if (smem_write_stage_idx == 1) { + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + hasNext = gemm_k_iterations > 1; + } + + // Only read the next if we need to + if (hasNext) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load( + warp_frag_A_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_B.load(tb_frag_B); + + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B.set_residual_tile(gemm_k_iterations == 3); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + } + + warp_mma(accum, + FragmentAScaler::apply(warp_frag_A[warp_mma_k % 2], + warp_frag_A_scale[warp_mma_k % 2]), + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + // Accumulator type + typename AccumulatorSharedStorage, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages_, + int kMaxK_, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageFromSharedMemory + : public MmaBaseFromSharedMemory { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + using IteratorB = IteratorB1; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = + WarpIteratorA1_; ///< Iterates over the intermediate + ///< accumulator tile in shared memory + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + using FragmentC = FragmentC1; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLDGSTSIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / + Base::kWarpGemmIterations1; + }; + + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireB ? Base::kStages : Base::kStages - 1; + + private: + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = FragmentElementwiseScaler; + + private: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + bool prologue_done_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp level iterator over operand A tile kept in shared memory + WarpIteratorA1 warp_tile_iterator_A1, + // warp level iterator over operand A elementwise scale tile kept in + // shared memory. + WarpIteratorAScale warp_tile_iterator_A1_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(warp_tile_iterator_A1), + warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + AccumulatorSharedStorage& accumulator_shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(accumulator_shared_storage.accum_ref(), + lane_idx), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + CUTLASS_DEVICE + void set_prologue_done(bool value) { prologue_done_ = value; } + + CUTLASS_DEVICE + static void prologue(typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) { + SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); + _prologue(iterator_B1, + (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, + smem_iterator_B1); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1(IteratorB1& iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index(group_start_B1 * + IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + CUTLASS_DEVICE + static void _prologue(IteratorB& iterator_B1, + int32_t gemm_k_iterations_1, + SmemIteratorB1& smem_iterator_B1_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + smem_iterator_B1_.set_iteration_index(0); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_1_, + ///< destination accumulator tile + FragmentC1& accum, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC1 const& src_accum) { + // 2nd Gemm + + // + // Prologue + // + // Perform accumulation in the 'd' output operand + accum = src_accum; + + if (!prologue_done_) { + _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); + } else if (!kSmemContainsEntireB) { + // Restore the iterators increments + + int gemm_k_iterations_1 = gemm_k_iterations_1_; + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + iterator_B1.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); + iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma1.transform(warp_transformed_frag_A1[0], + warp_transformed_frag_B1[0], + FragmentAScaler::apply(warp_loaded_frag_A1[0], + warp_loaded_frag_A1_scale[0]), + warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC1 tmp_accum; + + if (platform::is_same::value || + platform::is_same::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); + gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + // Load warp-level tile from accumulator fragment (A) + // or shared memory (operand B) + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations1); + // skip warp tile loading for the last kgroup (we are out of the buf) + if (gemm_k_iterations_1 > (-Base::kStages + 2) || + warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma1.transform( + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + FragmentAScaler::apply(warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), + warp_loaded_frag_B1[warp_mma_k % 2]); + + if (platform::is_same::value || + platform::is_same::value) { + warp_mma1(tmp_accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma1(accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (!kSmemContainsEntireB) { + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + } + + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform( + warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same::value || + platform::is_same::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +template +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + + using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element>; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +// Converts a "regular" Mma into their counterpart from shared memory +template +struct DefaultMmaFromSharedMemory; + +// Mma pipelined +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_, + /// Transformation applied to B operand + typename TransformB_, + typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory, + AccumulatorSharedStorage_, + kScaleOperandA, + kTransposeA> { + static constexpr int kWarpSize = 32; + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + using RegularMma = MmaPipelined; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using ArchMmaOperator = typename Policy_::Operator; + + static constexpr bool kIsTransposedA = false; + using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< + WarpShape, + InstructionShape, + typename RegularMma::Operator::IteratorA, + Policy_>::WarpIterator; + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + + using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + AccumulatorSharedStorage_, + IteratorB, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_>; +}; + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory, + AccumulatorSharedStorage_, + kScaleOperandA, + kTransposeA> { + static constexpr int kWarpSize = 32; + + using RegularMma = MmaMultistage; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory< + WarpShape, + InstructionShape, + typename RegularMma::Operator::IteratorA, + Policy_>::WarpIterator; + using WarpIteratorTranspose = TransposeWarpIterator; + static constexpr bool kIsTransposedA = + WarpIteratorTranspose::kSupportsTranspose && kTransposeA; + using WarpIteratorA = + typename platform::conditional::type; + + static int constexpr kMaxK = kIsTransposedA + ? AccumulatorSharedStorage_::Shape::kM + : AccumulatorSharedStorage_::Shape::kN; + // Reduce the number of stages if we don't need that many + static int constexpr kStagesMax = + (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); + static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); + + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + using Mma = + typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + AccumulatorSharedStorage_, + IteratorB, + SmemIteratorB_, + RegularMma::kCacheOpB, + ElementC_, + LayoutC_, + Policy_, + kStages, + kMaxK>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct B2bGemm; + +// Tensor Cores >= Sm75 specialization (Ampere ...) +template < /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>; + using FragmentC = typename IteratorC::Fragment; + using InstructionShape = InstructionShape_; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using accum_t = Element_; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + // Iterator to load accumulators (results of matmul in registers) + using FragmentIteratorAccumulator = + cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape, + InstructionShape, + accum_t, + typename Operator::Policy::Operator::FragmentC, + cutlass::layout::RowMajor>; + + // Iterator to store to shared-memory + using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + scalar_t, // accum_t, + SmemAccumulatorLayout>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + typename SmemIteratorD0::Element, + typename SmemIteratorD0::TensorLayout, + typename SmemIteratorD0::Padding>; + // We need to provide an operation for the epilogue. Let's create an + // operation that does nothing (ScaleType::Nothing), just converts + // from accum_t (float) -> scalar_t (can be half) + using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< + typename SmemIteratorD0::Element, // ElementOutput + FragmentIteratorAccumulator::Fragment::kElements, + accum_t, // ElementAccumulator + typename SmemIteratorD0::Element, // ElementCompute + cutlass::epilogue::thread::ScaleType::Nothing>; + using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + SmemIteratorD0, // ScaleBiasIterator - not used + OutputOpNoOp>; + + // Epilogue 2: with LSE (for backwards pass) + static int const kElementsPerAccess = 2; // TODO: Why 2? + using IteratorAccumulatorLSE = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + // Shape + cutlass::MatrixShape, + // WarpShape + cutlass::MatrixShape, + lse_scalar_t, + cutlass::layout::RowMajor, + kElementsPerAccess>>; + using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< + scalar_t, // ElementOutput_ + lse_scalar_t, // ElementLSE_ + accum_t, // ElementAccumulator_ + accum_t, // ElementCompute_ + 128 / cutlass::sizeof_bits::value + // FragmentIteratorAccumulator::Fragment::kElements + // InstructionShape::kM * InstructionShape::kN / 32 + >; + using EpilogueWithLSE = + cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + IteratorAccumulatorLSE, + EpilogueOpApplyLSE>; + + static void CUTLASS_DEVICE + accumToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + Epilogue epilogue; + epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); + } + + static void CUTLASS_DEVICE + accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC& accum, + lse_scalar_t const* lse, + int32_t lse_extents, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + constexpr int32_t kAlignLSE = 32; + IteratorAccumulatorLSE iterator_lse( + lse, + {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, + thread_id, + warp_id, + cutlass::MatrixCoord{0, 0} // offset + ); + + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + EpilogueWithLSE epilogue; + EpilogueOpApplyLSE minus_lse_exp({}); + epilogue(minus_lse_exp, + smem_iterator_attn, + accum, + // scale - unused + iterator_lse, + // bias + iterator_lse); + } +}; + +// Volta Specialization +// only supported for f16 +template +struct B2bGemm, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>, + Operator, + cutlass::half_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>; + using scalar_t = cutlass::half_t; + using accum_t = IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = IteratorC::Fragment; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< + WarpShape, + cutlass::gemm::GemmShape<32, 32, 4>, + scalar_t, + SmemAccumulatorLayout>; + + // // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< + 16, + 32>, // typename SmemIteratorD0::TensorLayout, + cutlass::MatrixShape<0, 0> // Padding + >; + + using OutputLayout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; + using TensorRef = cutlass::TensorRef; + using Policy = typename IteratorC::Policy; + using Element = accum_t; + // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields + // Let's copy their values + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static void CUTLASS_DEVICE + accumToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // ctor - from MmaVoltaTensorOpAccumulatorTileIterator + TensorRef ref_(shared_storage.accum_ref()); + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + cutlass::MatrixCoord lane_offset(accum_m, accum_n); + + // Tile offset + ref_.add_coord_offset(tile_coords * + cutlass::MatrixCoord({IteratorC::Shape::kRow, + IteratorC::Shape::kColumn})); + + using AccessType = cutlass::Array; + + // store - from MmaVoltaTensorOpAccumulatorTileIterator + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2; + int r = (accum_m + lane_offset.row()); + AccessType to_store; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + int c = (accum_n + n + lane_offset.column()); + to_store[n] = scalar_t(accum[idx]); + } + int c = (accum_n + lane_offset.column()); + assert(r < 32); + assert(c < 32); + *reinterpret_cast(ref_.data() + + ref_.offset({r, c})) = to_store; + } + } + } + } + } + } + } + + static void CUTLASS_DEVICE + accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = + accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +// Simt Specialization +// for f32 on Sm70-Sm75 and f16/f32 below + +template +struct B2bGemm< + cutlass::gemm::warp::MmaSimtTileIterator, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + cutlass::gemm::warp::MmaSimtTileIterator, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>; + using accum_t = typename IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = typename IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::ColumnMajor, + cutlass::MatrixShape<0, 0> // Padding + >; + + static void CUTLASS_DEVICE + accumToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + using Policy = typename IteratorC::Policy; + using Element = typename IteratorC::Element; + using Iterations = typename IteratorC::Iterations; + using Delta = typename IteratorC::Delta; + + auto ref_ = shared_storage.accum_ref(); + // ctor - MmaSimtTileIterator + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = + lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + + ref_.add_coord_offset(lane_offset); + + // Tile offset + ref_.add_coord_offset(tile_coords * + cutlass::MatrixCoord({IteratorC::Shape::kRow, + IteratorC::Shape::kColumn})); + + // store - MmaSimtTileIterator + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int r = + Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + + m; + int c = mma_n * Delta::kColumn + n; + int idx = + n + Policy::LaneMmaShape::kN * + (mma_n + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } + + static void CUTLASS_DEVICE + accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = + accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm_kernel_utils.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm_kernel_utils.h new file mode 100644 index 0000000000000..72744b7f2b0cb --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm_kernel_utils.h @@ -0,0 +1,234 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cutlass/arch/mma.h" + +//////////////////////////////////////////////////////////////////////////////// +// Some helper functions +//////////////////////////////////////////////////////////////////////////////// +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (query.scalar_type() == at::ScalarType::Float) { \ + using scalar_t = float; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + func(); \ + } else { \ + XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ + } \ + } + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + F(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + F(); \ + } \ + } +#define DISPATCH_ARCHTAG(CC, func) \ + { \ + if (CC >= 80) { \ + using ArchTag = cutlass::arch::Sm80; \ + func(); \ + } else if (CC >= 75) { \ + using ArchTag = cutlass::arch::Sm75; \ + func(); \ + } else if (CC >= 70) { \ + using ArchTag = cutlass::arch::Sm70; \ + func(); \ + } else if (CC >= 50) { \ + using ArchTag = cutlass::arch::Sm50; \ + func(); \ + } else { \ + XFORMERS_CHECK( \ + false, \ + "Your device is too old. We require compute capability >= 50"); \ + } \ + } + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK(TENSOR.stride(-1) == 1, \ + #TENSOR ": last dimension must be contiguous"); + +#ifdef TORCH_CHECK +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + XFORMERS_CHECK(uint64_t(PTR) % ALIGNMENT == 0, \ + #PTR " is not correctly aligned") +#define XFORMERS_CHECK TORCH_CHECK +#elif defined(__CUDACC_RTC__) +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + return false; \ + } +#else +#include +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + std::cerr << #PTR " is not correctly aligned\n"; \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::cerr << #COND " failed\n"; \ + return false; \ + } +#endif + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + XFORMERS_CHECK(B < std::numeric_limits::max(), \ + #B " overflows"); \ + } + +namespace gemm_kernel_utils { + +template +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) { + return ((n + m - 1) / m) * m; +} + +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below +template +struct DefaultGemmType { + static constexpr int ThreadK = 8; + static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using OpClass = cutlass::arch::OpClassSimt; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f32 +template +struct DefaultGemmType= 80>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAddFastF32; +}; + +// Specialization for tensorcores with f16/bf16 - Sm75+ +template +struct DefaultGemmType= 75 && + cutlass::sizeof_bits::value == 16>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Enables to do +// `auto x = kCondition ? fa(arg) : fb(arg)` +// when `fa` and `fb` have different types +template +struct call_conditional; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(ta(arg)) { + return ta(arg); + } +}; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(tb(arg)) { + return tb(arg); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Mark a variable as warp-uniform - enables some compiler optimizations +// The cheapest way to do it is just to broadcast it from lane 0 +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_DEVICE int32_t warp_uniform(int32_t value) { + return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); +} + +template +CUTLASS_DEVICE T* warp_uniform(T* ptr) { + struct { + union { + T* ptr; + uint32_t asInt[2]; + }; + } p; + p.ptr = ptr; + p.asInt[0] = warp_uniform(p.asInt[0]); + p.asInt[1] = warp_uniform(p.asInt[1]); + return p.ptr; +} +} // namespace gemm_kernel_utils diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/generate_variable_forward_kernels.py b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/generate_variable_forward_kernels.py new file mode 100644 index 0000000000000..8b5e6c3355265 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/generate_variable_forward_kernels.py @@ -0,0 +1,577 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# Generates combination of kernels - implementations and registry + +# Kernels are ordered (see `sort_index`), and when dispatching, +# we select the first kernel in the list that supports the inputs + +import argparse +import collections +import itertools +import os +import shutil +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Tuple, TypeVar + +DEFAULT_ARCH = [50, 70, 75, 80] +MAX_ARCH = 90 +ENABLE_MACRO = "PADDLE_WITH_MEMORY_EFFICIENT_ATTENTION" + +assert sorted(DEFAULT_ARCH) == DEFAULT_ARCH + + +def find_arch_range(min_arch, max_arch): + assert min_arch >= DEFAULT_ARCH[0] and min_arch < MAX_ARCH + assert max_arch >= DEFAULT_ARCH[0] and max_arch < MAX_ARCH + assert min_arch <= max_arch + n = len(DEFAULT_ARCH) + + start_idx = n - 1 + for i in range(n - 1): + if DEFAULT_ARCH[i] <= min_arch and min_arch < DEFAULT_ARCH[i + 1]: + start_idx = i + break + + end_idx = n + for i in range(n - 1): + if DEFAULT_ARCH[i] <= max_arch and max_arch < DEFAULT_ARCH[i + 1]: + end_idx = i + 1 + + return DEFAULT_ARCH[start_idx:end_idx] + + +def find_max_arch(arch): + arch = sorted(arch) + idx = DEFAULT_ARCH.index(arch[-1]) + if idx == len(DEFAULT_ARCH) - 1: + return MAX_ARCH + else: + return DEFAULT_ARCH[idx + 1] + + +def convert_to_arch_list(arch): + arch = arch.lower().strip() + if arch == "all": + return DEFAULT_ARCH + + arch = [int(s.strip()) for s in arch.split(';') if s.strip()] + arch = list(set(arch)) + arch.sort() + return find_arch_range(arch[0], arch[-1]) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="The argument for generating the memory efficient kernels." + ) + parser.add_argument( + "--dst_path", + type=str, + default=str(Path(__file__).parent), + help="The destination path to save the generated files.", + ) + parser.add_argument( + "--cuda_arch", + type=convert_to_arch_list, + default=convert_to_arch_list("All"), + help="The CUDA architecture to be generated.", + ) + args = parser.parse_args() + args.max_arch = find_max_arch(args.cuda_arch) + return args + + +args = parse_args() + +DTYPES = { + "f32": "float", + "f16": "cutlass::half_t", + "bf16": "cutlass::bfloat16_t", +} + +SM = args.cuda_arch + +KERNEL_IMPL_TEMPLATE = """ + +void {NAME}({CPP_CLASS} default_fmha, Params ¶ms, const phi::GPUContext& ctx) {{ + using AttentionKernel = typename decltype(default_fmha)::FMHAKernel; + using FMHA = cutlass::gemm::device::GemmGrouped; + using scalar_t = typename FMHA::GemmKernel::scalar_t; + using accum_t = typename FMHA::GemmKernel::accum_t; + using output_t = typename FMHA::GemmKernel::output_t; + using output_accum_t = typename FMHA::GemmKernel::output_accum_t; + using ElementQ = scalar_t; + using ElementK = scalar_t; + using ElementP = accum_t; + using ElementM = scalar_t; + using ElementAccumulator = accum_t; + using ElementV = scalar_t; + using ElementO = output_t; + using ElementOAccum = output_accum_t; + + int problem_count = params.num_batches * params.num_heads; + + std::vector problem_sizes1; + problem_sizes1.reserve(problem_count); + + phi::Allocator::AllocationPtr problem_sizes_device0{{nullptr}}; + phi::Allocator::AllocationPtr problem_sizes_device1{{nullptr}}; + problem_sizes_device0 = phi::memory_utils::Alloc( + ctx.GetPlace(), + problem_count * sizeof(GemmCoord), + phi::Stream(reinterpret_cast(ctx.stream()))); + problem_sizes_device1 = phi::memory_utils::Alloc( + ctx.GetPlace(), + problem_count * sizeof(GemmCoord), + phi::Stream(reinterpret_cast(ctx.stream()))); + GemmCoord* problem0_device = + reinterpret_cast(problem_sizes_device0->ptr()); + GemmCoord* problem1_device = + reinterpret_cast(problem_sizes_device1->ptr()); + get_problem_sizes<<>>( + params.seq_lens, + problem0_device, + problem1_device, + params.num_batches, + params.num_heads, + params.key_value_seq_len, + params.head_size, + params.value_head_size, + params.prompt_num); + phi::memory_utils::Copy(phi::CPUPlace(), + problem_sizes1.data(), + ctx.GetPlace(), + problem1_device, + sizeof(GemmCoord) * problem_count, + ctx.stream()); + if (AttentionKernel::kNeedsOutputAccumulatorBuffer) {{ + const int64_t output_size = params.num_batches * params.num_heads * + params.query_seq_len * params.value_head_size; + phi::Allocator::AllocationPtr tmp_output_accum_buffer_ptr{{nullptr}}; + tmp_output_accum_buffer_ptr = phi::memory_utils::Alloc( + ctx.GetPlace(), + output_size * sizeof(ElementOAccum), + phi::Stream(reinterpret_cast(ctx.stream()))); + params.output_accum_ptr = tmp_output_accum_buffer_ptr->ptr(); + }} + int threadblock_count = + FMHA::sufficient(problem_sizes1.data(), problem_count); + typename FMHA::Arguments args( + problem0_device, + problem1_device, + problem_count, + threadblock_count, + params.num_heads, + const_cast(reinterpret_cast(params.query_ptr)), + const_cast(reinterpret_cast(params.key_ptr)), + params.mask_ptr + ? const_cast(reinterpret_cast(params.mask_ptr)) + : nullptr, + const_cast(reinterpret_cast(params.value_ptr)), + reinterpret_cast(params.output_ptr), + AttentionKernel::kNeedsOutputAccumulatorBuffer + ? reinterpret_cast(params.output_accum_ptr) + : nullptr, + params.ldq, + params.ldk, + params.ldm, + params.ldv, + params.ldo, + params.ElementQ, + params.ElementK, + params.ElementM, + params.ElementV, + params.ElementO, + params.causal, + params.scale, + problem_sizes1.data()); + + FMHA fmha; + cutlass::Status status; + size_t workspace_size = fmha.get_workspace_size(args); + phi::DenseTensor workspace; + workspace.Resize(phi::make_ddim({{static_cast(workspace_size)}})); + ctx.template Alloc(&workspace); + status = fmha.initialize(args, workspace.data()); + if (status != cutlass::Status::kSuccess) {{ + PADDLE_THROW(phi::errors::Unimplemented( + "Failed to initialize CUTLASS Grouped FMHA kernel.")); + }} + status = fmha.run(ctx.stream()); + if (status != cutlass::Status::kSuccess) {{ + PADDLE_THROW(phi::errors::Unimplemented( + "Failed to run CUTLASS Grouped FMHA kernel.")); + }} +}} +""" + + +@dataclass(order=True) +class FwdKernel: + sort_index: Tuple[int, ...] = field(init=False, repr=False) + aligned: bool + mask_aligned: bool + dtype: str + sm_range: Tuple[int, int] + q: int + k: int + single_value_iter: bool + support_mask: bool = True + mask_broadcast: bool = False + dispatch_cond: Optional[str] = None + + def __post_init__(self) -> None: + # Set kernel selection priority + # The lowest value that matches inputs + # will be selected + self.sort_index = ( + # First select aligned kernel + 0 if self.aligned else 1, + 0 if self.support_mask else 1, + # Then keep output in RF + 0 if self.single_value_iter else 1, + self.q, + 0 if self.mask_aligned else 1, + 0 if self.mask_broadcast else 1, + ) + + @property + def _aligned_suffix(self) -> str: + return "aligned" if self.aligned else "notaligned" + + @property + def _mask_aligned_suffix(self) -> str: + return "ma" if self.mask_aligned else "mua" + + @property + def _mask_support_suffix(self) -> str: + return "sm" if self.support_mask else "usm" + + @property + def _mask_broadcast_suffix(self) -> str: + return "mb" if self.mask_broadcast else "mnb" + + @property + def _single_value_suffix(self) -> str: + return "rf" if self.single_value_iter else "urf" + + @property + def name(self) -> str: + return f"fmha_cutlassF_variable_{self.dtype}_{self._aligned_suffix}_{self.q}x{self.k}_{self._single_value_suffix}_{self._mask_support_suffix}_{self._mask_aligned_suffix}_sm{self.sm_range[0]}" + + @property + def cpp_class(self) -> str: + template_args = ", ".join( + [ + DTYPES[self.dtype], + f"cutlass::arch::Sm{self.sm_range[0]}", + "true" if self.aligned else "false", + "true" if self.mask_aligned else "false", + str(self.q), + str(self.k), + "true" if self.single_value_iter else "false", + "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly", + "true" if self.support_mask else "false", + "false", + ] + ) + return f"cutlass::gemm::kernel::DefaultFMHAGrouped<{template_args}>" + + @property + def impl_group(self) -> str: + # Maps to file which will contain the implementation + return f"{self.dtype}_{self._aligned_suffix}_{self._mask_support_suffix}_{self._mask_aligned_suffix}_{self._mask_broadcast_suffix}_{self._single_value_suffix}_{self.q}x{self.k}" + + @property + def cpp_impl(self) -> str: + return KERNEL_IMPL_TEMPLATE.format( + CPP_CLASS=self.cpp_class, NAME=self.name + ) + + @classmethod + def get_all(cls) -> List["FwdKernel"]: + kernels: List[FwdKernel] = [] + for aligned, dtype, (sm, sm_max) in itertools.product( + [True, False], DTYPES.keys(), zip(SM, SM[1:] + [args.max_arch]) + ): + # Remove some kernels we don't use + if dtype == "bf16" and sm < 80: + continue + if not aligned and sm >= 80: + continue + for q, k, single_value_iter in [ + (32, 128, True), + (32, 128, False), + (64, 64, True), + ]: + for support_mask, mask_aligned in [ + (False, False), + (True, False), + (True, True), + ]: + kernels.append( + cls( + aligned=aligned, + dtype=dtype, + sm_range=(sm, sm_max), + q=q, + k=k, + single_value_iter=single_value_iter, + support_mask=support_mask, + mask_aligned=mask_aligned, + mask_broadcast=False, + ) + ) + return kernels + + +T = TypeVar("T", bound=FwdKernel) + + +def write_decl_impl( + kernels: List[T], family_name: str, impl_file: str, enable_def: str +) -> None: + cpp_file_header = """// This file is auto-generated. See "generate_variable_forward_kernels.py" +""" + + kernels.sort() + + implfile_to_kernels: Dict[str, List[T]] = collections.defaultdict(list) + cat_to_kernels: Dict[ + Tuple[str, int, int], List[T] + ] = collections.defaultdict(list) + + dispatch_all = "" + declarations = cpp_file_header + "#pragma once\n" + declarations += f"#ifdef {enable_def}\n" + declarations += f"""#include "{impl_file}"\n""" + declarations += "namespace phi {\n" + + # Declaration of kernel functions + for k in kernels: + implfile_to_kernels[k.impl_group].append(k) + cat_to_kernels[(k.dtype, k.sm_range[0], k.sm_range[1])].append(k) + + for (cat_dt, cat_sm, cat_sm_max), kernels in cat_to_kernels.items(): + declarations += f"// ======== {cat_dt} / sm{cat_sm} ========\n" + declarations += "\n".join( + k.cpp_impl.split("{")[0].rstrip() + ";" for k in kernels + ) + dispatch_category_fn = f"dispatch_{family_name}_{cat_dt}_sm{cat_sm}" + declarations += ( + f"\n\ntemplate void {dispatch_category_fn}(T cb) {{\n" + ) + for k in kernels: + _call = f"cb({k.cpp_class}(), {k.name});\n" + if k.dispatch_cond is not None: + _call = f"if ({k.dispatch_cond}) {_call}" + declarations += f" {_call}" + declarations += "}\n\n" + dispatch_all += f""" + if (std::is_same::value && {cat_sm} <= cc && cc < {cat_sm_max}) {{ + {dispatch_category_fn}(cb); + }}""" + + declarations += f""" +template +void dispatch_{family_name}(const ::phi::GPUContext &ctx, T cb) {{ + auto cc = ctx.GetComputeCapability(); + PADDLE_ENFORCE_GE( + cc, + 70, + phi::errors::InvalidArgument("the Nvidia GPU's Compute Capability must be greater or equal than 70")); + + using DT = typename ::phi::CutlassTrait::Type; +{dispatch_all} +}} +""" + declarations += "} // namespace phi\n" + declarations += f"#endif // {enable_def}\n" + + autogen_dir = Path(args.dst_path) / "autogen" + os.makedirs(autogen_dir, exist_ok=True) + declaration_path = autogen_dir / f"{family_name}.h" + declaration_path.write_text(declarations) + + for f, f_kernels in implfile_to_kernels.items(): + impl_cu = cpp_file_header + impl_cu += f"#ifdef {enable_def}\n" + impl_cu += f"""#include "{impl_file}"\n""" + impl_cu += "namespace phi {\n" + for k in f_kernels: + impl_cu += k.cpp_impl + impl_cu += "} // namespace phi\n" + impl_cu += f"#endif // {enable_def}\n" + impl_path = autogen_dir / "impl" + os.makedirs(impl_path, exist_ok=True) + (impl_path / f"{family_name}_{f}.cu").write_text(impl_cu) + + +def write_main_header(): + main_header_content = ''' +#pragma once + +#ifdef {} + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "cutlass/util/device_memory.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/default_fmha_grouped.h" +#include "paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/gemm/gemm_grouped.h" + +namespace phi {{ + +using GemmCoord = cutlass::gemm::GemmCoord; + +struct Params {{ + // meta params + phi::DataType datatype; + + // [bs, nh, seq_len, dh] + const void* query_ptr; + const void* key_ptr; + const void* value_ptr; + + // and it can be broadcasted in axis0, 1, 2. + const void* mask_ptr = nullptr; + + const int* seq_lens = nullptr; + + // Output tensors + void* output_ptr; // [num_batches, num_heads, query_seq_len, head_size] + void* output_accum_ptr = + nullptr; // [num_batches, num_heads, query_seq_len, head_size] + + // Scale + float scale; + + // Dimensions/strides + int32_t num_batches; + int32_t num_heads; + int32_t query_seq_len; + int32_t key_value_seq_len; + int32_t head_size; + int32_t value_head_size; + + int64_t ldq; + int64_t ldk; + int64_t ldm; + int64_t ldv; + int64_t ldo; + + int64_t ElementQ; + int64_t ElementK; + int64_t ElementM; + int64_t ElementV; + int64_t ElementO; + + int prompt_num = 0; + + bool causal; + bool mask_broadcast_row; +}}; + +__global__ static void get_problem_sizes(const int* seq_lens, + GemmCoord* problem_sizes0, + GemmCoord* problem_sizes1, + const int bs, + const int num_head, + const int kv_seq_len, + const int head_size, + const int value_head_size, + const int prompt_num) {{ + int bi = blockIdx.x; + int hi = threadIdx.x; + if (bi < bs && hi < num_head) {{ + int id = bi * num_head + hi; + int m = seq_lens[bi]; + int mkv = m + prompt_num; + int k0 = head_size; + int k1 = value_head_size; + GemmCoord problem0(m, mkv, k0); + GemmCoord problem1(m, k1, mkv); + problem_sizes0[id] = problem0; + problem_sizes1[id] = problem1; + }} +}} + +template +struct CutlassTrait {{ + using Type = T; +}}; + +template <> +struct CutlassTrait {{ + using Type = cutlass::half_t; +}}; + +template <> +struct CutlassTrait {{ + using Type = cutlass::bfloat16_t; +}}; + + +template +struct ToPhiDTypeTrait {{ + private: + using NonConstT = typename std::remove_const::type; + static constexpr bool kIsFP16 = std::is_same::value; + static constexpr bool kIsBF16 = std::is_same::value; + + public: + using Type = typename std::conditional::type>::type; +}}; + +}} // namespace phi + +#include "./cutlass_forward.h" + +#endif +'''.format( + ENABLE_MACRO + ) + + path = Path(args.dst_path) / "autogen" + os.makedirs(path, exist_ok=True) + path = Path(path) / "memory_efficient_variable_attention.h" + path.write_text(main_header_content) + + +if os.path.exists(Path(args.dst_path) / "autogen"): + shutil.rmtree(Path(args.dst_path) / "autogen") +forward_impl = "paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/autogen/memory_efficient_variable_attention.h" + +write_main_header() + +write_decl_impl( + FwdKernel.get_all(), + "cutlass_forward", + impl_file=forward_impl, + enable_def=ENABLE_MACRO, +) diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h new file mode 100644 index 0000000000000..56c01e0961a7c --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h @@ -0,0 +1,713 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue iterator that supports prefetching + Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in +/// epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | +/// ForwardTileIterator +/// +template +class PredicatedTileIteratorPrefetch { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, + "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, + "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, + "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = + Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, + "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPrefetch(PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) { + TensorCoord thread_offset = + ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void prefetch_all() { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kIterations; ++iter) { + prefetch(); + ++(*this); + } + } + + CUTLASS_DEVICE + void prefetch() { + uint8_t* byte_pointer = byte_pointer_; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + // on windows using unsigned long here gives the error + // error: asm operand type size(4) does not match + // type/size implied by constraint 'l' + uint64_t addr = (uint64_t)(( + void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + + 2 * output_Q + add_Q; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) row_add_P = 0; + if (output_Q > convolution_Q - 2) row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPrefetch& operator++() { + ++state_[0]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_row; + } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * + ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +template +struct MakePrefetchableIterator { + using Iterator = PredicatedTileIteratorPrefetch; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/make_residual_last.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/make_residual_last.h new file mode 100644 index 0000000000000..b4967c01ca194 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/make_residual_last.h @@ -0,0 +1,74 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "predicated_tile_access_iterator_residual_last.h" +#include "predicated_tile_iterator_residual_last.h" + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileIteratorResidualLast; +}; + +template +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h new file mode 100644 index 0000000000000..545d9b45f51e0 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h @@ -0,0 +1,1951 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorResidualLast +/// +template +class PredicatedTileAccessIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear +/// data. +/// +template +class PredicatedTileAccessIteratorResidualLast { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = + PredicatedTileAccessIteratorPredicates; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the " + "access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : Base(layout.stride(0), + MakePredicatedTileAccessIteratorDesc()()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + /// Parameters object with precomputed internal state + Params const& params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset seperated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + /// + + /// Gather indices + int const* indices_; + + Index gather_offset_strided; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + indices_(indices) { + the_predicates.set_predicates(thread_id, threadblock_offset); + the_predicates.get_mask(residual_tile_mask); + + // Working around a weird compiler bug happening on P100 for the backward. + // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) + // residual_tile_mask[0] = 15 (correct) + // + // Adding prints when the value is calculated (in `compute_predicates_`) + // sometimes removes the bug. The consequence is that we skip some + // element of a tensor, leading to wrong results + // Setting `compute_predicates_`'s second argument (`is_steady_state`) to + // true also seems to get rid of the bug - at the cost of twice as many + // comparisons. +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) + constexpr bool kWorkAroundCompilerBug = false; +#else + constexpr bool kWorkAroundCompilerBug = true; +#endif + the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + gather_offset_strided = the_predicates.thread_offset_.strided(); + add_pointer_offset( + layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + } + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (!Gather) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + gather_offset_strided += Shape::kStrided * tile_offset.strided(); + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + if (Gather) { + assert(indices_); + + if (!valid()) { + return nullptr; + } + + LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * + sizeof_bits::value / 8) + + the_predicates.iteration_vector_; + int strided_index = + gather_offset_strided + + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + + LongIndex strided_offset = indices_[strided_index] * + LongIndex(params_.stride_) * + sizeof_bits::value / 8; + + return reinterpret_cast(pointer_ + contiguous_offset + + strided_offset); + } + + return reinterpret_cast(pointer_ + + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * + sizeof_bits::value) / + 8) + + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather) { + pointer_ += params_.inc_strided_; + } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, + // this subtraction as well as the subsequent integer addition are both + // elided by the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { the_predicates.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { the_predicates.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { the_predicates.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return the_predicates.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType, + Gather>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = + PredicatedTileAccessIteratorPredicates; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the " + "access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorResidualLast; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + // Default ctor + CUTLASS_HOST_DEVICE + Params() + : stride_(0), + inc_contiguous_(0), + inc_strided_(0), + inc_next_(0), + inc_advance_(0) {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = + (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = + inc_strided_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) * + sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = + Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = + inc_advance_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - + LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const& params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent) { + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(pointer_) + + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { the_predicates.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { the_predicates.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { the_predicates.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return the_predicates.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// interleaved data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h new file mode 100644 index 0000000000000..2a274944b42a3 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h @@ -0,0 +1,1946 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 + tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorResidualLast +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize +/// register liveness and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" +/// object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is +/// constructed. Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator +/// is constructed. Subsequent additions to logical coordinate offset may be +/// performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be +/// partially full in both the advance dimension and the steady-state dimension. +/// This is assumed to be the last tile in the iteration sequence. Advancing an +/// iterator that has just been constructed moves to the first tile that is full +/// in the advance dimension and recomputes predicates. Subsequent accesses may +/// be performed without updating internal predicates and are efficient in terms +/// of live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced +/// at least once outside any looping structure to minimize integer arithmetic. +/// +/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to +/// dereferencing the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update +// internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - +// subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to +// steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = +// transform::threadblock::PredicatedTileIteratorResidualLast; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template +class PredicatedTileIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = + AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIteratorResidualLast; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : address_iterator_(params.params_, + pointer, + extent, + thread_id, + threadblock_offset, + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset(frag, + pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * + (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset(frag, + pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * + (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< Gather indices + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = + AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIteratorResidualLast; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : address_iterator_( + params.params_, pointer, extent, thread_id, threadblock_offset) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset(frag, + pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * + (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset(frag, + pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * + (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 +/// data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/transpose_warp_iterator.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/transpose_warp_iterator.h new file mode 100644 index 0000000000000..07954557c2f31 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/transpose_warp_iterator.h @@ -0,0 +1,36 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "warp_iterator_from_smem.h" + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp::WarpIteratorFromSmem> { + using Iterator = + cutlass::gemm::warp::WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/warp_iterator_from_smem.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/warp_iterator_from_smem.h new file mode 100644 index 0000000000000..b36dbd89bd749 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/iterators/warp_iterator_from_smem.h @@ -0,0 +1,275 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + This is only implemented for the specific shapes that are interesting for us +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + bool kTranspose = false> +class WarpIteratorFromSmem { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + + /// Basic check + static_assert(kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B " + "operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = cutlass::MatrixShape<16, 8>; + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 + : 32 / sizeof_bits::value); + + using InstructionCount = + MatrixShape; + + static int const kIterations = (kOperand == Operand::kA) + ? InstructionCount::kColumn + : InstructionCount::kRow; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn + : InstructionShape::kRow); + static int constexpr kAccessesInner = + (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + + private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + + public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {} + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) { + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert( + InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4, + ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; + ++inst_m_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; + ++access_m_idx) { + int access_idx = + access_m_idx + kTilesPerInstruction * + (inner_idx + kAccessesInner * inst_m_idx); + + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + } else { + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; + ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset(inner_idx * 4 * kElementsPerAccess, + inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) { + TensorCoord coord_offset(tile_offset.row() * Shape::kRow, + tile_offset.column() * Shape::kColumn); + if (kTranspose) { + coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; + } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() { + iterations_++; + + if (iterations_ >= kIterations) advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = typename platform:: + conditional::type; + + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + cutlass::arch::ldsm(access_ptr[0], + ref_.data() + ref_.offset(offset)); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/kernel_forward.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/kernel_forward.h new file mode 100644 index 0000000000000..151d644745c16 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/kernel_forward.h @@ -0,0 +1,1067 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "debug_utils.h" +#include "epilogue/epilogue_pipelined.h" +#include "epilogue/epilogue_rescale_output.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" +#include "gemm_kernel_utils.h" +#include "transform/tile_smem_loader.h" + +using namespace gemm_kernel_utils; + +namespace { +template +constexpr int getWarpsPerSm() { + return (Arch::kMinComputeCapability >= 80 && + !cutlass::platform::is_same::value + ? 16 + : 12); +} +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} // namespace + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + bool maskIsAligned_, + int kQueriesPerBlock, + int kKeysPerBlock, + bool kSingleValueIteration, // = `value.shape[-1] <= kKeysPerBlock` + bool kAddMask, + // This is quite faster when mask need broadcast at row axis + bool kMaskBroadcastRow = true> +struct AttentionKernel { + using scalar_t = scalar_t_; + using accum_t = float; + + using lse_scalar_t = float; + using output_t = scalar_t; + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + static constexpr bool kIsAligned = isAligned_; + static constexpr bool kMaskIsAligned = maskIsAligned_; + + static constexpr int32_t kAlignLSE = 32; // block size of backward + static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 && + cutlass::sizeof_bits::value == 16; + static constexpr bool kKeepOutputInRF = kSingleValueIteration; + static constexpr bool kNeedsOutputAccumulatorBuffer = + !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + static_assert(kQueriesPerBlock % 32 == 0, ""); + static_assert(kKeysPerBlock % 32 == 0, ""); + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (32 * 32); + static constexpr int kWarpSize = 32; + + // Launch bounds + static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int kMinBlocksPerSm = + getWarpsPerSm() / kNumWarpsPerBlock; + + struct Params { + // Input tensors + scalar_t* query_ptr; // [num_queries, num_heads, head_dim] + scalar_t* key_ptr; // [num_keys, num_heads, head_dim] + scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] + scalar_t* mask_ptr = nullptr; // [num_heads, num_queries, num_keys] + int32_t* cu_seqlens_q_ptr = nullptr; + int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + output_t* output_ptr; // [num_queries, num_heads, head_dim_value] + output_accum_t* + output_accum_ptr; // [num_queries, num_heads, head_dim_value] + lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null + + // Scale + accum_t scale; + + // Dimensions/strides + int32_t head_dim; + int32_t head_dim_value; + int32_t num_queries; + int32_t num_keys; + + bool causal; + bool mask_broadcast_row; + + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + int32_t mask_strideM = 0; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + int32_t o_strideH = 0; + int32_t mask_strideH = 0; + + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + int64_t o_strideB; + int32_t mask_strideB = 0; + + int32_t num_batches; + int32_t num_heads; + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { + // return head_dim_value * num_heads; + return head_dim_value; + } + // Moves pointers to what we should process + // Returns "false" if there is no work to do + CUTLASS_DEVICE bool advance_to_block() { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + int64_t q_start, k_start; + // Advance to current batch - in case of different sequence lengths + if (cu_seqlens_q_ptr != nullptr) { + assert(cu_seqlens_k_ptr != nullptr); + cu_seqlens_q_ptr += batch_id; + cu_seqlens_k_ptr += batch_id; + q_start = cu_seqlens_q_ptr[0]; + k_start = cu_seqlens_k_ptr[0]; + int64_t q_next_start = cu_seqlens_q_ptr[1]; + int64_t k_next_start = cu_seqlens_k_ptr[1]; + num_queries = q_next_start - q_start; + num_keys = k_next_start - k_start; + + if (query_start >= num_queries) { + return false; + } + } else { + query_ptr += batch_id * q_strideB; + key_ptr += batch_id * k_strideB; + value_ptr += batch_id * v_strideB; + // output_ptr += int64_t(batch_id * num_queries) * o_strideM(); + output_ptr += batch_id * o_strideB; + ; + + if (output_accum_ptr != nullptr) { + // output_accum_ptr += int64_t(batch_id * num_queries) * o_strideM(); + output_accum_ptr += batch_id * o_strideB; + } + q_start = 0; + k_start = 0; + } + + // Advance to the current batch / head / query_start + query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + head_id * k_strideH; + value_ptr += k_start * v_strideM + head_id * v_strideH; + // output_ptr += int64_t(q_start + query_start) * o_strideM() + + // head_id * head_dim_value; + output_ptr += + int64_t(q_start + query_start) * o_strideM() + head_id * o_strideH; + + if (mask_ptr != nullptr) { + mask_ptr += (batch_id * mask_strideB) + (head_id * mask_strideH); + } + if (output_accum_ptr != nullptr) { + // output_accum_ptr += int64_t(q_start + query_start) * o_strideM() + + // head_id * head_dim_value; + + output_accum_ptr += + int64_t(q_start + query_start) * o_strideM() + head_id * o_strideH; + } else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += + batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + num_queries -= query_start; + if (causal) { + num_keys = cutlass::fast_min(int32_t(query_start + kQueriesPerBlock), + num_keys); + } + num_batches = 0; // no longer used after + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + mask_ptr = warp_uniform(mask_ptr); + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3(ceil_div(num_queries, (int32_t)kQueriesPerBlock), + num_heads, + num_batches); + } + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize, kNumWarpsPerBlock, 1); + } + }; + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::ColumnMajor, // LayoutB, + kAlignmentB, + accum_t, + cutlass::layout::RowMajor, // LayoutC, + OpClass, + ArchTag, // ArchTag + ThreadblockShape, // ThreadblockShape + WarpShape, // WarpShape + typename GemmType::InstructionShape, // InstructionShape + DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but + // that uses too much smem + typename GemmType::Operator // Operator + >::DefaultMma; + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using Mma = typename DefaultMma::ThreadblockMma; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + static_assert(MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * + MmaCore::WarpCount::kK == + kNumWarpsPerBlock, + ""); + + // used for efficient load of mask_ tile Bij from global to shared memory + using MaskLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + kMaskIsAligned ? 128 / cutlass::sizeof_bits::value : 1>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + output_accum_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + LayoutB, // LayoutB, + kAlignmentB, + output_accum_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MM0::AccumulatorSharedStorage, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert(WarpCount::kM * WarpCount::kN * WarpCount::kK == + kNumWarpsPerBlock, + ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + + struct SharedStorageMM1 { + typename Mma::SharedStorage mm; + }; + }; + + static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; + static constexpr int64_t kAlignmentK = MM0::kAlignmentB; + static constexpr int64_t kAlignmentV = 1; + static constexpr int64_t kAlignmentM = kMaskIsAligned ? kAlignmentQ : 1; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::MaskLoader::SmemTile mask; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::MaskLoader::SmemTile mask; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + }; + + using SharedStorage = typename cutlass::platform::conditional< + kSingleValueIteration || kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); + CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); + CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + CHECK_ALIGNED_PTR(p.mask_ptr, kAlignmentM); + if (p.mask_ptr != nullptr) { + CHECK_ALIGNED_PTR(p.mask_ptr, kAlignmentM); + XFORMERS_CHECK(p.mask_strideB % kAlignmentM == 0, + "attn_mask is not correctly aligned"); + XFORMERS_CHECK(p.mask_strideH % kAlignmentM == 0, + "attn_mask is not correctly aligned"); + XFORMERS_CHECK(p.mask_strideM % kAlignmentM == 0, + "attn_mask is not correctly aligned"); + } + XFORMERS_CHECK(p.q_strideM % kAlignmentQ == 0, + "query is not correctly aligned"); + XFORMERS_CHECK(p.k_strideM % kAlignmentK == 0, + "key is not correctly aligned"); + XFORMERS_CHECK(p.v_strideM % kAlignmentV == 0, + "value is not correctly aligned"); + XFORMERS_CHECK(p.q_strideH % kAlignmentQ == 0, + "query is not correctly aligned"); + XFORMERS_CHECK(p.k_strideH % kAlignmentK == 0, + "key is not correctly aligned"); + XFORMERS_CHECK(p.v_strideH % kAlignmentV == 0, + "value is not correctly aligned"); + return true; + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) { + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& si = shared_storage.after_mm0.si; + auto& mi = shared_storage.mi; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)p.o_strideM()}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{p.num_queries, + p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{p.num_queries, + p.head_dim_value}, + thread_id(), + {0, col}); + }; + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = cutlass::fast_min(int32_t(kKeysPerBlock), + p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue(shared_storage.after_mm0.mm1.mm, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * MM0::Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{ + tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(p.q_strideM)), + p.query_ptr, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(p.k_strideM)), + p.key_ptr + iter_key_start * p.k_strideM, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_id(); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + if (kAddMask) { + accum = + cutlass::multiplies()(p.scale, accum); + } + + int32_t mask_iter_m = kMaskBroadcastRow ? 1 : problem_size_0_m; + + // apply attention mask if applicable + if (kAddMask) { + // load mask tile Bij into shared memory + typename MM0::MaskLoader::GmemTileIterator mask_iter( + {cutlass::layout::RowMajor(p.mask_strideM)}, + // attn_mask_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + p.mask_ptr + query_start * p.mask_strideM + iter_key_start, + {mask_iter_m, problem_size_0_n}, + thread_id()); + + cutlass::TensorRef mask_tensor_ref( + shared_storage.after_mm0.mask.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::MaskLoader::SmemTileIterator smem_tile_iter( + mask_tensor_ref, thread_id()); + MM0::MaskLoader::load(mask_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + if (kMaskBroadcastRow) { + accum[idx] += mask_tensor_ref.at({0, accum_n}); + } else { + accum[idx] += mask_tensor_ref.at({accum_m, accum_n}); + } + } + }, + [&](int accum_m) {}); + } + + // Mask out last if causal + if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + int32_t last_col; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + last_col = query_start + accum_m - iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + p.num_keys - iter_key_start >= kKeysPerBlock, + kFullColumns, + ([&] { + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + shared_storage.addition_storage, + lane_id(), + thread_id(), + warp_id(), + p.num_keys - iter_key_start, + iteratorC_tile_offset, + // (p.mask_ptr != nullptr) ? 1.0f : p.scale); + kAddMask ? 1.0f : p.scale); + // 1.0f); + })); + })); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % (MM0::Mma::Base::WarpCount::kM * + MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = + cutlass::MatrixCoord{warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = + kSingleValueIteration ? 1 + : ceil_div((int64_t)problem_size_1_n, + int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv(shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.si, + (int)thread_id(), + (int)warp_id(), + (int)lane_id(), + (int)problem_size_1_k); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = + typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = + call_conditional:: + apply( + createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue(shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + if (thread_id() < p.num_queries) { + p.logsumexp_ptr[thread_id()] = + accum_t(mi[thread_id()]) + + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = + cutlass::platform::numeric_limits::infinity(); + } + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + cutlass::Array& + addition_storage, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int16_t max_col, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + if (!kIsFirst) { + if (thread_id < kQueriesPerBlock) { + m_prime[thread_id] = mi[thread_id]; + } + __syncthreads(); + } + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (kFullColumns || accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max * scaling); + }); + } + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (thread_id < kQueriesPerBlock) { + auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); + m_prime[thread_id] = m_prime_exp; + s_prime[thread_id] *= m_prime_exp; + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !kIsFirst) { + accum_t mp; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mp = m_prime[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) {}); + __syncthreads(); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = (kFullColumns || accum_n < max_col) + ? exp2f(frag[idx] - mi_row) + : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + addition_storage[accum_m + kQueriesPerBlock * + tile_offset.column()] = total_row; + } + }); + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + total_row = s_prime[id]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } + } + + static CUTLASS_DEVICE int8_t lane_id() { return threadIdx.x; } + static CUTLASS_DEVICE int8_t warp_id() { return threadIdx.y; } + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x + threadIdx.y * blockDim.x; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/transform/tile_smem_loader.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/transform/tile_smem_loader.h new file mode 100644 index 0000000000000..7dee7c97a8ba3 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/transform/tile_smem_loader.h @@ -0,0 +1,71 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +template // thread access width in elements +class TileSmemLoader { + public: + using SmemTile = + cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape< + ThreadblockTileShape::kColumn, // contiguous + ThreadblockTileShape::kRow>, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load(GmemTileIterator tile_load_iter, + SmemTileIterator tile_store_iter) { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/utils.h b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/utils.h new file mode 100644 index 0000000000000..3081755fd5a0b --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/utils.h @@ -0,0 +1,46 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cutlass/platform/platform.h" +namespace cutlass { +namespace platform { + +// template< class To, class From > +// constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& from ) noexcept; + +// template +// constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& src) noexcept +// { +// static_assert(sizeof(To) == sizeof(From), "sizes must match"); +// return reinterpret_cast(src); +// } + +// template <> +// struct numeric_limits { +// CUTLASS_HOST_DEVICE +// static constexpr float infinity() noexcept { return bit_cast(0x7f800000);} static constexpr bool is_integer = false; static +// constexpr bool has_infinity = true; +// }; + +// template <> +// struct numeric_limits { +// CUTLASS_HOST_DEVICE +// static const cutlass::half_t infinity() noexcept { return +// bit_cast(0x7800);} static constexpr bool +// is_integer = false; static constexpr bool has_infinity = true; +// }; + +} // namespace platform +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention_variable.cu b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention_variable.cu new file mode 100644 index 0000000000000..1215f2c005de0 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention_variable.cu @@ -0,0 +1,298 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/autogen/cutlass_forward.h" +#include "paddle/phi/kernels/fusion/fused_multihead_attention_variable_kernel.h" + +namespace phi { +namespace fusion { + +template +void MultiHeadAttentionVariableForwardKernel( + const Context& ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const DenseTensor& seq_lens, + const paddle::optional& mask, + const float scale, + const bool causal, + DenseTensor* output) { + ctx.template Alloc(output); + Params params{}; + // [B, N, S, H] + params.seq_lens = seq_lens.data(); + + params.num_batches = query.dims()[0]; + params.num_heads = query.dims()[1]; + params.query_seq_len = query.dims()[2]; + params.head_size = query.dims()[3]; + params.key_value_seq_len = key.dims()[2]; + params.value_head_size = value.dims()[3]; + + params.datatype = query.dtype(); + params.query_ptr = query.data(); + params.key_ptr = key.data(); + params.value_ptr = value.data(); + + params.output_ptr = output->data(); + + params.ldq = params.head_size; + params.ldk = params.head_size; + params.ldv = params.value_head_size; + params.ldo = params.value_head_size; + + params.ElementQ = params.query_seq_len * params.head_size; + params.ElementK = params.key_value_seq_len * params.head_size; + params.ElementV = params.key_value_seq_len * params.value_head_size; + params.ElementO = params.query_seq_len * params.value_head_size; + + params.scale = scale; + params.causal = causal; + + if (mask) { + // [B, 1, S, D] + auto mask_tensor = mask.get(); + params.ldm = mask_tensor.dims()[3]; + params.ElementM = mask_tensor.dims()[2] * mask_tensor.dims()[3]; + params.mask_ptr = mask_tensor.data(); + params.mask_broadcast_row = false; + } + + bool kernel_launched = false; + + auto launchKernel = [&](auto k_, auto kernel_fn) { + using KernelType = decltype(k_); + if (kernel_launched) { + return; + } + if (mask && !KernelType::kAddMask) { + return; + } + if (!mask && KernelType::kAddMask) { + return; + } + if (KernelType::kMaskBroadcastRow) { + // not support mask_broad_cast + return; + } + if (params.mask_ptr && + reinterpret_cast(params.mask_ptr) % 16 == 0 && + params.ldm % (16 / sizeof(T)) == 0 && !KernelType::kMaskIsAligned) { + return; + } + if (params.mask_ptr && + !(reinterpret_cast(params.mask_ptr) % 16 == 0 && + params.ldm % (16 / sizeof(T)) == 0) && + KernelType::kMaskIsAligned) { + return; + } + if (KernelType::kSingleValueIteration && + KernelType::kKeysPerBlock < params.value_head_size) { + return; + } + if (KernelType::kKeysPerBlock == 64 && params.value_head_size > 64) { + return; + } + if (params.head_size % KernelType::MM0::kAlignmentA) { + return; + } + kernel_launched = true; + kernel_fn(k_, params, ctx); + }; + dispatch_cutlass_forward(ctx, launchKernel); + PADDLE_ENFORCE_EQ( + kernel_launched, + true, + phi::errors::InvalidArgument("the kernel should not be launched")); +} + +template +void MultiHeadAttentionVariableWrapper(const Context& ctx, + T* query, + T* key, + T* value, + const int* seq_lens, + const phi::DenseTensor* mask_tensor, + const float scale, + const bool causal, + const int64_t batch_size, + const int64_t num_heads, + const int64_t seq_len, + const int64_t out_seq_len, + const int64_t head_size, + const int64_t value_head_size, + const int prompt_num, + T* output) { + Params params{}; + // [B, N, S, H] + params.seq_lens = seq_lens; + + params.num_batches = batch_size; + params.num_heads = num_heads; + params.query_seq_len = seq_len; + params.head_size = head_size; + params.key_value_seq_len = out_seq_len; + params.value_head_size = value_head_size; + + if (std::is_same::value) { + params.datatype = DataType::FLOAT16; + } else if (std::is_same::value) { + params.datatype = DataType::BFLOAT16; + } else { + params.datatype = DataType::FLOAT32; + } + params.query_ptr = query; + params.key_ptr = key; + params.value_ptr = value; + + params.prompt_num = prompt_num; + + params.output_ptr = output; + + params.ldq = params.head_size; + params.ldk = params.head_size; + params.ldv = params.value_head_size; + params.ldo = params.value_head_size; + + params.ElementQ = params.query_seq_len * params.head_size; + params.ElementK = params.key_value_seq_len * params.head_size; + params.ElementV = params.key_value_seq_len * params.value_head_size; + params.ElementO = params.query_seq_len * params.value_head_size; + + params.scale = scale; + params.causal = causal; + + if (mask_tensor) { + // [B, 1, S, D] + params.mask_broadcast_row = false; + params.mask_ptr = mask_tensor->data(); + params.ldm = mask_tensor->dims()[3]; + params.ElementM = mask_tensor->dims()[2] * mask_tensor->dims()[3]; + } else { + params.mask_broadcast_row = false; + params.mask_ptr = nullptr; + } + + bool kernel_launched = false; + + auto launchKernel = [&](auto k_, auto kernel_fn) { + using KernelType = decltype(k_); + if (kernel_launched) { + return; + } + if (mask_tensor && !KernelType::kAddMask) { + return; + } + if (!mask_tensor && KernelType::kAddMask) { + return; + } + if (KernelType::kMaskBroadcastRow) { + // not support mask_broad_cast + return; + } + if (params.mask_ptr && + reinterpret_cast(params.mask_ptr) % 16 == 0 && + params.ldm % (16 / sizeof(T)) == 0 && !KernelType::kMaskIsAligned) { + return; + } + if (params.mask_ptr && + !(reinterpret_cast(params.mask_ptr) % 16 == 0 && + params.ldm % (16 / sizeof(T)) == 0) && + KernelType::kMaskIsAligned) { + return; + } + if (KernelType::kSingleValueIteration && + KernelType::kKeysPerBlock < params.value_head_size) { + return; + } + if (KernelType::kKeysPerBlock == 64 && params.value_head_size > 64) { + return; + } + if (params.head_size % KernelType::MM0::kAlignmentA) { + return; + } + kernel_launched = true; + kernel_fn(k_, params, ctx); + }; + dispatch_cutlass_forward(ctx, launchKernel); + PADDLE_ENFORCE_EQ( + kernel_launched, + true, + phi::errors::InvalidArgument("the kernel should not be launched")); +} + +template void MultiHeadAttentionVariableWrapper(const phi::GPUContext& ctx, + phi::dtype::bfloat16* query, + phi::dtype::bfloat16* key, + phi::dtype::bfloat16* value, + const int* seq_lens, + const phi::DenseTensor* mask, + const float scale, + const bool causal, + const int64_t batch_size, + const int64_t num_heads, + const int64_t seq_len, + const int64_t out_seq_len, + const int64_t head_size, + const int64_t value_head_size, + const int prompt_num, + phi::dtype::bfloat16* output); + +template void MultiHeadAttentionVariableWrapper(const phi::GPUContext& ctx, + phi::dtype::float16* query, + phi::dtype::float16* key, + phi::dtype::float16* value, + const int* seq_lens, + const phi::DenseTensor* mask, + const float scale, + const bool causal, + const int64_t batch_size, + const int64_t num_heads, + const int64_t seq_len, + const int64_t out_seq_len, + const int64_t head_size, + const int64_t value_head_size, + const int prompt_num, + phi::dtype::float16* output); + +template void MultiHeadAttentionVariableWrapper(const phi::GPUContext& ctx, + float* query, + float* key, + float* value, + const int* seq_lens, + const phi::DenseTensor* mask, + const float scale, + const bool causal, + const int64_t batch_size, + const int64_t num_heads, + const int64_t seq_len, + const int64_t out_seq_len, + const int64_t head_size, + const int64_t value_head_size, + const int prompt_num, + float* output); + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_multihead_attention_variable, + GPU, + ALL_LAYOUT, + phi::fusion::MultiHeadAttentionVariableForwardKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(3).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/fusion/cutlass/fused_mutlihead_attention_kernel.cu b/paddle/phi/kernels/fusion/cutlass/fused_mutlihead_attention_kernel.cu new file mode 100644 index 0000000000000..06b6e7ba0b926 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/fused_mutlihead_attention_kernel.cu @@ -0,0 +1,536 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/malloc.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/kernel_forward.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { + +struct LaunchParams { + // meta params + phi::DataType datatype; + + // Create Input tensors in BMHK format, where + // B = batch_size + // M = sequence length + // H = num_heads + // K = embedding size per head + + const void* query_ptr; + const void* key_ptr; + const void* value_ptr; + // Mask Tensor format is BHMK + // and it can be broadcasted in axis0, 1, 2. + const void* mask_ptr = nullptr; + + int32_t* cu_seqlens_q_ptr = nullptr; + int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + void* output_ptr; // [num_batches, query_seq_len, num_heads, head_size] + void* output_accum_ptr; // [num_batches, query_seq_len, num_heads, head_size] + void* logsumexp_ptr; // [num_batches, num_heads, num_queries] - can be null + + // Scale + float scale; + + // Dimensions/strides + int32_t num_batches; + int32_t num_heads; + int32_t query_seq_len; + int32_t key_value_seq_len; + int32_t head_size; + int32_t value_head_size; + bool causal; + bool mask_broadcast_row; + /* + We can understand the computation of Fused Multihead Attention in this way: + for Query matmul Key, we execute num_batches * num_heads times matmul, + each matmul problem is: (M, K) (K, N) -> (M, N). + Here M is: query_seq_len, K is: head_size, N is: key_value_seq_len. + The stride concept is equals to torch's, it means the offset to move to next + axis. For Q matrix(M, K), we need to move K(which equals to head_size) offset + to next row(in M axis). + */ + int32_t query_strideM; + int32_t key_strideM; + int32_t value_strideM; + + // Since mask can be broadcasted, we need to assign each stride + int32_t mask_strideM; + int64_t mask_strideH; // stride for num_heads + int64_t mask_strideB; // stride for num_batches +}; + +template +void LaunchMultiHeadAttentionKernel(LaunchParams params, + const phi::GPUContext& ctx) { + using Attention = AttentionKernel; + + typename Attention::Params p; + { // set parameters + p.query_ptr = const_cast(reinterpret_cast(params.query_ptr)); + p.key_ptr = const_cast(reinterpret_cast(params.key_ptr)); + p.value_ptr = const_cast(reinterpret_cast(params.value_ptr)); + p.mask_ptr = const_cast(reinterpret_cast(params.mask_ptr)); + + // TODO(zhengzekang): Currently we only support inference, so here we set + // `logsumexp_ptr` as nullptr, which is used for backward. + p.logsumexp_ptr = nullptr; + + p.output_accum_ptr = nullptr; + if (Attention::kNeedsOutputAccumulatorBuffer) { + const int64_t output_size = params.num_batches * params.num_heads * + params.query_seq_len * params.head_size; + paddle::memory::AllocationPtr tmp_output_accum_buffer_ptr{nullptr}; + tmp_output_accum_buffer_ptr = paddle::memory::Alloc( + ctx.GetPlace(), + output_size * sizeof(typename Attention::output_accum_t), + phi::Stream(reinterpret_cast(ctx.stream()))); + p.output_accum_ptr = + reinterpret_cast( + tmp_output_accum_buffer_ptr->ptr()); + } + + p.output_ptr = reinterpret_cast(params.output_ptr); + + // TODO(zhengzekang): support arbitrary seq lengths + // if (cu_seqlens_q.has_value()) { + // p.cu_seqlens_q_ptr = (int32_t*)cu_seqlens_q->data_ptr(); + // p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr(); + // } + + p.num_batches = params.num_batches; + p.num_heads = params.num_heads; + p.num_queries = params.query_seq_len; + p.num_keys = params.key_value_seq_len; + p.head_dim = params.head_size; + p.head_dim_value = params.value_head_size; + + p.scale = params.scale; + p.causal = params.causal; + p.mask_broadcast_row = params.mask_broadcast_row; + + // TODO(zhengzekang): This might overflow for big tensors + p.q_strideM = params.query_strideM; + p.k_strideM = params.key_strideM; + p.mask_strideM = params.mask_strideM; + p.v_strideM = params.value_strideM; + + p.q_strideH = p.q_strideM * params.query_seq_len; + p.k_strideH = p.k_strideM * params.key_value_seq_len; + p.mask_strideH = params.mask_strideH; + p.v_strideH = p.v_strideM * params.key_value_seq_len; + p.o_strideH = params.value_head_size * params.query_seq_len; + + p.q_strideB = p.q_strideH * params.num_heads; + p.k_strideB = p.k_strideH * params.num_heads; + p.v_strideB = p.v_strideH * params.num_heads; + p.o_strideB = + params.value_head_size * params.query_seq_len * params.num_heads; + + p.mask_strideB = params.mask_strideB; + } + + constexpr auto kernel_fn = attention_kernel_batched_impl; + + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + + if (!Attention::check_supported(p)) { + PADDLE_THROW( + phi::errors::Unimplemented("The Params is not supported by cutlass " + "fused multihead attention. ")); + return; + } + kernel_fn<<>>(p); +} + +template +void DispatchFMHAMaskBroadcastRow(LaunchParams params, + const phi::GPUContext& ctx) { + if (params.mask_broadcast_row) { + LaunchMultiHeadAttentionKernel(params, ctx); + } else { + LaunchMultiHeadAttentionKernel(params, ctx); + } +} + +template +void DispatchFMHAAddMask(LaunchParams params, const phi::GPUContext& ctx) { + if (params.mask_ptr != nullptr) { + DispatchFMHAMaskBroadcastRow(params, ctx); + } else { + DispatchFMHAMaskBroadcastRow(params, ctx); + } +} + +template +void DispatchFMHASingleValueIteration(LaunchParams params, + const phi::GPUContext& ctx) { + if (params.value_head_size <= KeysPerBlock) { + DispatchFMHAAddMask(params, ctx); + } else { + DispatchFMHAAddMask(params, ctx); + } +} + +template +void DispatchFMHABlockSize(LaunchParams params, const phi::GPUContext& ctx) { + if (params.value_head_size > 64) { + DispatchFMHASingleValueIteration(params, ctx); + } else { + DispatchFMHASingleValueIteration(params, ctx); + } +} + +template +void DispatchFMHAMaskIsAligned(LaunchParams params, + const phi::GPUContext& ctx) { + if (reinterpret_cast(params.mask_ptr) % 16 == 0 && + params.mask_strideM % (16 / sizeof(T)) == 0) { + DispatchFMHABlockSize(params, ctx); + } else { + DispatchFMHABlockSize(params, ctx); + } +} + +template +void DispatchFMHAIsAligned(LaunchParams params, const phi::GPUContext& ctx) { + if (reinterpret_cast(params.query_ptr) % 16 == 0 && + reinterpret_cast(params.key_ptr) % 16 == 0 && + reinterpret_cast(params.value_ptr) % 16 == 0 && + params.query_strideM % (16 / sizeof(T)) == 0 && + params.query_strideM % (16 / sizeof(T)) == 0 && + params.value_strideM % (16 / sizeof(T)) == 0) { + DispatchFMHAMaskIsAligned(params, ctx); + } else { + DispatchFMHAMaskIsAligned(params, ctx); + } +} + +template +void DispatchFMHAArchTag(LaunchParams params, const phi::GPUContext& ctx) { + const int compute_capability = ctx.GetComputeCapability(); + if (compute_capability == 80) { + DispatchFMHAIsAligned(params, ctx); + } else if (compute_capability == 75) { + // DispatchFMHAIsAligned(params, ctx); + } else if (compute_capability == 70) { + // DispatchFMHAIsAligned(params, ctx); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently cutlass fused multihead attention kernel " + "only support arch: SM80, SM75, SM70")); + return; + } +} + +void DispatchFusedMultiheadAttentionKernel(LaunchParams params, + const phi::GPUContext& ctx) { + if (params.datatype == DataType::FLOAT32) { + return DispatchFMHAArchTag(params, ctx); + } else if (params.datatype == DataType::FLOAT16) { + return DispatchFMHAArchTag(params, ctx); + } else if (params.datatype == DataType::BFLOAT16) { + return DispatchFMHAArchTag(params, ctx); + } else { + PADDLE_ENFORCE_EQ( + true, + false, + phi::errors::Unimplemented( + "Currently cutlass fused multihead attention kernel " + "only support datatype: float32, float16 and bfloat16. ")); + return; + } +} + +template +void MultiHeadAttentionForwardWrapper(const Context& ctx, + T* query, + T* key, + T* value, + const phi::DenseTensor* mask_tensor, + const float scale, + const bool causal, + const int64_t batch_size, + const int64_t num_head, + const int64_t seq_len, + const int64_t out_seq_len, + const int64_t head_size, + T* output) { + LaunchParams params{}; + + if (std::is_same::value) { + params.datatype = DataType::FLOAT16; + } else if (std::is_same::value) { + params.datatype = DataType::BFLOAT16; + } else { + params.datatype = DataType::FLOAT32; + } + params.query_ptr = query; + params.key_ptr = key; + + params.value_ptr = value; + params.output_ptr = output; + params.output_accum_ptr = nullptr; + + // TODO(zhengzekang): currently we only used in inference. Maybe add a bool + // flag to save it ? + params.logsumexp_ptr = nullptr; + + params.num_batches = batch_size; + params.num_heads = num_head; + params.query_seq_len = seq_len; + params.key_value_seq_len = out_seq_len; + params.head_size = head_size; + params.value_head_size = head_size; + + params.scale = scale; + params.causal = causal; + + params.query_strideM = head_size; + params.key_strideM = head_size; + params.value_strideM = head_size; + + if (mask_tensor != nullptr) { + params.mask_broadcast_row = false; + params.mask_ptr = mask_tensor->data(); + params.mask_strideB = mask_tensor->dims()[3] * mask_tensor->dims()[2]; + params.mask_strideH = 0; // Since head dim is broadcast. + params.mask_strideM = mask_tensor->dims()[3]; + } else { + params.mask_broadcast_row = false; + params.mask_ptr = nullptr; + params.mask_strideB = 0; + params.mask_strideH = 0; + params.mask_strideM = 0; + } + + DispatchFusedMultiheadAttentionKernel(params, ctx); +} + +template +void MultiHeadAttentionForwardKernel(const Context& ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const paddle::optional& mask, + const float scale, + const bool causal, + DenseTensor* output) { + ctx.template Alloc(output); + LaunchParams params{}; + + params.datatype = query.dtype(); + params.query_ptr = query.data(); + params.key_ptr = key.data(); + params.mask_ptr = nullptr; + + params.value_ptr = value.data(); + params.output_ptr = output->data(); + params.output_accum_ptr = nullptr; + + // TODO(zhengzekang): currently we only used in inference. Maybe add a bool + // flag to save it ? + params.logsumexp_ptr = nullptr; + + params.num_batches = query.dims()[0]; + params.num_heads = query.dims()[1]; + params.query_seq_len = query.dims()[2]; + params.key_value_seq_len = key.dims()[2]; + params.head_size = query.dims()[3]; + params.value_head_size = value.dims()[3]; + + params.scale = scale; + params.causal = causal; + + params.query_strideM = query.dims()[3]; + params.key_strideM = key.dims()[3]; + params.value_strideM = value.dims()[3]; + + if (mask) { + auto mask_tensor = mask.get(); + params.mask_ptr = mask_tensor.data(); + params.mask_strideM = + mask_tensor.dims()[2] == 1 ? 0 : mask_tensor.dims()[3]; + params.mask_strideH = mask_tensor.dims()[1] == 1 + ? 0 + : mask_tensor.dims()[2] * mask_tensor.dims()[3]; + params.mask_strideB = mask_tensor.dims()[0] == 1 + ? 0 + : mask_tensor.dims()[1] * mask_tensor.dims()[2] * + mask_tensor.dims()[3]; + + params.mask_broadcast_row = false; + if (params.mask_strideM == 0) { + params.mask_broadcast_row = true; + } + } + DispatchFusedMultiheadAttentionKernel(params, ctx); +} + +template void MultiHeadAttentionForwardWrapper( + const phi::GPUContext& ctx, + phi::dtype::bfloat16* query, + phi::dtype::bfloat16* key, + phi::dtype::bfloat16* value, + const phi::DenseTensor* mask_tensor, + const float scale, + const bool causal, + const int64_t batch_size, + const int64_t num_head, + const int64_t seq_len, + const int64_t out_seq_len, + const int64_t head_size, + phi::dtype::bfloat16* output); + +template void MultiHeadAttentionForwardWrapper( + const phi::GPUContext& ctx, + phi::dtype::float16* query, + phi::dtype::float16* key, + phi::dtype::float16* value, + const phi::DenseTensor* mask_tensor, + const float scale, + const bool causal, + const int64_t batch_size, + const int64_t num_head, + const int64_t seq_len, + const int64_t out_seq_len, + const int64_t head_size, + phi::dtype::float16* output); + +template void MultiHeadAttentionForwardWrapper( + const phi::GPUContext& ctx, + float* query, + float* key, + float* value, + const phi::DenseTensor* mask_tensor, + const float scale, + const bool causal, + const int64_t batch_size, + const int64_t num_head, + const int64_t seq_len, + const int64_t out_seq_len, + const int64_t head_size, + float* output); +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL( + fused_multihead_attention, + GPU, + ALL_LAYOUT, + phi::fusion::cutlass_internal::MultiHeadAttentionForwardKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/fusion/fused_multihead_attention_kernel.h b/paddle/phi/kernels/fusion/fused_multihead_attention_kernel.h new file mode 100644 index 0000000000000..1a7d9ff114cff --- /dev/null +++ b/paddle/phi/kernels/fusion/fused_multihead_attention_kernel.h @@ -0,0 +1,54 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +namespace fusion { + +namespace cutlass_internal { + +template +void MultiHeadAttentionForwardKernel(const Context& ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const paddle::optional& mask, + const float scale, + const bool causal, + DenseTensor* output); + +template +void MultiHeadAttentionForwardWrapper(const Context& ctx, + T* query, + T* key, + T* value, + const phi::DenseTensor* mask_tensor, + const float scale, + const bool causal, + const int64_t batch_size, + const int64_t num_head, + const int64_t seq_len, + const int64_t out_seq_len, + const int64_t head_size, + T* output); + +} // namespace cutlass_internal + +} // namespace fusion + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/fused_multihead_attention_variable_kernel.h b/paddle/phi/kernels/fusion/fused_multihead_attention_variable_kernel.h new file mode 100644 index 0000000000000..1b9f0d1acc892 --- /dev/null +++ b/paddle/phi/kernels/fusion/fused_multihead_attention_variable_kernel.h @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace fusion { + +template +void MultiHeadAttentionVariableForwardKernel( + const Context& ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const DenseTensor& seq_lens, + const paddle::optional& mask, + const float scale, + const bool causal, + DenseTensor* output); + +template +void MultiHeadAttentionVariableWrapper(const Context& ctx, + T* query, + T* key, + T* value, + const int* seq_lens, + const phi::DenseTensor* mask_tensor, + const float scale, + const bool causal, + const int64_t batch_size, + const int64_t num_heads, + const int64_t seq_len, + const int64_t out_seq_len, + const int64_t head_size, + const int64_t value_head_size, + const int prompt_num, + T* output); + +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h b/paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h new file mode 100644 index 0000000000000..e804008f9123a --- /dev/null +++ b/paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace fusion { + +template +void FusedSoftmaxMaskKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& mask, + DenseTensor* out); +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 55189de6be6a6..83124414586ce 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -246,6 +246,7 @@ PD_REGISTER_KERNEL(flash_attn_unpadded, GPU, ALL_LAYOUT, phi::FlashAttnUnpaddedKernel, + float, phi::dtype::float16, phi::dtype::bfloat16) { kernel->InputAt(5).SetBackend( diff --git a/python/paddle/distributed/communication/stream/all_to_all.py b/python/paddle/distributed/communication/stream/all_to_all.py index d8793601f729a..07547242628e1 100644 --- a/python/paddle/distributed/communication/stream/all_to_all.py +++ b/python/paddle/distributed/communication/stream/all_to_all.py @@ -92,7 +92,7 @@ def _all_to_all_in_static_mode( data_feeder.check_variable_and_dtype( in_tensor, 'in_tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'all_to_all', ) helper.append_op( diff --git a/python/paddle/distributed/communication/stream/recv.py b/python/paddle/distributed/communication/stream/recv.py index 8fbbfbf098828..8f04c11c95bc1 100644 --- a/python/paddle/distributed/communication/stream/recv.py +++ b/python/paddle/distributed/communication/stream/recv.py @@ -43,7 +43,7 @@ def _recv_in_static_mode( data_feeder.check_variable_and_dtype( tensor, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'recv', ) ring_id = 0 if group is None else group.id diff --git a/python/paddle/distributed/communication/stream/send.py b/python/paddle/distributed/communication/stream/send.py index 0de989042b9e6..88a93002455ce 100644 --- a/python/paddle/distributed/communication/stream/send.py +++ b/python/paddle/distributed/communication/stream/send.py @@ -43,7 +43,7 @@ def _send_in_static_mode( data_feeder.check_variable_and_dtype( tensor, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'send', ) diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index 26949b12e42f7..16f67783ce9c9 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -104,7 +104,7 @@ def _c_identity(tensor, group=None): check_variable_and_dtype( tensor, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], '_c_identity', ) @@ -164,7 +164,7 @@ def _c_concat(tensor, group=None): check_variable_and_dtype( tensor, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], '_c_concat', ) @@ -230,7 +230,7 @@ def _c_split(tensor, group=None): check_variable_and_dtype( tensor, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], '_c_split', ) @@ -275,7 +275,7 @@ def _mp_allreduce( check_variable_and_dtype( tensor, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], op_type, ) diff --git a/python/paddle/distributed/utils/moe_utils.py b/python/paddle/distributed/utils/moe_utils.py index ae18938941817..979d844f9142e 100644 --- a/python/paddle/distributed/utils/moe_utils.py +++ b/python/paddle/distributed/utils/moe_utils.py @@ -117,7 +117,7 @@ def global_scatter( check_variable_and_dtype( x, 'x', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'global_scatter', ) check_variable_and_dtype( @@ -234,7 +234,7 @@ def global_gather( check_variable_and_dtype( x, 'x', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'global_gather', ) diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 92c95f8eac23a..090b46928554d 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -583,7 +583,7 @@ def partial_concat(input, start_index=0, length=-1): check_variable_and_dtype( x, 'input[' + str(id) + ']', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'partial_concat', ) check_type(start_index, 'start_index', (int), 'partial_concat') diff --git a/python/paddle/fluid/tests/unittests/test_empty_like_op.py b/python/paddle/fluid/tests/unittests/test_empty_like_op.py index 164275b1a7d83..fa0091f3020a6 100644 --- a/python/paddle/fluid/tests/unittests/test_empty_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_empty_like_op.py @@ -39,7 +39,14 @@ def __check_out__(self, out): f'shape should be {self.dst_shape}, but get {shape}', ) - if data_type in ['float16', 'float32', 'float64', 'int32', 'int64']: + if data_type in [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'uint16', + ]: max_value = np.nanmax(out) min_value = np.nanmin(out) always_non_full_zero = max_value >= min_value diff --git a/python/paddle/fluid/tests/unittests/test_empty_op.py b/python/paddle/fluid/tests/unittests/test_empty_op.py index bfd3184c4bb2b..14cbf1333c9f5 100644 --- a/python/paddle/fluid/tests/unittests/test_empty_op.py +++ b/python/paddle/fluid/tests/unittests/test_empty_op.py @@ -34,7 +34,14 @@ def test_check_output(self): def verify_output(self, outs): data_type = outs[0].dtype - if data_type in ['float16', 'float32', 'float64', 'int32', 'int64']: + if data_type in [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'uint16', + ]: max_value = np.nanmax(outs[0]) min_value = np.nanmin(outs[0]) diff --git a/python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_multi_transformer.py b/python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_multi_transformer.py index f4637b070cbf9..64dbea5f0e5fd 100644 --- a/python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_multi_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_multi_transformer.py @@ -14,15 +14,31 @@ import os import unittest +import re from test_dist_base import TestDistBase import paddle +from paddle.fluid import core paddle.enable_static() flag_name = os.path.splitext(__file__)[0] - +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11030, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3", +) class TestStaticModelParallel(TestDistBase): def _setup_config(self): self._sync_mode = True diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 59a15750c7d88..11c70ba042cbd 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -885,7 +885,9 @@ def fused_multi_transformer( ffn2_biases, pre_layer_norm=True, epsilon=1e-05, + residual_alpha=1.0, cache_kvs=None, + beam_offset=None, pre_caches=None, seq_lens=None, rotary_embs=None, @@ -898,6 +900,8 @@ def fused_multi_transformer( mode='upscale_in_train', trans_qkvw=True, ring_id=-1, + norm_type="layernorm", + use_neox_rotary_style=False, name=None, ): r""" @@ -940,32 +944,55 @@ def fused_multi_transformer( out = ffn_layer_norm(out) Args: - x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16 or float32, the shape is `[batch\_size, sequence\_length, d\_model]`. - ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of attention layer_norm, the shape is `[d\_model]`. - ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of attention layer_norm. the shape is `[d\_model]`. - qkv_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head, d\_model]`. - qkv_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head]`. - linear_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention linear. The shape is `[num\_head * dim\_head, d\_model]`. - linear_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention linear. The shape is `[d\_model]`. - ffn_ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward layer_norm, the shape is `[d\_model]` - ffn_ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of feedforward layer_norm, the shape is `[d\_model]` - ffn1_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward first linear, the shape is `[d\_model, dim\_feedforward]`. - ffn1_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward first linear, the shape is `[dim\_feedforward]`. - ffn2_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward second linear, the shape is `[dim\_feedforward, d\_model]`. - ffn2_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward second linear, the shape is `[d_model]`. - pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True. - epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5. - cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. - pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. + x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16 or float32, + the shape is `[batch\_size, sequence\_length, d\_model]`. + ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of attention layer_norm, + the shape is `[d\_model]`. + ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of attention layer_norm. + the shape is `[d\_model]`. + qkv_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention qkv computation. + The shape is `[3, num\_head, dim\_head, d\_model]`. + qkv_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention qkv computation. + The shape is `[3, num\_head, dim\_head]`. + linear_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention linear. + The shape is `[num\_head * dim\_head, d\_model]`. + linear_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention linear. + The shape is `[d\_model]`. + ffn_ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward layer_norm, + the shape is `[d\_model]` + ffn_ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of feedforward layer_norm, + the shape is `[d\_model]` + ffn1_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward first linear, + the shape is `[d\_model, dim\_feedforward]`. + ffn1_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward first linear, + the shape is `[dim\_feedforward]`. + ffn2_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward second linear, + the shape is `[dim\_feedforward, d\_model]`. + ffn2_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward second linear, + the shape is `[d_model]`. + pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). + Default True. + epsilon (float, optional): Small float value added to denominator of the layer_norm + to avoid dividing by zero. Default is 1e-5. + cache_kvs (list(Tensor)|tuple(Tensor), optional): + The cache structure tensors for the generation model. + The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. + pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. + The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None. - rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. - time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None. + rotary_embs (Tensor optional): The RoPE embs for rotary computation. + The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. + time_step (Tensor, optional): The time step tensor for the generation model. + Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. + The shape is `[1]`, must be in CPUPlace. Default None. attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None. dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0. - rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, - 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0. + rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, + and it is 0 when rotary_embs is None, + 1 when rotary_embs is not None and pos_extra_ids is None, + 2 when rotary_embs and pos_extra_ids are both not None. Default 0. activation (str, optional): The activation. Default "gelu". training (bool, optional): A flag indicating whether it is in train phrase or not. Default False. mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] @@ -982,8 +1009,10 @@ def fused_multi_transformer( trans_qkvw (bool, optional): Whether to transpose for weights of qkv. If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed]. Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default True. - ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. + Default is -1, means not using mp. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. Returns: Tensor|tuple: If `cache_kvs` is None, return a tensor that has @@ -1056,6 +1085,7 @@ def fused_multi_transformer( cache_kvs, pre_caches, rotary_embs, + beam_offset, time_step, seq_lens, attn_mask, @@ -1072,6 +1102,8 @@ def fused_multi_transformer( pre_layer_norm, 'epsilon', epsilon, + 'residual_alpha', + residual_alpha, 'dropout_rate', dropout_rate, 'rotary_emb_dims', @@ -1086,6 +1118,10 @@ def fused_multi_transformer( trans_qkvw, 'ring_id', ring_id, + 'norm_type', + norm_type, + 'use_neox_rotary_style', + use_neox_rotary_style, ) if cache_kvs is not None: return final_out, cache_kv_out @@ -1095,18 +1131,23 @@ def fused_multi_transformer( dtype = x.dtype # check dtypes check_variable_and_dtype( - x, 'x', ['float16', 'float32'], 'fused_multi_transformer' + x, 'x', ['uint16', 'float16', 'float32'], 'fused_multi_transformer' ) check_dtype( - dtype, 'dtype', ['float16', 'float32'], 'fused_multi_transformer' + dtype, + 'dtype', + ['uint16', 'float16', 'float32'], + 'fused_multi_transformer', ) # set inputs inputs = {} inputs['X'] = [x] inputs['LnScale'] = ln_scales - inputs['LnBias'] = ln_biases inputs['QKVW'] = qkv_weights + + if ln_biases is not None: + inputs['LnBias'] = ln_biases if qkv_biases is not None: inputs['QKVBias'] = qkv_biases if cache_kvs is not None: @@ -1116,6 +1157,8 @@ def fused_multi_transformer( inputs['TimeStep'] = time_step if pre_caches is not None: inputs['PreCaches'] = pre_caches + if beam_offset is not None: + inputs['BeamCacheOffset'] = beam_offset if rotary_emb_dims > 0: inputs['RotaryPosEmb'] = rotary_embs inputs['SeqLengths'] = seq_lens @@ -1125,7 +1168,8 @@ def fused_multi_transformer( inputs['OutLinearBias'] = linear_biases inputs['FFNLnScale'] = ffn_ln_scales - inputs['FFNLnBias'] = ffn_ln_biases + if ffn_ln_biases is not None: + inputs['FFNLnBias'] = ffn_ln_biases inputs['FFN1Weight'] = ffn1_weights if ffn1_biases is not None: inputs['FFN1Bias'] = ffn1_biases @@ -1137,6 +1181,7 @@ def fused_multi_transformer( attrs = { 'pre_layer_norm': pre_layer_norm, 'epsilon': epsilon, + 'residual_alpha': residual_alpha, 'dropout_rate': dropout_rate, 'rotary_emb_dims': rotary_emb_dims, 'is_test': not training, @@ -1144,6 +1189,8 @@ def fused_multi_transformer( 'act_method': activation, 'trans_qkvw': trans_qkvw, 'ring_id': ring_id, + 'norm_type': norm_type, + 'use_neox_rotary_style': use_neox_rotary_style, } outputs = {} diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 31ea1e8d663b1..1a345e3f1d42c 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1072,20 +1072,23 @@ class FusedMultiTransformer(Layer): Otherwise, no pre-process and post-precess includes dropout, residual connection, layer normalization. Default True ln_scale_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property - for Attention layer_norm. For Attention layer_norm weight, if it is a list/tuple, `attrs[0]` + for Attention layer_norm. For Attention layer_norm weight, + if it is a list/tuple, `attrs[0]` would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as `attr` for transformer layer 1,etc. Otherwise, all layers both use it as `attr` to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in :code:`ParamAttr`. ln_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property - for Attention layer_norm. For Attention layer_norm bias, if it is a list/tuple, `attrs[0]` + for Attention layer_norm. For Attention layer_norm bias, if it is a list/tuple, + `attrs[0]` would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as `attr` for transformer layer 1,etc. Otherwise, all layers both use it as `attr` to create parameters. The `False` value means the corresponding layer would not have trainable bias parameter. Default: None, which means the default bias parameter property is used. See usage for details in :code:`ParamAttr`. qkv_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property - for Attention qkv computation. For Attention qkv weight, if it is a list/tuple, `attrs[0]` + for Attention qkv computation. For Attention qkv weight, + if it is a list/tuple, `attrs[0]` would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as `attr` for transformer layer 1,etc. Otherwise, all layers both use it as `attr` to create parameters. Default: None, which means the default weight @@ -1097,14 +1100,17 @@ class FusedMultiTransformer(Layer): `attr` to create parameters. The `False` value means the corresponding layer would not have trainable bias parameter. Default: None, which means the default bias parameter property is used. See usage for details in :code:`ParamAttr`. - linear_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property + linear_weight_attrs(ParamAttr|list|tuple, optional): + To specify the weight parameter property for Attention linear. For Attention linear weight, if it is a list/tuple, `attrs[0]` would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as `attr` for transformer layer 1,etc. Otherwise, all layers both use it as `attr` to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in :code:`ParamAttr`. - linear_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property - for Attention linear computation. For Attention linear bias, if it is a list/tuple, `attrs[0]` + linear_bias_attrs(ParamAttr|list|tuple|bool, optional): + To specify the bias parameter property + for Attention linear computation. For Attention linear bias, + if it is a list/tuple, `attrs[0]` would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as `attr` for transformer layer 1,etc. Otherwise, all layers both use it as `attr` to create parameters. The `False` value means the corresponding layer would @@ -1116,7 +1122,8 @@ class FusedMultiTransformer(Layer): `attr` for transformer layer 1,etc. Otherwise, all layers both use it as `attr` to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in :code:`ParamAttr`. - ffn_ln_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property + ffn_ln_bias_attrs(ParamAttr|list|tuple|bool, optional): + To specify the bias parameter property for FFN layer_norm. For FFN layer_norm bias, if it is a list/tuple, `attrs[0]` would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as `attr` for transformer layer 1,etc. Otherwise, all layers both use it as @@ -1129,20 +1136,23 @@ class FusedMultiTransformer(Layer): `attr` for transformer layer 1,etc. Otherwise, all layers both use it as `attr` to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in :code:`ParamAttr`. - ffn1_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property + ffn1_bias_attrs(ParamAttr|list|tuple|bool, optional): + To specify the bias parameter property for FFN first linear. For FFN first linear bias, if it is a list/tuple, `attrs[0]` would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as `attr` for transformer layer 1,etc. Otherwise, all layers both use it as `attr` to create parameters. The `False` value means the corresponding layer would not have trainable bias parameter. Default: None, which means the default bias parameter property is used. See usage for details in :code:`ParamAttr`. - ffn2_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property + ffn2_weight_attrs(ParamAttr|list|tuple, optional): + To specify the weight parameter property for FFN second linear. For FFN second linear weight, if it is a list/tuple, `attrs[0]` would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as `attr` for transformer layer 1,etc. Otherwise, all layers both use it as `attr` to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in :code:`ParamAttr`. - ffn2_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property + ffn2_bias_attrs(ParamAttr|list|tuple|bool, optional): + To specify the bias parameter property for FFN second linear. For FFN second linear bias, if it is a list/tuple, `attrs[0]` would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as `attr` for transformer layer 1,etc. Otherwise, all layers both use it as @@ -1152,13 +1162,17 @@ class FusedMultiTransformer(Layer): epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default: 1e-05. num_layers (int, optional): The number of layers of the transformer. If `qkv_weight_attrs` - is a list or tuple, the number of layers is obtained from `qkv_weight_attrs`. num_layers + is a list or tuple, the number of layers is obtained from `qkv_weight_attrs`. + num_layers only takes effect when `qkv_weight_attrs` is not a list or tuple. Default: -1. - nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using mp. + nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, + means not using mp. trans_qkvw (bool, optional): Whether to transpose for weights of qkv. If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed]. - Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default: True. - ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using mp. + Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. + Default: True. + ring_id (int, optional): For distributed tensor model parallel. + Default is -1, means not using mp. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -1199,10 +1213,13 @@ def __init__( ffn2_weight_attrs=None, ffn2_bias_attrs=None, epsilon=1e-5, + residual_alpha=1.0, num_layers=-1, nranks=1, trans_qkvw=True, ring_id=-1, + norm_type="layernorm", + use_neox_rotary_style=False, name=None, ): super().__init__() @@ -1225,8 +1242,14 @@ def __init__( self.normalize_before = normalize_before self._dtype = self._helper.get_default_dtype() self._epsilon = epsilon + self._residual_alpha = residual_alpha self._trans_qkvw = trans_qkvw self._ring_id = ring_id + self._norm_type = norm_type + self._use_neox_rotary_style = use_neox_rotary_style + self._norm_weight_dtype = ( + "float32" if self._norm_type == "layernorm" else self._dtype + ) self.embed_dim = embed_dim self.num_heads = num_heads @@ -1254,6 +1277,10 @@ def __init__( self.ffn_ln_scales, self.ffn_ln_biases = [], [] self.ffn1_weights, self.ffn1_biases = [], [] self.ffn2_weights, self.ffn2_biases = [], [] + self.qkv_weights_scales = [] + self.linear_weights_scales = [] + self.ffn1_weights_scales = [] + self.ffn2_weights_scales = [] def get_attr(attrs, idx): if isinstance(attrs, (list, tuple)): @@ -1261,6 +1288,12 @@ def get_attr(attrs, idx): return attrs[idx] return attrs + def _add_parameter(param): + if param is None: + return + assert param.name not in self._parameters + self._parameters[param.name] = param + for i in range(num_layers): ln_scale_attr = get_attr(ln_scale_attrs, i) ln_bias_attr = get_attr(ln_bias_attrs, i) @@ -1280,10 +1313,17 @@ def get_attr(attrs, idx): attr=ln_scale_attr, shape=[embed_dim], default_initializer=Constant(value=1.0), + dtype=self._norm_weight_dtype, ) - ln_bias = self.create_parameter( - attr=ln_bias_attr, shape=[embed_dim], is_bias=True - ) + ln_bias = None + if ln_bias_attr: + ln_bias = self.create_parameter( + attr=ln_bias_attr, + shape=[embed_dim], + is_bias=True, + dtype=self._norm_weight_dtype, + ) + qkv_weight = self.create_parameter( shape=[3, num_heads, self.head_dim, embed_dim] if trans_qkvw @@ -1292,63 +1332,90 @@ def get_attr(attrs, idx): dtype=self._dtype, is_bias=False, ) - qkv_bias = self.create_parameter( - shape=[3, num_heads, self.head_dim], - attr=qkv_bias_attr, - dtype=self._dtype, - is_bias=True, - ) + + qkv_bias = None + if qkv_bias_attr: + qkv_bias = self.create_parameter( + shape=[3, num_heads, self.head_dim], + attr=qkv_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + linear_weight = self.create_parameter( shape=[num_heads * self.head_dim, embed_dim], attr=linear_weight_attr, dtype=self._dtype, is_bias=False, ) - linear_bias = self.create_parameter( - shape=[embed_dim], - attr=linear_bias_attr, - dtype=self._dtype, - is_bias=True, - ) + + linear_bias = None + if linear_bias_attr: + linear_bias = self.create_parameter( + shape=[embed_dim], + attr=linear_bias_attr, + dtype=self._dtype, + is_bias=True, + ) ffn_ln_scale = self.create_parameter( shape=[embed_dim], attr=ffn_ln_scale_attr, is_bias=False, default_initializer=Constant(1.0), + dtype=self._norm_weight_dtype, ) - ffn_ln_bias = self.create_parameter( - shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True - ) + + ffn_ln_bias = None + if ffn_ln_bias_attr: + ffn_ln_bias = self.create_parameter( + shape=[embed_dim], + attr=ffn_ln_bias_attr, + is_bias=True, + dtype=self._norm_weight_dtype, + ) + ffn1_weight = self.create_parameter( - shape=[embed_dim, dim_feedforward], + shape=[embed_dim, dim_feedforward * 2] + if activation.endswith("glu") + else [embed_dim, dim_feedforward], attr=ffn1_weight_attr, dtype=self._dtype, is_bias=False, ) - ffn1_bias = self.create_parameter( - shape=[dim_feedforward], - attr=ffn1_bias_attr, - dtype=self._dtype, - is_bias=True, - ) + + ffn1_bias = None + if ffn1_bias_attr: + ffn1_bias = self.create_parameter( + shape=[dim_feedforward * 2] + if activation.endswith("glu") + else [dim_feedforward], + attr=ffn1_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + ffn2_weight = self.create_parameter( shape=[dim_feedforward, embed_dim], attr=ffn2_weight_attr, dtype=self._dtype, is_bias=False, ) - ffn2_bias = self.create_parameter( - shape=[embed_dim], - attr=ffn2_bias_attr, - dtype=self._dtype, - is_bias=True, - ) + + ffn2_bias = None + if ffn2_bias_attr: + ffn2_bias = self.create_parameter( + shape=[embed_dim], + attr=ffn2_bias_attr, + dtype=self._dtype, + is_bias=True, + ) # tensor model parallel if nranks > 1: # column parallel _set_var_distributed(qkv_weight) + _set_var_distributed(qkv_bias) _set_var_distributed(ffn1_weight) _set_var_distributed(ffn1_bias) @@ -1370,6 +1437,38 @@ def get_attr(attrs, idx): self.ffn2_weights.append(ffn2_weight) self.ffn2_biases.append(ffn2_bias) + _add_parameter(ln_scale) + _add_parameter(ln_bias) + _add_parameter(qkv_weight) + _add_parameter(qkv_bias) + _add_parameter(linear_weight) + _add_parameter(linear_bias) + + _add_parameter(ffn_ln_scale) + _add_parameter(ffn_ln_bias) + _add_parameter(ffn1_weight) + _add_parameter(ffn1_bias) + _add_parameter(ffn2_weight) + _add_parameter(ffn2_bias) + + if self.ln_biases[0] is None: + self.ln_biases = None + + if self.qkv_biases[0] is None: + self.qkv_biases = None + + if self.linear_biases[0] is None: + self.linear_biases = None + + if self.ffn_ln_biases[0] is None: + self.ffn_ln_biases = None + + if self.ffn1_biases[0] is None: + self.ffn1_biases = None + + if self.ffn2_biases[0] is None: + self.ffn2_biases = None + self.dropout_rate = dropout_rate self.activation = activation self.name = name @@ -1382,6 +1481,7 @@ def forward( pre_caches=None, rotary_embs=None, rotary_emb_dims=0, + beam_offset=None, seq_lens=None, time_step=None, ): @@ -1403,11 +1503,16 @@ def forward( inference and should be None for training. The shape is `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None. pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches - for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. - rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. - rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, - 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0. - seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None. + for the generation model. The shape is + `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. + rotary_embs (Tensor optional): The RoPE embs for the rotary computation. + The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. + rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, + and it is 0 when rotary_embs is None, + 1 when rotary_embs is not None and pos_extra_ids is None, + 2 when rotary_embs and pos_extra_ids are both not None. Default 0. + seq_lens (Tensor optional): The sequence lengths of this batch. + The shape is `[bsz]`. Default None. time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be @@ -1439,7 +1544,9 @@ def forward( self.ffn2_biases, pre_layer_norm=self.normalize_before, epsilon=self._epsilon, + residual_alpha=self._residual_alpha, cache_kvs=caches, + beam_offset=beam_offset, pre_caches=pre_caches, rotary_embs=rotary_embs, time_step=time_step, @@ -1452,6 +1559,8 @@ def forward( mode='upscale_in_train', trans_qkvw=self._trans_qkvw, ring_id=self._ring_id, + norm_type=self._norm_type, + use_neox_rotary_style=self._use_neox_rotary_style, name=self.name, ) return out diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 8e92dfa3a5e60..5499b12e9c1b5 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1161,7 +1161,7 @@ def _check_attr(attr, message): check_dtype( dtype, 'dtype', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'eye', ) out = helper.create_variable_for_type_inference(dtype=dtype) @@ -1561,7 +1561,7 @@ def meshgrid(*args, **kwargs): check_dtype( input_.dtype, 'create data type', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'meshgrid', ) @@ -1674,7 +1674,7 @@ def diagflat(x, offset=0, name=None): check_dtype( x.dtype, 'x', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'diagflat', ) check_type(offset, 'offset', (int), 'diagflat') @@ -1792,7 +1792,7 @@ def diag(x, offset=0, padding_value=0, name=None): check_dtype( x.dtype, 'x', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'diag_v2', ) check_type(offset, 'offset', (int), 'diag_v2') diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 9bbc773e07f85..32fa60a43947c 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1310,7 +1310,7 @@ def t(input, name=None): check_variable_and_dtype( input, 'input', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'transpose', ) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 09aaff08c3ca5..39919c85cad2a 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1714,6 +1714,7 @@ def roll(x, shifts, axis=None, name=None): 'float64', 'int32', 'int64', + 'uint16', 'complex64', 'complex128', ], @@ -2206,6 +2207,7 @@ def squeeze(x, axis=None, name=None): 'int8', 'int32', 'int64', + 'uint16', 'complex64', 'complex128', ], @@ -2603,6 +2605,7 @@ def unsqueeze(x, axis, name=None): 'int16', 'int32', 'int64', + 'uint16', 'complex64', 'complex128', ], @@ -4694,7 +4697,7 @@ def index_add(x, index, axis, value, name=None): check_variable_and_dtype( x, 'x', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'paddle.tensor.manipulation.index_add', ) check_variable_and_dtype( @@ -4706,7 +4709,7 @@ def index_add(x, index, axis, value, name=None): check_variable_and_dtype( value, 'add_value', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'paddle.tensor.manipulation.index_add', ) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 65d94c1138562..6a0916cf542ac 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1441,7 +1441,10 @@ def nansum(x, axis=None, dtype=None, keepdim=False, name=None): out6 = paddle.nansum(y, axis=[0, 1]) # [9, 18] """ check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'nansum' + x, + 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'nansum', ) check_type(axis, 'axis', (int, list, tuple, type(None)), 'nansum') @@ -3208,10 +3211,16 @@ def kron(x, y, name=None): else: helper = LayerHelper('kron', **locals()) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'kron' + x, + 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'kron', ) check_variable_and_dtype( - y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'], 'kron' + y, + 'y', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + 'kron', ) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -3277,7 +3286,7 @@ def cumsum(x, axis=None, dtype=None, name=None): check_variable_and_dtype( x, 'x', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'cumsum', ) check_type(x, 'x', (Variable), 'cumsum')