@@ -36,6 +36,7 @@ limitations under the License.
3636#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
3737void sm100_cutlass_mla_decode (
3838 torch::Tensor const & out,
39+ torch::Tensor const & lse,
3940 torch::Tensor const & q_nope,
4041 torch::Tensor const & q_pe,
4142 torch::Tensor const & kv_c_and_k_pe_cache,
@@ -99,6 +100,7 @@ struct MlaSm100 {
99100template <typename T>
100101typename T::Fmha::Arguments args_from_options (
101102 at::Tensor const & out,
103+ at::Tensor const & lse,
102104 at::Tensor const & q_nope,
103105 at::Tensor const & q_pe,
104106 at::Tensor const & kv_c_and_k_pe_cache,
@@ -162,7 +164,10 @@ typename T::Fmha::Arguments args_from_options(
162164 stride_PT,
163165 page_count_total,
164166 page_size},
165- {static_cast <ElementOut*>(out.data_ptr ()), stride_O, static_cast <ElementAcc*>(nullptr ), stride_LSE},
167+ {static_cast <ElementOut*>(out.data_ptr ()),
168+ stride_O,
169+ static_cast <ElementAcc*>(lse.defined () ? lse.data_ptr () : nullptr ),
170+ stride_LSE},
166171 hw_info,
167172 // TODO(trevor-m): Change split_kv back to -1 when
168173 // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
@@ -181,6 +186,7 @@ typename T::Fmha::Arguments args_from_options(
181186template <typename Element, typename ElementOut, bool IsPaged128, typename PersistenceOption>
182187void runMla (
183188 at::Tensor const & out,
189+ at::Tensor const & lse,
184190 at::Tensor const & q_nope,
185191 at::Tensor const & q_pe,
186192 at::Tensor const & kv_c_and_k_pe_cache,
@@ -192,7 +198,7 @@ void runMla(
192198 cudaStream_t stream) {
193199 using MlaSm100Type = MlaSm100<Element, ElementOut, IsPaged128, PersistenceOption>;
194200 typename MlaSm100Type::Fmha fmha;
195- auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
201+ auto arguments = args_from_options<MlaSm100Type>(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
196202
197203 CUTLASS_CHECK (fmha.can_implement (arguments));
198204
@@ -214,6 +220,7 @@ void runMla(
214220
215221void sm100_cutlass_mla_decode (
216222 torch::Tensor const & out,
223+ torch::Tensor const & lse,
217224 torch::Tensor const & q_nope,
218225 torch::Tensor const & q_pe,
219226 torch::Tensor const & kv_c_and_k_pe_cache,
@@ -234,13 +241,13 @@ void sm100_cutlass_mla_decode(
234241 DISPATCH_BOOL (num_kv_splits <= 1 , NotManualSplitKV, [&] {
235242 if (in_dtype == at::ScalarType::Half) {
236243 runMla<cutlass::half_t , cutlass::half_t , IsPaged128, IsPersistent<NotManualSplitKV>>(
237- out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
244+ out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
238245 } else if (in_dtype == at::ScalarType::BFloat16) {
239246 runMla<cutlass::bfloat16_t , cutlass::bfloat16_t , IsPaged128, IsPersistent<NotManualSplitKV>>(
240- out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
247+ out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
241248 } else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
242249 runMla<cutlass::float_e4m3_t , cutlass::bfloat16_t , IsPaged128, IsPersistent<NotManualSplitKV>>(
243- out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
250+ out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
244251 } else {
245252 TORCH_CHECK (false , " Unsupported input data type of MLA" );
246253 }
0 commit comments