@@ -107,13 +107,14 @@ __global__ void merge_attn_states_kernel(
107107
108108#define LAUNCH_MERGE_ATTN_STATES (scalar_t, NUM_THREADS ) \
109109 { \
110- vllm::merge_attn_states_kernel<scalar_t , NUM_THREADS><<<grid, block>>> ( \
111- reinterpret_cast <scalar_t *>(output.data_ptr ()), output_lse_ptr, \
112- reinterpret_cast <scalar_t *>(prefix_output.data_ptr ()), \
113- reinterpret_cast <float *>(prefix_lse.data_ptr ()), \
114- reinterpret_cast <scalar_t *>(suffix_output.data_ptr ()), \
115- reinterpret_cast <float *>(suffix_lse.data_ptr ()), num_tokens, \
116- num_heads, head_size); \
110+ vllm::merge_attn_states_kernel<scalar_t , NUM_THREADS> \
111+ <<<grid, block, 0 , stream>>> ( \
112+ reinterpret_cast <scalar_t *>(output.data_ptr ()), output_lse_ptr, \
113+ reinterpret_cast <scalar_t *>(prefix_output.data_ptr ()), \
114+ reinterpret_cast <float *>(prefix_lse.data_ptr ()), \
115+ reinterpret_cast <scalar_t *>(suffix_output.data_ptr ()), \
116+ reinterpret_cast <float *>(suffix_lse.data_ptr ()), num_tokens, \
117+ num_heads, head_size); \
117118 }
118119
119120/* @brief Merges the attention states from prefix and suffix
@@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel(
122123 * @param output [n,h,d] The output tensor to store the merged attention states.
123124 * @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
124125 * @param prefix_output [n,h,d] The prefix attention states.
125- * @param prefix_lse [h,d ] The log-sum-exp values for the prefix attention
126+ * @param prefix_lse [h,n ] The log-sum-exp values for the prefix attention
126127 * states.
127128 * @param suffix_output [n,h,d] The suffix attention states.
128- * @param suffix_lse [h,d ] The log-sum-exp values for the suffix attention
129+ * @param suffix_lse [h,n ] The log-sum-exp values for the suffix attention
129130 * states.
130131 */
131132template <typename scalar_t >
@@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
146147 if (output_lse.has_value ()) {
147148 output_lse_ptr = output_lse.value ().data_ptr <float >();
148149 }
149- // process one pack elements per thread. float -> 4, half/bf16 -> 8
150+ // Process one pack elements per thread. for float, the
151+ // pack_size is 4 for half/bf16, the pack_size is 8.
150152 const uint threads_per_head = head_size / pack_size;
151153 const uint total_threads = num_tokens * num_heads * threads_per_head;
152154
153155 dim3 block (NUM_THREADS);
154156 dim3 grid ((total_threads + NUM_THREADS - 1 ) / NUM_THREADS);
155157
158+ const c10::cuda::OptionalCUDAGuard device_guard (prefix_output.device ());
159+ auto stream = at::cuda::getCurrentCUDAStream ();
160+
156161 LAUNCH_MERGE_ATTN_STATES (scalar_t , NUM_THREADS);
157162}
158163
0 commit comments