@@ -22,7 +22,7 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime>
2222 GemmConfig gemm_config;
2323 LaunchArgs launch_args;
2424
25- void *sfb, *grouped_layout;
25+ void *sfb, *grouped_layout, *signal ;
2626 CUtensorMap tensor_map_a;
2727 CUtensorMap tensor_map_b;
2828 CUtensorMap tensor_map_d;
@@ -44,7 +44,8 @@ static void __instantiate_kernel() {{
4444 {}, {},
4545 {}, {},
4646 {}, {},
47- {}, {}, {}
47+ {}, {}, {},
48+ {}
4849 >);
4950}};
5051)" ,
@@ -57,13 +58,14 @@ static void __instantiate_kernel() {{
5758 args.gemm_config .thread_config .num_tma_threads , args.gemm_config .thread_config .num_math_threads ,
5859 args.gemm_config .multicast_config .num_multicast , args.gemm_config .multicast_config .is_multicast_on_a ,
5960 args.gemm_config .num_sms , to_string (args.gemm_config .gemm_type ),
60- get_default_epilogue_type (args.epilogue_type ));
61+ get_default_epilogue_type (args.epilogue_type ),
62+ args.gemm_config .enable_overlap );
6163 }
6264
6365 static void launch_impl (const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
6466 // TODO: optimize `args` copy
6567 DG_CUDA_UNIFIED_CHECK (launch_kernel (kernel, config,
66- args.sfb , args.grouped_layout ,
68+ args.sfb , args.grouped_layout , args. signal ,
6769 args.m , args.n , args.k ,
6870 args.tensor_map_a , args.tensor_map_b ,
6971 args.tensor_map_d , args.tensor_map_sfa ));
@@ -121,6 +123,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
121123 config.multicast_config .num_multicast ),
122124 .sfb = sfb.data_ptr (),
123125 .grouped_layout = nullptr ,
126+ .signal = nullptr ,
124127 .tensor_map_a = tensor_map_a,
125128 .tensor_map_b = tensor_map_b,
126129 .tensor_map_d = tensor_map_d,
@@ -181,6 +184,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
181184 config.multicast_config .num_multicast ),
182185 .sfb = sfb.data_ptr (),
183186 .grouped_layout = m_indices.data_ptr (),
187+ .signal = nullptr ,
184188 .tensor_map_a = tensor_map_a,
185189 .tensor_map_b = tensor_map_b,
186190 .tensor_map_d = tensor_map_d,
@@ -191,14 +195,17 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
191195 MAYBE_LAUNCH (SM90FP8Gemm1D2DRuntime::launch (runtime, args));
192196}
193197
194- static void sm90_m_grouped_fp8_gemm_masked_1d2d (const torch::Tensor& a, const torch::Tensor& sfa,
198+ static std::optional<std::pair< int , int >> sm90_m_grouped_fp8_gemm_masked_1d2d (const torch::Tensor& a, const torch::Tensor& sfa,
195199 const torch::Tensor& b, const torch::Tensor& sfb,
196200 const torch::Tensor& d,
197201 const torch::Tensor& masked_m,
198202 const int & num_groups, const int & m, const int & n, const int & k,
199203 const int & expected_m,
200204 const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
201- const std::string& compiled_dims) {
205+ const std::string& compiled_dims,
206+ const int & max_block_n,
207+ const bool & enable_overlap,
208+ const c10::optional<torch::Tensor>& signal) {
202209 const auto & aligned_k = align (k, 128 );
203210 DG_HOST_ASSERT (d.scalar_type () == torch::kBFloat16 );
204211 DG_HOST_ASSERT (major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
@@ -207,7 +214,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
207214 GemmType::MGroupedMasked, KernelType::Kernel1D2D,
208215 expected_m, n, k, num_groups, major_a, major_b,
209216 torch::kFloat8_e4m3fn , d.scalar_type (), false ,
210- device_runtime->get_num_sms ());
217+ device_runtime->get_num_sms (), max_block_n, enable_overlap );
211218
212219 // Requires no TMA splits
213220 DG_HOST_ASSERT (config.smem_config .swizzle_a_mode == config.block_k );
@@ -242,6 +249,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
242249 config.multicast_config .num_multicast ),
243250 .sfb = sfb.data_ptr (),
244251 .grouped_layout = masked_m.data_ptr (),
252+ .signal = enable_overlap ? signal.value ().data_ptr () : nullptr ,
245253 .tensor_map_a = tensor_map_a,
246254 .tensor_map_b = tensor_map_b,
247255 .tensor_map_d = tensor_map_d,
@@ -250,6 +258,9 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
250258 const auto & code = SM90FP8Gemm1D2DRuntime::generate (args);
251259 const auto & runtime = compiler->build (" sm90_fp8_m_grouped_gemm_masked_1d2d" , code);
252260 MAYBE_LAUNCH (SM90FP8Gemm1D2DRuntime::launch (runtime, args));
261+ return enable_overlap ?
262+ std::optional (std::make_pair (config.block_m , config.signal_threshold )) :
263+ std::nullopt ;
253264}
254265
255266} // namespace deep_gemm
0 commit comments