@@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
312312 // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
313313 constexpr bool kIsVariableB = true ;
314314 constexpr bool kIsVariableC = true ;
315- constexpr bool kHasZ = true ;
316315 BOOL_SWITCH (params.seqlen % (kNThreads * kNItems ) == 0 , kIsEvenLen , [&] {
317- BOOL_SWITCH (params.query_start_loc_ptr != nullptr , kVarlen , [&] {
318- using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads , kNItems , kNRows , kIsEvenLen , kIsVariableB , kIsVariableC , kHasZ , kVarlen , input_t , weight_t >;
319- constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof (typename Ktraits::scan_t );
320- dim3 grid (params.batch , params.dim / kNRows );
321- auto kernel = &selective_scan_fwd_kernel<Ktraits>;
322- if (kSmemSize >= 48 * 1024 ) {
323- C10_CUDA_CHECK (cudaFuncSetAttribute (
324- (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
325- }
326- kernel<<<grid, Ktraits::kNThreads , kSmemSize , stream>>> (params);
327- C10_CUDA_KERNEL_LAUNCH_CHECK ();
316+ BOOL_SWITCH (params.z_ptr != nullptr , kHasZ , [&] {
317+ BOOL_SWITCH (params.query_start_loc_ptr != nullptr , kVarlen , [&] {
318+ using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads , kNItems , kNRows , kIsEvenLen , kIsVariableB , kIsVariableC , kHasZ , kVarlen , input_t , weight_t >;
319+ constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof (typename Ktraits::scan_t );
320+ dim3 grid (params.batch , params.dim / kNRows );
321+ auto kernel = &selective_scan_fwd_kernel<Ktraits>;
322+ if (kSmemSize >= 48 * 1024 ) {
323+ C10_CUDA_CHECK (cudaFuncSetAttribute (
324+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
325+ }
326+ kernel<<<grid, Ktraits::kNThreads , kSmemSize , stream>>> (params);
327+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
328+ });
328329 });
329330 });
330331}
@@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
612613
613614 at::Tensor z, out_z;
614615 const bool has_z = z_.has_value ();
615- TORCH_CHECK (has_z, " has_z = False is disabled in favor of reduced binary size" )
616- z = z_.value ();
617- TORCH_CHECK (z.scalar_type () == input_type);
618- TORCH_CHECK (z.is_cuda ());
619- TORCH_CHECK (z.stride (-1 ) == 1 || z.size (-1 ) == 1 );
620- if (varlen){
621- CHECK_SHAPE (z, dim, seqlen);
622- } else {
623- CHECK_SHAPE (z, batch_size, dim, seqlen);
616+ if (has_z) {
617+ z = z_.value ();
618+ TORCH_CHECK (z.scalar_type () == input_type);
619+ TORCH_CHECK (z.is_cuda ());
620+ TORCH_CHECK (z.stride (-1 ) == 1 || z.size (-1 ) == 1 );
621+ if (varlen){
622+ CHECK_SHAPE (z, dim, seqlen);
623+ } else {
624+ CHECK_SHAPE (z, batch_size, dim, seqlen);
625+ }
626+
627+ out_z = z;
624628 }
625629
626- out_z = z;
627-
628630 // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
629631 at::Tensor out = delta;
630632 TORCH_CHECK (ssm_states.scalar_type () == input_type);
@@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
653655 selective_scan_fwd_cuda<input_t , weight_t >(params, stream);
654656 });
655657}
656-
0 commit comments