diff --git a/CMakeLists.txt b/CMakeLists.txt index 0129f85123fb..5e36742dddb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -232,7 +232,6 @@ endif() set(VLLM_EXT_SRC "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" - "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/cache_kernels.cu" "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu deleted file mode 100644 index c83d72751a55..000000000000 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ /dev/null @@ -1,656 +0,0 @@ -// clang-format off -// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu -// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu -#include -#include -#include - -#include "causal_conv1d.h" -#include -#include -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK - -#include -#include - -#ifdef USE_ROCM - namespace cub = hipcub; -#endif - -#include "static_switch.h" - - - -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ - } - - -template -void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -template -void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -void set_conv_params_fwd(ConvParamsBase ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t width, - // device pointers - const at::Tensor x, - const at::Tensor weight, - const at::Tensor out, - const std::optional& bias, - bool silu_activation, - int64_t pad_slot_id, - const std::optional& query_start_loc = std::nullopt, - const std::optional& cache_indices = std::nullopt, - const std::optional& has_initial_state = std::nullopt) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.width = width; - params.pad_slot_id = pad_slot_id; - - params.silu_activation = silu_activation; - - // Set the pointers and strides. - params.x_ptr = x.data_ptr(); - params.weight_ptr = weight.data_ptr(); - params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; - params.out_ptr = out.data_ptr(); - // All stride are in elements, not bytes. - params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; - params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; - params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; - const bool varlen = params.query_start_loc_ptr != nullptr; - params.x_batch_stride = x.stride(varlen ? 1 : 0); - params.x_c_stride = x.stride(varlen ? 0 : 1); - params.x_l_stride = x.stride(varlen ? 1 : -1); - params.weight_c_stride = weight.stride(0); - params.weight_width_stride = weight.stride(1); - params.out_batch_stride = out.stride(varlen ? 1 : 0); - params.out_c_stride = out.stride(varlen ? 0 : 1); - params.out_l_stride = out.stride(varlen ? 1 : -1); -} - - -void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, - const std::optional &bias_, - const std::optional &conv_states, - const std::optional &query_start_loc, - const std::optional &cache_indices, - const std::optional &has_initial_state, - bool silu_activation, - // used to identify padding entries if cache_indices provided - // in case of padding, the kernel will return early - int64_t pad_slot_id) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - - const bool varlen = query_start_loc.has_value() ? true : false; - const auto sizes = x.sizes(); - const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; - const int dim = varlen ? sizes[0] : sizes[1]; - const int seqlen = varlen ? sizes[1] : sizes[2]; - const int width = weight.size(-1); - if (varlen){ - CHECK_SHAPE(x, dim, seqlen); - } - else { - CHECK_SHAPE(x, batch_size, dim, seqlen); - } - CHECK_SHAPE(weight, dim, width); - - - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); - } - - - if (has_initial_state.has_value()) { - auto has_initial_state_ = has_initial_state.value(); - TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); - TORCH_CHECK(has_initial_state_.is_cuda()); - CHECK_SHAPE(has_initial_state_, batch_size); - } - - - if (query_start_loc.has_value()) { - auto query_start_loc_ = query_start_loc.value(); - TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); - TORCH_CHECK(query_start_loc_.is_cuda()); - } - - - if (cache_indices.has_value()) { - auto cache_indices_ = cache_indices.value(); - TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); - TORCH_CHECK(cache_indices_.is_cuda()); - CHECK_SHAPE(cache_indices_, batch_size); - } - - at::Tensor out = x; - - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, - bias_, - silu_activation, - pad_slot_id, - query_start_loc, - cache_indices, - has_initial_state - ); - - if (conv_states.has_value()) { - auto conv_states_ = conv_states.value(); - TORCH_CHECK(conv_states_.scalar_type() == input_type); - TORCH_CHECK(conv_states_.is_cuda()); - params.conv_states_ptr = conv_states_.data_ptr(); - params.conv_states_batch_stride = conv_states_.stride(0); - params.conv_states_c_stride = conv_states_.stride(1); - params.conv_states_l_stride = conv_states_.stride(2); - } else { - params.conv_states_ptr = nullptr; - } - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { - causal_conv1d_fwd_cuda(params, stream); - }); -} - - -void causal_conv1d_update(const at::Tensor &x, - const at::Tensor &conv_state, - const at::Tensor &weight, - const std::optional &bias_, - bool silu_activation, - const std::optional &cache_seqlens_, - const std::optional &conv_state_indices_, - // used to identify padding entries if cache_indices provided - // in case of padding, the kernel will return early - int64_t pad_slot_id) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); - TORCH_CHECK(conv_state.scalar_type() == input_type); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(conv_state.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - - const auto sizes = x.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int width = weight.size(-1); - const int conv_state_len = conv_state.size(2); - TORCH_CHECK(conv_state_len >= width - 1); - - CHECK_SHAPE(x, batch_size, dim, seqlen); - CHECK_SHAPE(weight, dim, width); - - TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); - } - - at::Tensor out = x; - - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, - bias_, - silu_activation, - pad_slot_id); - params.conv_state_ptr = conv_state.data_ptr(); - params.conv_state_len = conv_state_len; - // All stride are in elements, not bytes. - params.conv_state_batch_stride = conv_state.stride(0); - params.conv_state_c_stride = conv_state.stride(1); - params.conv_state_l_stride = conv_state.stride(2); - - if (cache_seqlens_.has_value()) { - auto cache_seqlens = cache_seqlens_.value(); - TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32); - TORCH_CHECK(cache_seqlens.is_cuda()); - TORCH_CHECK(cache_seqlens.stride(-1) == 1); - CHECK_SHAPE(cache_seqlens, batch_size); - params.cache_seqlens = cache_seqlens.data_ptr(); - } else { - params.cache_seqlens = nullptr; - } - - if (conv_state_indices_.has_value()) { - auto conv_state_indices = conv_state_indices_.value(); - TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) - TORCH_CHECK(conv_state_indices.is_cuda()); - TORCH_CHECK(conv_state_indices.stride(0) == 1) - CHECK_SHAPE(conv_state_indices, batch_size); - - int conv_state_entries = conv_state.size(0); - CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len); - - params.conv_state_indices_ptr = conv_state_indices.data_ptr(); - } else { - CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len); - params.conv_state_indices_ptr = nullptr; - } - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { - causal_conv1d_update_cuda(params, stream); - }); -} - -template -struct Causal_conv1d_fwd_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static_assert(kWidth <= kNElts); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType::Type; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - static constexpr int kSmemIOSize = kIsVecLoad - ? 0 - : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); - static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; - static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_fwd_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - extern __shared__ char smem_[]; - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_vec = reinterpret_cast(smem_); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_store_vec = reinterpret_cast(smem_); - vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - - const bool kVarlen = params.query_start_loc_ptr != nullptr; - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y; - const int *query_start_loc = kVarlen ? reinterpret_cast(params.query_start_loc_ptr) : nullptr; - const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id; - const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen; - - input_t *x = reinterpret_cast(params.x_ptr) + sequence_start_index * params.x_batch_stride - + channel_id * params.x_c_stride; - weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride - + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - bool has_initial_state = params.has_initial_state_ptr == nullptr ? false - : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; - - int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr - : reinterpret_cast(params.cache_indices_ptr); - int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - // cache_index == params.pad_slot_id is defined as padding, so we exit early - if (cache_index == params.pad_slot_id){ - return; - } - input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr - : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; - - // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. - if (tidx == 0) { - input_t initial_state[kNElts] = {0}; - if (has_initial_state) { - #pragma unroll - for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; } - } - smem_exchange[kNThreads - 1] = reinterpret_cast(initial_state)[0]; - } - - float weight_vals[kWidth]; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } - - constexpr int kChunkSize = kNThreads * kNElts; - const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize; - for (int chunk = 0; chunk < n_chunks; ++chunk) { - input_t x_vals_load[2 * kNElts] = {0}; - if constexpr(kIsVecLoad) { - typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts); - } else { - __syncthreads(); - typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize); - } - x += kChunkSize; - __syncthreads(); - // Thread kNThreads - 1 don't write yet, so that thread 0 can read - // the last elements of the previous chunk. - if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } - __syncthreads(); - reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; - __syncthreads(); - // Now thread kNThreads - 1 can write the last elements of the current chunk. - if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } - - float x_vals[2 * kNElts]; - #pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } - - float out_vals[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = bias_val; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; - } - } - - if (params.silu_activation) { - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); - } - } - - input_t out_vals_store[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } - if constexpr(kIsVecLoad) { - typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts); - } else { - typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize); - } - out += kChunkSize; - - int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize); - // in case the final state is separated between the last "smem_exchange" and - // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), - // (which occurs when `final_state_position` is a non-positive index) - // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it - if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){ - input_t vals_load[kNElts] = {0}; - if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){ - // chunk = n_chunks - 2, a segment of the final state sits in the last index - reinterpret_cast(vals_load)[0] = smem_exchange[kNThreads - 1]; - #pragma unroll - for (int w = 0; w < -final_state_position; ++w){ - conv_states[w] = vals_load[kNElts + final_state_position + w]; - } - } - if ((chunk == n_chunks - 1) && tidx == 0){ - // chunk = n_chunks - 1, the second segment of the final state first positions - reinterpret_cast(vals_load)[0] = smem_exchange[0]; - for (int w = -final_state_position; w < kWidth - 1; ++w){ - conv_states[w] = vals_load[w + final_state_position]; - } - return; - } - } - } - // Final state is stored in the smem_exchange last token slot, - // in case seqlen < kWidth, we would need to take the final state from the - // initial state which is stored in conv_states - // in case seqlen > kWidth, we would need to load the last kWidth - 1 data - // and load it into conv_state accordingly - int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts; - if (conv_states != nullptr && tidx == last_thread) { - input_t x_vals_load[kNElts * 2] = {0}; - // in case we are on the first kWidth tokens - if (last_thread == 0 && seqlen < kWidth){ - // Need to take the initial state - reinterpret_cast(x_vals_load)[0] = smem_exchange[0]; - const int offset = seqlen - (kWidth - 1); - #pragma unroll - for (int w = 0; w < kWidth - 1; ++w){ - // pad the existing state - if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; } - else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); } - } - #pragma unroll - for (int w = 0; w < kWidth - 1; ++w){ - if (offset + w >= 0) - conv_states[w] = x_vals_load[offset + w ]; - } - } - else { - // in case the final state is in between the threads data - const int offset = ((seqlen - (kWidth - 1)) % (kNElts)); - if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){ - // In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a - // illegal access error on H100. - // Therefore, we access last_thread + 1, only if the final state data sits there - reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; - } - reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; - #pragma unroll - for (int w = 0; w < kWidth - 1; ++w){ - conv_states[w] = x_vals_load[offset + w ]; - } - } - - } -} - - -template -void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - const bool kVarlen = params.query_start_loc_ptr != nullptr; - BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] { - using Ktraits = Causal_conv1d_fwd_kernel_traits; - constexpr int kSmemSize = Ktraits::kSmemSize; - dim3 grid(params.batch, params.dim); - - auto kernel = &causal_conv1d_fwd_kernel; - - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - } - kernel<<>>(params); - - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -template -void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} - - -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - - - - -template -struct Causal_conv1d_update_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_update_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y * kNThreads + tidx; - if (channel_id >= params.dim) return; - - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + channel_id * params.x_c_stride; - - // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor - // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. - const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr - ? batch_id - : params.conv_state_indices_ptr[batch_id]; - // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early - if (conv_state_batch_coord == params.pad_slot_id){ - return; - } - input_t *conv_state = reinterpret_cast(params.conv_state_ptr) - + conv_state_batch_coord * params.conv_state_batch_stride - + channel_id * params.conv_state_c_stride; - - weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - int state_len = params.conv_state_len; - int advance_len = params.seqlen; - int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; - int update_idx = cache_seqlen - (kWidth - 1); - update_idx = update_idx < 0 ? update_idx + state_len : update_idx; - - float weight_vals[kWidth] = {0}; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } - - float x_vals[kWidth] = {0}; - if constexpr (!kIsCircularBuffer) { - #pragma unroll 2 - for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { - conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; - } - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { - input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; - if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { - conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; - } - x_vals[i] = float(state_val); - } - } else { - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { - input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; - x_vals[i] = float(state_val); - } - } - #pragma unroll 2 - for (int i = 0; i < params.seqlen; ++i) { - input_t x_val = x[i * params.x_l_stride]; - if constexpr (!kIsCircularBuffer) { - if (i < advance_len && state_len - advance_len + i >= 0) { - conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; - } - } else { - conv_state[update_idx * params.conv_state_l_stride] = x_val; - ++update_idx; - update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; - } - x_vals[kWidth - 1] = float(x_val); - float out_val = bias_val; - #pragma unroll - for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } - if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } - out[i * params.out_l_stride] = input_t(out_val); - // Shift the input buffer by 1 - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } - } -} - -template -void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - using Ktraits = Causal_conv1d_update_kernel_traits; - dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); - auto kernel = params.cache_seqlens == nullptr - ? &causal_conv1d_update_kernel - : &causal_conv1d_update_kernel; - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); - } -} - -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h deleted file mode 100644 index e26684a2b98b..000000000000 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ /dev/null @@ -1,159 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ -// clang-format off -// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h -#pragma once - -#include -#include -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ConvParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, width; - int64_t pad_slot_id; - bool silu_activation; - - index_t x_batch_stride; - index_t x_c_stride; - index_t x_l_stride; - index_t weight_c_stride; - index_t weight_width_stride; - index_t out_batch_stride; - index_t out_c_stride; - index_t out_l_stride; - - int conv_state_len; - index_t conv_state_batch_stride; - index_t conv_state_c_stride; - index_t conv_state_l_stride; - - // Common data pointers. - void *__restrict__ x_ptr; - void *__restrict__ weight_ptr; - void *__restrict__ bias_ptr; - void *__restrict__ out_ptr; - - void *__restrict__ conv_state_ptr; - void *__restrict__ query_start_loc_ptr; - void *__restrict__ has_initial_state_ptr; - void *__restrict__ cache_indices_ptr; - int32_t *__restrict__ cache_seqlens; - - // For the continuous batching case. Makes it so that the mamba state for - // the current batch doesn't need to be a contiguous tensor. - int32_t *__restrict__ conv_state_indices_ptr; - - void *__restrict__ seq_idx_ptr; - - // No __restrict__ since initial_states could be the same as final_states. - void * initial_states_ptr; - index_t initial_states_batch_stride; - index_t initial_states_l_stride; - index_t initial_states_c_stride; - - void * final_states_ptr; - index_t final_states_batch_stride; - index_t final_states_l_stride; - index_t final_states_c_stride; - - void * conv_states_ptr; - index_t conv_states_batch_stride; - index_t conv_states_l_stride; - index_t conv_states_c_stride; -}; - - -#ifndef USE_ROCM - #include - - template - __device__ inline T shuffle_xor(T val, int offset) { - return __shfl_xor_sync(uint32_t(-1), val, offset); - } - - constexpr size_t custom_max(std::initializer_list ilist) - { - return std::max(ilist); - } - - template - constexpr T constexpr_min(T a, T b) { - return std::min(a, b); - } - -#else - #include - - template - __device__ inline T shuffle_xor(T val, int offset) { - return __shfl_xor(val, offset); - } - constexpr size_t custom_max(std::initializer_list ilist) - { - return *std::max_element(ilist.begin(), ilist.end()); - } - - template - constexpr T constexpr_min(T a, T b) { - return a < b ? a : b; - } -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template struct BytesToType {}; - -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } -}; - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ inline T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -template<> -struct Allreduce<2> { -template -static __device__ inline T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h deleted file mode 100644 index ef74bf447f84..000000000000 --- a/csrc/mamba/causal_conv1d/static_switch.h +++ /dev/null @@ -1,28 +0,0 @@ -// Inspired by -// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h -// clang-format off -// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - static constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - static constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/csrc/ops.h b/csrc/ops.h index 52c264d64cca..7f3e6b6923a3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const std::optional& has_initial_state, const torch::Tensor& ssm_states, int64_t pad_slot_id); -void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, - const at::Tensor& weight, - const std::optional& bias_, - bool silu_activation, - const std::optional& cache_seqlens_, - const std::optional& conv_state_indices_, - int64_t pad_slot_id); - -void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, - const std::optional& bias_, - const std::optional& conv_states, - const std::optional& query_start_loc, - const std::optional& cache_indices, - const std::optional& has_initial_state, - bool silu_activation, int64_t pad_slot_id); - using fptr_t = int64_t; fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9414e26196b2..1920bec42238 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int pad_slot_id) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); - ops.def( - "causal_conv1d_update(Tensor! x," - "Tensor! conv_state," - "Tensor! weight," - "Tensor? bias_," - "bool silu_activation," - "Tensor? cache_seqlens_," - "Tensor? conv_state_indices," - "int pad_slot_id) -> ()"); - ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); - - ops.def( - "causal_conv1d_fwd(Tensor! x, Tensor! weight," - "Tensor? bias_," - "Tensor!? conv_states," - "Tensor? query_start_loc," - "Tensor? cache_indices," - "Tensor? has_initial_state," - "bool silu_activation," - "int pad_slot_id) -> ()"); - ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); - #ifndef USE_ROCM // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel ops.def( diff --git a/tests/kernels/mamba/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py index addb8bfcda13..411bd9e904b0 100644 --- a/tests/kernels/mamba/test_causal_conv1d.py +++ b/tests/kernels/mamba/test_causal_conv1d.py @@ -6,9 +6,8 @@ import pytest import torch import torch.nn.functional as F +from einops import rearrange -from tests.kernels.utils import opcheck -from vllm import _custom_ops as ops # noqa: F401 from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) @@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, x = x.contiguous() bias = bias.contiguous() if bias is not None else None - opcheck(torch.ops._C.causal_conv1d_fwd, - (x, weight, bias, conv_states, cu_seq_len, cache_indices, - has_initial_state, activation in ["silu", "swish"], pad_slot_id)) - - -@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) -@pytest.mark.parametrize("silu_activation", [True]) -@pytest.mark.parametrize("has_bias", [True]) -@pytest.mark.parametrize("has_initial_state", [True, False]) -@pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize( - 'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096]) -@pytest.mark.parametrize('dim', [64]) -@pytest.mark.parametrize('batch', [1]) -def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, - has_initial_state, itype): - device = "cuda" - rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 1e-2, 5e-2 - # set seed - current_platform.seed_everything(0) - x = torch.randn(batch, dim, seqlen, device=device, - dtype=itype).contiguous() - - weight = torch.randn(dim, width, device=device, dtype=itype) - bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None - if has_initial_state: - initial_states = torch.randn(batch, - dim, - width - 1, - device=device, - dtype=itype) - has_initial_state_tensor = torch.ones(batch, - dtype=torch.bool, - device=x.device) - else: - initial_states = None - has_initial_state_tensor = None - x_ref = x.clone() - weight_ref = weight.clone() - bias_ref = bias.clone() if bias is not None else None - initial_states_ref = initial_states.clone( - ) if initial_states is not None else None - activation = None if not silu_activation else "silu" - out = causal_conv1d_fn(x, - weight, - bias, - activation=activation, - conv_states=initial_states, - has_initial_state=has_initial_state_tensor) - out_ref, final_states_ref = causal_conv1d_ref( - x_ref, - weight_ref, - bias_ref, - initial_states=initial_states_ref, - return_final_states=True, - activation=activation) - if has_initial_state: - assert initial_states is not None and final_states_ref is not None - assert torch.allclose(initial_states, - final_states_ref, - rtol=rtol, - atol=atol) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - - causal_conv1d_opcheck_fn(x, - weight, - bias, - activation=activation, - conv_states=initial_states, - has_initial_state=has_initial_state_tensor) - @pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - opcheck(torch.ops._C.causal_conv1d_update, - (x, conv_state, weight, bias, activation - in ["silu", "swish"], None, None, PAD_SLOT_ID)) - @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("has_bias", [False, True]) -@pytest.mark.parametrize("seqlen", [1, 4, 5]) -@pytest.mark.parametrize("width", [2, 3, 4]) -@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +@pytest.mark.parametrize("seqlen", [1, 3]) +@pytest.mark.parametrize("width", [3, 4]) +@pytest.mark.parametrize("dim", [2048 + 16, 4096]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) -def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, - seqlen, has_bias, +@pytest.mark.parametrize("batch_size", [3]) +def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, + width, seqlen, has_bias, silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) @@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, # set seed current_platform.seed_everything(0) - batch_size = 3 padding = 5 if with_padding else 0 padded_batch_size = batch_size + padding + # total_entries = number of cache line total_entries = 10 * batch_size - x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype) + # x will be (batch, dim, seqlen) with contiguous along dim-axis + x = torch.randn(padded_batch_size, seqlen, dim, device=device, + dtype=itype).transpose(1, 2) + x_ref = x.clone() conv_state_indices = torch.randperm(total_entries)[:batch_size].to( @@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) ], dim=0) + + # conv_state will be (cache_lines, dim, state_len) + # with contiguous along dim-axis conv_state = torch.randn(total_entries, - dim, width - 1, + dim, device=device, - dtype=itype) + dtype=itype).transpose(1, 2) + conv_state_for_padding_test = conv_state.clone() weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state[conv_state_indices, :].detach().clone() activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, conv_state, weight, @@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, activation=activation) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) - assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) assert torch.equal(conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]) - - opcheck(torch.ops._C.causal_conv1d_update, - (x, conv_state, weight, bias, activation - in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID)) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize( - 'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]) +@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096]) @pytest.mark.parametrize('dim', [64, 4096]) -# tests correctness in case subset of the sequences are padded @pytest.mark.parametrize('with_padding', [True, False]) -def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, - silu_activation, itype): +@pytest.mark.parametrize('batch', [4, 10]) +def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, + has_bias, silu_activation, itype): device = "cuda" torch.cuda.empty_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) @@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, # set seed current_platform.seed_everything(0) seqlens = [] - batch_size = 4 - if seqlen < 10: - batch_size = 1 + batch_size = batch padding = 3 if with_padding else 0 padded_batch_size = batch_size + padding nsplits = padded_batch_size - 1 eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values + seqlens.append( torch.diff( torch.cat( @@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0) - x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, - dtype=itype)[:, 4096:4096 + dim, :] + x = rearrange( + torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype), + "b s d -> b d s")[:, 4096:4096 + dim, :] + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None x_ref = x.clone() weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None activation = None if not silu_activation else "silu" final_states = torch.randn(total_entries, - dim, width - 1, + dim, device=x.device, - dtype=x.dtype) + dtype=x.dtype).transpose(1, 2) final_states_ref = final_states.clone() has_initial_states = torch.randint(0, 2, (cumsum.shape[0] - 1, ), @@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), ], dim=-1) + out = causal_conv1d_fn(x.squeeze(0), + weight, + bias=bias, + conv_states=final_states, + query_start_loc=cumsum.cuda(), + cache_indices=padded_state_indices, + has_initial_state=has_initial_states, + activation=activation, + pad_slot_id=PAD_SLOT_ID) - out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - padded_state_indices, has_initial_states, - final_states, activation, PAD_SLOT_ID) out_ref = [] out_ref_b = [] @@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) out_ref_tensor = torch.cat(out_ref, dim=0) - unpadded_out = out[:, :out_ref_tensor.shape[-1]] - assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) assert torch.allclose(final_states[state_indices], final_states_ref[state_indices], rtol=rtol, atol=atol) - - causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - padded_state_indices, has_initial_states, - final_states, activation) + unpadded_out = out[:, :out_ref_tensor.shape[-1]] + assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index ccf0ff6abd16..6a3f21ba543f 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -6,11 +6,11 @@ import torch.nn.functional as F from einops import rearrange, repeat -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - _query_start_loc_to_chunk_indices_offsets) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.platforms import current_platform +from vllm.v1.attention.backends.mamba_attn import ( + _query_start_loc_to_chunk_indices_offsets) # Added by the IBM Team, 2024 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 92db27f5b8dc..deedeef46b0c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -963,17 +963,17 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, out_dtype: torch.dtype, device: torch.device): """ - An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs + An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs the gemms for each combination based on the specified problem sizes. This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized input and expert weights. - a_/b_scales: The blockscales in FP8-E4M3 precision - - expert_offsets/sf_offsets: Indices that mark at which token index - each expert begins its computation. The number of tokens - computed with expert E is expert_offsets[E + 1] - - expert_offsets[E] And the sf_size per expert is + - expert_offsets/sf_offsets: Indices that mark at which token index + each expert begins its computation. The number of tokens + computed with expert E is expert_offsets[E + 1] - + expert_offsets[E] And the sf_size per expert is sf_offset[E+1] - sf_offset[E] - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. @@ -1464,30 +1464,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int: # mamba -def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], - conv_states: Optional[torch.Tensor], - query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - silu_activation: bool, pad_slot_id: int): - torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, - query_start_loc, cache_indices, - has_initial_state, silu_activation, - pad_slot_id) - - -def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, bias_: Optional[torch.Tensor], - silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor], - pad_slot_id: int): - torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation, cache_seqlens, - conv_state_indices, pad_slot_id) - - def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index 88053faf9e52..0a836fd17533 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -1,14 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math from dataclasses import dataclass +from typing import Optional, Union +import numpy as np import torch from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionMetadata) +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.platforms import current_platform +from vllm.v1.attention.backends.mamba_attn import ( + Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets) @dataclass @@ -21,6 +25,29 @@ class Mamba2Metadata: seq_idx: torch.Tensor chunk_indices: torch.Tensor chunk_offsets: torch.Tensor + """ + With continuous batching layout of `x` in vLLM, to enable a Triton program + to handle a request in parallel, two supporting tensors are used + (batch_ptr, token_chunk_offset_ptr) + BLOCK_M = the # tokens to be handled by a Triton program + (can be customized for different hardware) + + nums_dict: + tracks the data associated with a given value of BLOCK_M + BLOCK_M = #tokens handled by a Triton program + cu_seqlen: total tokens per batch + (used as flag to update other data at each new input) + batch_ptr: tracks batch-id handled by the Triton program + token_chunk_offset_ptr: tracks token group_idx handled by the Triton program + (Triton implementation of causal_conv1d handles parallelism in 3-axes + - feature-axis + - batch-axis + - sequence-axis) + """ + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.tensor] = None + token_chunk_offset_ptr: Optional[torch.tensor] = None def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: @@ -38,45 +65,10 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: f"Unsupported platform for Mamba2: {current_platform.device_type}") -def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, - chunk_size: int, - total_seqlens: int): - - cu_seqlens = query_start_loc[1:] # remove prepended 0 - - # outputs will have length expansion of chunks that do not divide - # chunk_size - N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size - > 0).sum() - chunk_indices = torch.arange(N, - dtype=torch.int, - device=query_start_loc.device) - chunk_offsets = torch.zeros((N, ), - dtype=torch.int, - device=query_start_loc.device) - - p = 0 # num of insertions - for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): - - # if does not divide chunk_size, then there is one chunk insertion - p += (s % chunk_size > 0) - - # get the dimensions - # - the + 1 for _e is to shift the boundary by one chunk - # - this shifting is not needed if chunk_size divides e - _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size - > 0) - - # adjust inidces and offsets - chunk_indices[_s:_e] -= p - chunk_offsets[_s] = s % chunk_size - - return chunk_indices, chunk_offsets - - def prepare_mamba2_metadata( chunk_size: int, attn_metadata: AttentionMetadata, + mamba2_metadata=None, ) -> Mamba2Metadata: # compute number of prefill and decode requests @@ -96,12 +88,12 @@ def prepare_mamba2_metadata( attn_metadata_instances = get_platform_metadata_classes() if (isinstance(attn_metadata, attn_metadata_instances) and attn_metadata.context_lens_tensor is not None): - has_initial_states = \ - attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,] - # precompute flag to avoid device syncs in mamba2 layer forwards + # precompute flag to avoid device syncs later in mamba2 layer + # forwards # prep is only needed for mamba2 ssd prefill processing - prep_initial_states = torch.any(has_initial_states).item() - + has_initial_states = attn_metadata.context_lens_tensor > 0 + prep_initial_states = torch.any( + has_initial_states[:num_prefills]).item() query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] seq_idx = torch.repeat_interleave(torch.arange( num_prefills, dtype=torch.int32, device=query_start_loc.device), @@ -117,9 +109,78 @@ def prepare_mamba2_metadata( _query_start_loc_to_chunk_indices_offsets( query_start_loc, chunk_size, num_prefill_tokens) + if mamba2_metadata is not None: + mamba2_metadata.has_initial_states = has_initial_states + mamba2_metadata.prep_initial_states = prep_initial_states + mamba2_metadata.chunk_size = chunk_size + mamba2_metadata.seq_idx = seq_idx + mamba2_metadata.chunk_indices = chunk_indices + mamba2_metadata.chunk_offsets = chunk_offsets + # We use 1 reset flag: + # * mamba2_metadata.cu_seqlen is None + # update config specific to (each input) + # (become available at first layer, e.g. conv_weights) + mamba2_metadata.cu_seqlen = None # suppose to be updated at each input + + return mamba2_metadata return Mamba2Metadata(has_initial_states=has_initial_states, prep_initial_states=prep_initial_states, chunk_size=chunk_size, seq_idx=seq_idx, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets) + + +def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor, + mamba2_metadata: Union[Mamba2Metadata, + Mamba2AttentionMetadata]): + """ + this is triggered upon handling a new input at the first layer + """ + dim, cu_seqlen = x.shape + mamba2_metadata.cu_seqlen = cu_seqlen + seqlens = np.diff(query_start_loc.to('cpu')) + nums_dict = {} # type: ignore + for BLOCK_M in [8]: # cover all BLOCK_M values + nums = -(-seqlens // BLOCK_M) + nums_dict[BLOCK_M] = {} + nums_dict[BLOCK_M]['nums'] = nums + nums_dict[BLOCK_M]['tot'] = nums.sum().item() + mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) + nums_dict[BLOCK_M]['mlist'] = mlist + mlist_len = len(nums_dict[BLOCK_M]['mlist']) + nums_dict[BLOCK_M]['mlist_len'] = mlist_len + MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 + offsetlist = [] # type: ignore + for idx, num in enumerate(nums): + offsetlist.extend(range(num)) + offsetlist = torch.tensor(offsetlist, dtype=torch.int32) + nums_dict[BLOCK_M]['offsetlist'] = offsetlist + + if mamba2_metadata.batch_ptr is None: + # Update default value after class definition + #mamba2_metadata.MAX_NUM_PROGRAMS *= 2 + mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device='cuda') + mamba2_metadata.token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device='cuda') + else: + if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS: + mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_( + PAD_SLOT_ID) + mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore + MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) + + mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist) + mamba2_metadata.token_chunk_offset_ptr[ # type: ignore + 0:mlist_len].copy_(offsetlist) + nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr + nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = ( + mamba2_metadata.token_chunk_offset_ptr) # type: ignore + mamba2_metadata.nums_dict = nums_dict + return mamba2_metadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 118bd8d55c1d..796c8d937572 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -159,7 +159,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, hidden_states = causal_conv1d_fn( hidden_states, conv_weights, - self.conv1d.bias, + bias=self.conv1d.bias, activation=self.activation, conv_states=mamba_cache_params.conv_state, has_initial_state=attn_metadata.context_lens_tensor > 0, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 9dcbcb2e6f2b..2cc30e4d3f7e 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -17,7 +17,8 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata +from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, + update_metadata) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -161,9 +162,9 @@ def mamba_v2_sharded_weight_loader( tp_size: int, tp_rank: int, ) -> LoaderFunction: - """Create a weight loader for mamba v2. This ensures that the projections - are correctly sharded so that they can be split into x, B, C. It also - ensures that all the groups corresponding to a head shard is placed + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures that all the groups corresponding to a head shard is placed together with it. """ @@ -458,9 +459,11 @@ def forward_cuda( if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] + mamba2_metadata = attn_metadata assert isinstance(attn_metadata, Mamba2AttentionMetadata) self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] state_indices_tensor = attn_metadata.state_indices_tensor has_initial_states_p = attn_metadata.has_initial_states @@ -531,6 +534,7 @@ def forward_cuda( # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension + # NOTE: V0 put prefill before decode, v1 puts decode before prefill if envs.VLLM_USE_V1: hidden_states_B_C_d, hidden_states_B_C_p = torch.split( hidden_states_B_C, @@ -579,8 +583,13 @@ def forward_cuda( # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions # pointed to by "state_indices_tensor" + x = hidden_states_B_C_p.transpose( + 0, 1) # this is the form that causal-conv see + if mamba2_metadata.cu_seqlen is None: + mamba2_metadata = update_metadata( + x, attn_metadata.query_start_loc, mamba2_metadata) hidden_states_B_C_p = causal_conv1d_fn( - hidden_states_B_C_p.transpose(0, 1), + x, conv_weights, self.conv1d.bias, activation=self.activation, @@ -590,8 +599,6 @@ def forward_cuda( query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] - # TODO: Why is this needed? - hidden_states_B_C_p = hidden_states_B_C_p.contiguous() hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_B_C_p) @@ -715,9 +722,10 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: # - heads and n_groups are TP-ed conv_dim = (self.intermediate_size + 2 * n_groups * self.ssm_state_size) + # contiguous along 'dim' axis conv_state_shape = ( - divide(conv_dim, world_size), self.conv_kernel_size - 1, + divide(conv_dim, world_size), ) # These are not TP-ed as they depend on A, dt_bias, D diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index a10c5ab69787..c1641080ea1e 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -4,102 +4,943 @@ # Copyright (c) 2024, Tri Dao. # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py -from typing import Optional +from typing import Optional, Union +import numpy as np import torch +import triton +import triton.language as tl -from vllm import _custom_ops as ops from vllm.attention.backends.utils import PAD_SLOT_ID -def causal_conv1d_fn(x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID): - """ - x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen +@triton.jit() +def _causal_conv1d_fwd_kernel( # continuous batching + # Pointers to matrices + x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences + w_ptr, # (dim, width) + bias_ptr, + initial_states_ptr, # conv_states_ptr + cache_indices_ptr, # conv_state_indices_ptr + has_initial_states_ptr, + query_start_loc_ptr, + batch_ptr, + token_chunk_offset_ptr, + o_ptr, # (dim, seqlen) - actually pointing to x_ptr + # Matrix dimensions + batch: tl.int32, # actually padded_batch + dim: tl.constexpr, + seqlen: tl.int32, # cu_seqlen + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, # stride to get to next sequence, + stride_x_dim: tl.constexpr, # stride to get to next feature-value, + stride_x_token: tl. + constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_w_dim: tl.constexpr, # stride to get to next dim-axis value + stride_w_width: tl.constexpr, # stride to get to next width-axis value + stride_istate_seq: tl.constexpr, + stride_istate_dim: tl.constexpr, + stride_istate_token: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, + HAS_CACHE: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + NP2_STATELEN: tl.constexpr, + DECODE_SEQLEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + conv_states_ptr = initial_states_ptr + conv_state_indices_ptr = cache_indices_ptr + stride_conv_state_seq = stride_istate_seq + stride_conv_state_dim = stride_istate_dim + stride_conv_state_tok = stride_istate_token + state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value + + # one program handles one chunk in a single sequence + # rather than mixing sequences - to make updating initial_states across sequences efficiently + + # single-sequence id + idx_seq = tl.load(batch_ptr + tl.program_id(0)) + chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) + + # BLOCK_N elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if idx_seq == pad_slot_id: + return + + sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) + sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) + # find the actual sequence length + seqlen = sequence_end_index - sequence_start_index + + token_offset = BLOCK_M * chunk_offset + segment_len = min(BLOCK_M, seqlen - token_offset) + + # base of the sequence + x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] + + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq) + else: + # cache_idx + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + conv_states_base = (conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + + # Does 2 things: + # 1. READ prior-block init-state data - [done by every Triton programs] + # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] + if chunk_offset == 0: + # read from conv_states + load_init_state = False + if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to( + tl.int1) + if load_init_state: + # load from conv_states + prior_tokens = conv_states_base + (state_len - + 1) * stride_conv_state_tok + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + else: + # prior-tokens are zeros + if KERNEL_WIDTH >= 2: # STRATEGY1 + # first chunk and does not have prior-token, so just set to 0 + col0 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 3: # STRATEGY1 + col1 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 4: # STRATEGY1 + col2 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 5: # STRATEGY1 + col3 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + + # STEP 2: + # here prepare data for updating conv_state + if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + # just read from 'x' + # copy 'x' data to conv_state + # load only 'x' data (and set 0 before 'x' if seqlen < state_len) + idx_tokens_last = (seqlen - state_len) + tl.arange( + 0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = x_ptr + ( + (sequence_start_index + idx_tokens_last) * + stride_x_token)[:, None] + ( + idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] + mask_x = ((idx_tokens_last >= 0)[:, None] & + (idx_tokens_last < seqlen)[:, None] & + (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + conv_states_ptrs_target = conv_states_base[None, :] + ( + idx_tokens_conv * stride_conv_state_tok)[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats + < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: + if load_init_state: + # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + conv_states_ptrs_source = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, + None] + ) # [BLOCK_M, BLOCK_N] + mask = ((conv_state_batch_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :]) + conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + + x_ptrs = x_base[None, :] + ( + (idx_tokens_conv - VAL) * + stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & + (idx_tokens_conv - VAL < seqlen)[:, None] & + (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + + tl.debug_barrier( + ) # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load + new_conv_state = tl.where( + mask, conv_state, loaded_x + ) # BUG in 'tl.where' which requires a barrier before this + conv_states_ptrs_target = conv_states_base + ( + idx_tokens_conv * + stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv + < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + else: # load_init_state == False + # update conv_state by shifting left, BUT + # set cols prior to 'x' as zeros + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + VAL = state_len - seqlen + + x_ptrs = x_base[None, :] + ( + (idx_tokens_conv - VAL) * + stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & + (idx_tokens_conv - VAL < seqlen)[:, None] & + (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + + conv_states_ptrs_target = conv_states_base + ( + idx_tokens_conv * + stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv + < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: # chunk_offset > 0 + # read prior-token data from `x` + load_init_state = True + prior_tokens = x_base + (token_offset - 1) * stride_x_token + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + if KERNEL_WIDTH == 5: + # ruff: noqa: F841 + conv_states_ptrs = prior_tokens # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, + other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + + x_base_1d = x_base + token_offset * stride_x_token # starting of chunk + + # PRE-LOAD WEIGHTS + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + mask_x_1d = idx_feats < dim + for idx_token in range(segment_len): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < segment_len) & ( + idx_feats < dim) # token-index # feature-index + o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token + ) * stride_o_token + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Union[torch.Tensor, None], + conv_states: torch.Tensor, + query_start_loc: torch.Tensor, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """support varlen + continuous batching when x is 2D tensor + + x: (dim,cu_seq_len) + cu_seq_len = total tokens of all seqs in that batch sequences are concatenated from left to right for varlen weight: (dim, width) - bias: (dim,) + conv_states: (...,dim,width - 1) itype + updated inplace if provided + [it use `cache_indices` to get the index to the cache of conv_state for that sequence + + conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True + and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' + ] query_start_loc: (batch + 1) int32 The cumulative sequence lengths of the sequences in the batch, used to index into sequence. prepended by 0. - for example: query_start_loc = torch.Tensor([0,10,16,17]), + if + x = [5, 1, 1, 1] <- continuous batching (batch=4) + then + query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is + the ending index of the last sequence + [length(query_start_loc)-1 == batch] + for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) cache_indices: (batch) int32 - indicates the corresponding state index, + indicates the corresponding state index, like so: conv_state = conv_states[cache_indices[batch_id]] has_initial_state: (batch) bool - indicates whether should the kernel take the current state as initial + indicates whether should the kernel take the current state as initial state for the calculations - conv_states: (...,dim,width - 1) itype - updated inplace if provided - activation: either None or "silu" or "swish" + [single boolean for each sequence in the batch: True or False] + bias: (dim,) + activation: either None or "silu" or "swish" or True pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 - out: (batch, dim, seqlen) + out: same shape as `x` """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(-1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - - ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, - cache_indices, has_initial_state, activation - in ["silu", "swish"], pad_slot_id) - return x - - -def causal_conv1d_update(x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: Optional[str] = None, - cache_seqlens: Optional[torch.Tensor] = None, - conv_state_indices: Optional[torch.Tensor] = None, - pad_slot_id: int = PAD_SLOT_ID): + if isinstance(activation, bool) and activation: + activation = "silu" + + args = None + out = torch.zeros_like(x) + if metadata is not None: + cu_seqlen = metadata.cu_seqlen + nums_dict = metadata.nums_dict + #x = metadata.x + args = nums_dict + batch_ptr = metadata.batch_ptr + token_chunk_offset_ptr = metadata.token_chunk_offset_ptr + else: + seqlens = np.diff(query_start_loc.to('cpu')) + args = seqlens + MAX_NUM_PROGRAMS = 1024 + + batch_ptr = torch.full( + (MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=x.device + ) # tracking which seq-idx the Triton program is handling + token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=x.device + ) # tracking BLOCK_M-based index in the sequence the Triton program is handling + + is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) + dim, cu_seqlen = x.shape + _, width = weight.shape + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + padded_batch = query_start_loc.size(0) - 1 + stride_x_seq = 0 + stride_x_dim = x.stride(0) + stride_x_token = x.stride(1) + stride_w_dim = weight.stride(0) + stride_w_width = weight.stride(1) + stride_istate_seq = 0 + stride_istate_dim = 0 + stride_istate_token = 0 + num_cache_lines = 0 + if conv_states is not None: + # extensions to support vLLM: + # 1. conv_states is used to replaced initial_states + # 2. conv_states serve as a cache with num cache lines can be larger than batch size + # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] + # 4. computation can be skipped if cache_indices[idx] == pad_slot_id + num_cache_lines = conv_states.size(0) + assert (num_cache_lines, dim, width - 1) == conv_states.shape + stride_istate_seq = conv_states.stride(0) + stride_istate_dim = conv_states.stride(1) + stride_istate_token = conv_states.stride(2) + assert stride_istate_dim == 1 + if out.dim() == 2: + stride_o_seq = 0 + stride_o_dim = out.stride(0) + stride_o_token = out.stride(1) + else: + stride_o_seq = out.stride(0) + stride_o_dim = out.stride(1) + stride_o_token = out.stride(2) + + if validate_data: + assert x.dim() == 2 + assert query_start_loc is not None + assert query_start_loc.dim() == 1 + assert x.stride(0) == 1 or x.stride(1) == 1 + if bias is not None: + assert bias.dim() == 1 + assert dim == bias.size(0) + if cache_indices is not None: + assert cache_indices.dim() == 1 + assert padded_batch == cache_indices.size(0) + if has_initial_state is not None: + assert has_initial_state.size() == (padded_batch, ) + assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`" + assert weight.stride(1) == 1 + assert (dim, width) == weight.shape + assert is_channel_last, "Need to run in channel-last layout" + + if metadata is None: + + def num_program(META, seqlens): + tot = 0 + + mlist = [] + offsetlist = [] # type: ignore + + nums = -(-seqlens // META["BLOCK_M"]) + + tot = nums.sum().item() + mlist = np.repeat(np.arange(len(nums)), nums) + for idx, num in enumerate(nums): + offsetlist.extend( + range(num) + ) # chunk-idx if a sequence is split into multiple chunks + + if META["batch_ptr"].nelement() < len(mlist): + newlen = len(mlist) + 1 + META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_( + PAD_SLOT_ID) + + if META["batch_ptr"].nelement() >= len(mlist): + META["batch_ptr"][0:len(mlist)].copy_( + torch.from_numpy(np.array(mlist))) + META["token_chunk_offset_ptr"][0:len(mlist)].copy_( + torch.from_numpy(np.array(offsetlist))) + + META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) + META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to( + META["x_ptr"].device) + return tot + else: + + def num_program(META, nums_dict): + tot = nums_dict[META["BLOCK_M"]]['tot'] + + mlist = nums_dict[META["BLOCK_M"]]['mlist'] + mlist_len = nums_dict[META["BLOCK_M"]]['mlist_len'] + + offsetlist = nums_dict[META["BLOCK_M"]]['offsetlist'] + + if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: + META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] + META["token_chunk_offset_ptr"] = nums_dict[ + META["BLOCK_M"]]["token_chunk_offset_ptr"] + else: + if META["batch_ptr"].nelement() < mlist_len: + newlen = mlist_len + 1 + META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_( + PAD_SLOT_ID) + + if META["batch_ptr"].nelement() >= mlist_len: + META["batch_ptr"][0:mlist_len].copy_(mlist) + META["token_chunk_offset_ptr"][0:mlist_len].copy_( + offsetlist) + return tot + + def grid(META): + return ( + num_program(META, args), + triton.cdiv(dim, META["BLOCK_N"]), + ) + + if batch_ptr.device != x.device: + batch_ptr = batch_ptr.to(x.device) + token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device) + + _causal_conv1d_fwd_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_states, + cache_indices, + has_initial_state, + query_start_loc, + batch_ptr, + token_chunk_offset_ptr, + out, + # Matrix dimensions + padded_batch, + dim, + cu_seqlen, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + HAS_INITIAL_STATES=has_initial_state is not None, + HAS_CACHE=conv_states is not None, + IS_CONTINUOUS_BATCHING=cache_indices is not None, + USE_PAD_SLOT=pad_slot_id is not None, + NP2_STATELEN=np2_statelen, + DECODE_SEQLEN=1, + #launch_cooperative_grid=True + BLOCK_M=8, + BLOCK_N=256, + num_stages=2, + ) + return out + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + # STEP 1: READ init_state data + conv_states_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + conv_state_ptrs_source = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ((conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :]) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim + ) # [BLOCK_N] + + x_ptrs = x_base[None, :] + ( + (idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens - VAL >= 0)[:, None] & + (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + conv_state_ptrs_target = conv_state_base + ( + idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, + other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.static_range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & (idx_feats < dim + ) # token-index # feature-index + o_ptrs = o_ptr + ( + idx_seq) * stride_o_seq + idx_token * stride_o_token + ( + idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): """ x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 + [shape=2: single token prediction] + [shape=3: single or multiple tokens prediction] + conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) cache_seqlens: (batch,), dtype int32. If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state + The conv_state will be updated by copying x to the conv_state starting at the index @cache_seqlens % state_len. conv_state_indices: (batch,), dtype int32 - If not None, the conv_state is a larger tensor along the batch dim, + If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] - in this case, the kernel will not process entries at + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation_val = activation in ["silu", "swish"] + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] unsqueeze = x.dim() == 2 if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(-1) - ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, - cache_seqlens, conv_state_indices, pad_slot_id) + batch, dim, seqlen = x.shape + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert conv_state.stride( + -2 + ) == 1, f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch, ) == conv_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + stride_x_seq, stride_x_dim, stride_x_token = x.stride( + ) # X (batch, dim, seqlen) + + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( + ) + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) if unsqueeze: - x = x.squeeze(-1) - return x + out = out.squeeze(-1) + return out diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 49ba974c69a5..27685c59a3ea 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -36,10 +36,12 @@ def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, # Initialize parent class super().__init__(max_batch_size) + # assume conv_state = (dim, state_len) + assert conv_state_shape[0] > conv_state_shape[1] conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + - conv_state_shape, + (conv_state_shape[1], conv_state_shape[0]), dtype=dtype, - device="cuda") + device="cuda").transpose(-1, -2) temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + temporal_state_shape, dtype=dtype, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 74d619aadbdc..9dea08b65837 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -1,14 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - _query_start_loc_to_chunk_indices_offsets) from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import MambaSpec @@ -29,6 +28,42 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: return chunk_sizes.pop() +def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, + chunk_size: int, + total_seqlens: int): + + cu_seqlens = query_start_loc[1:] # remove prepended 0 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, + dtype=torch.int, + device=query_start_loc.device) + chunk_offsets = torch.zeros((N, ), + dtype=torch.int, + device=query_start_loc.device) + + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size + > 0) + + # adjust indices and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + class Mamba2AttentionBackend(AttentionBackend): @staticmethod @@ -53,6 +88,10 @@ class Mamba2AttentionMetadata: chunk_offsets: torch.Tensor state_indices_tensor: torch.Tensor # shape: [batch,] + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.tensor] = None + token_chunk_offset_ptr: Optional[torch.tensor] = None class Mamba2AttentionMetadataBuilder(