Skip to content

Commit ffe2b6b

Browse files
authored
Merge pull request #14 from Sulfur6/sgl.sbo.public
[Feat] Single Batch Overlap (SBO): Overlaping of Down GEMM with Combine Send
2 parents f4adba8 + 5f99d8d commit ffe2b6b

File tree

11 files changed

+125
-36
lines changed

11 files changed

+125
-36
lines changed

csrc/apis/gemm.hpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,17 @@ static void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torc
175175
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
176176
}
177177

178-
static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
178+
static std::optional<std::pair<int, int>> m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
179179
const std::pair<torch::Tensor, torch::Tensor>& b,
180180
const torch::Tensor& d,
181181
const torch::Tensor& masked_m,
182182
const int& expected_m,
183183
std::optional<std::tuple<int, int, int>> recipe,
184184
const std::string& compiled_dims,
185-
const bool& disable_ue8m0_cast) {
185+
const bool& disable_ue8m0_cast,
186+
const int& max_block_n,
187+
const bool& enable_overlap,
188+
const c10::optional<torch::Tensor>& signal) {
186189
// Shape must be `[G, M, K] @ [G, N, K].mT`
187190
const auto& major_a = get_major_type_ab(a.first);
188191
const auto& major_b = get_major_type_ab(b.first);
@@ -202,6 +205,12 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
202205
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
203206
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
204207

208+
if (enable_overlap) {
209+
DG_HOST_ASSERT(signal.has_value());
210+
DG_HOST_ASSERT(signal.value().is_contiguous());
211+
DG_HOST_ASSERT(signal.value().scalar_type() == torch::kInt32);
212+
}
213+
205214
// D must be N-major
206215
check_major_type_cd(d);
207216

@@ -213,9 +222,11 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
213222

214223
// Dispatch implementation
215224
const auto& arch_major = device_runtime->get_arch_major();
225+
std::optional<std::pair<int, int>> result = std::nullopt;
216226
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
217-
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
218-
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
227+
result = sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
228+
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims,
229+
max_block_n, enable_overlap, signal);
219230
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
220231
sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
221232
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
@@ -225,6 +236,7 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
225236
} else {
226237
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
227238
}
239+
return result;
228240
}
229241

230242
static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,

csrc/jit_kernels/heuristics/common.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ struct GemmConfig {
6363
cute::UMMA::Major major_b;
6464
bool with_accumulation;
6565
int block_m, block_n, block_k;
66+
int signal_threshold;
6667
int num_stages, num_last_stages;
6768

6869
// Templated device configs
@@ -73,6 +74,8 @@ struct GemmConfig {
7374
MulticastConfig multicast_config;
7475
SharedMemoryConfig smem_config;
7576
ThreadConfig thread_config;
77+
78+
bool enable_overlap;
7679
};
7780

7881
static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
@@ -151,7 +154,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
151154
const int& m, const int& n, const int& k, const int& num_groups,
152155
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
153156
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
154-
const bool& with_accumulation, const int& num_sms) {
157+
const bool& with_accumulation, const int& num_sms,
158+
const int& max_block_n = 256, const bool& enable_overlap = false) {
155159
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16);
156160
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
157161

@@ -161,7 +165,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
161165
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
162166
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
163167
block_ms = std::vector{64, 128};
164-
const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype);
168+
const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype, max_block_n);
165169

166170
// K block size is selected in a fixed manner
167171
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
@@ -271,14 +275,16 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
271275
.block_m = best_block_m,
272276
.block_n = best_block_n,
273277
.block_k = block_k,
278+
.signal_threshold = ceil_div(n, best_block_n),
274279
.num_stages = best_num_stages,
275280
.num_last_stages = ceil_div(k, block_k) % best_num_stages,
276281
.num_sms = num_min_sms,
277282
.tc_util = device_runtime->get_tc_util(),
278283
.multicast_config = best_multicast_config,
279284
// ReSharper disable once CppLocalVariableMightNotBeInitialized
280285
.smem_config = best_smem_config,
281-
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n)
286+
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n),
287+
.enable_overlap = enable_overlap
282288
};
283289

284290
// Only SM100 BF16 kernels support tensor core control

csrc/jit_kernels/heuristics/sm100.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace deep_gemm {
1212
struct SM100ArchSpec {
1313
static constexpr int smem_capacity = 232448;
1414

15-
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
15+
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype, const int& max_block_n) {
1616
// 16 is for better SM usage
1717
// Stride 32 is due to low-performance swizzle-16/32B
1818
std::vector<int> candidates = {16};

csrc/jit_kernels/heuristics/sm90.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ namespace deep_gemm {
1111
struct SM90ArchSpec {
1212
static constexpr int smem_capacity = 232448;
1313

14-
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
14+
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype, const int& max_block_n) {
1515
// Avoid bank conflicts for FP32 output
1616
const auto& start = cd_dtype == torch::kFloat ? 8 : 16;
1717
std::vector<int> candidates;
18-
for (int i = start; i <= 256; i += 16)
18+
for (int i = start; i <= max_block_n; i += 16)
1919
candidates.push_back(i);
2020
return candidates;
2121
}

csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime>
2222
GemmConfig gemm_config;
2323
LaunchArgs launch_args;
2424

25-
void *sfb, *grouped_layout;
25+
void *sfb, *grouped_layout, *signal;
2626
CUtensorMap tensor_map_a;
2727
CUtensorMap tensor_map_b;
2828
CUtensorMap tensor_map_d;
@@ -44,7 +44,8 @@ static void __instantiate_kernel() {{
4444
{}, {},
4545
{}, {},
4646
{}, {},
47-
{}, {}, {}
47+
{}, {}, {},
48+
{}
4849
>);
4950
}};
5051
)",
@@ -57,13 +58,14 @@ static void __instantiate_kernel() {{
5758
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
5859
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
5960
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
60-
get_default_epilogue_type(args.epilogue_type));
61+
get_default_epilogue_type(args.epilogue_type),
62+
args.gemm_config.enable_overlap);
6163
}
6264

6365
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
6466
// TODO: optimize `args` copy
6567
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
66-
args.sfb, args.grouped_layout,
68+
args.sfb, args.grouped_layout, args.signal,
6769
args.m, args.n, args.k,
6870
args.tensor_map_a, args.tensor_map_b,
6971
args.tensor_map_d, args.tensor_map_sfa));
@@ -121,6 +123,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
121123
config.multicast_config.num_multicast),
122124
.sfb = sfb.data_ptr(),
123125
.grouped_layout = nullptr,
126+
.signal = nullptr,
124127
.tensor_map_a = tensor_map_a,
125128
.tensor_map_b = tensor_map_b,
126129
.tensor_map_d = tensor_map_d,
@@ -181,6 +184,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
181184
config.multicast_config.num_multicast),
182185
.sfb = sfb.data_ptr(),
183186
.grouped_layout = m_indices.data_ptr(),
187+
.signal = nullptr,
184188
.tensor_map_a = tensor_map_a,
185189
.tensor_map_b = tensor_map_b,
186190
.tensor_map_d = tensor_map_d,
@@ -191,14 +195,17 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
191195
MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args));
192196
}
193197

194-
static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
198+
static std::optional<std::pair<int, int>> sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
195199
const torch::Tensor& b, const torch::Tensor& sfb,
196200
const torch::Tensor& d,
197201
const torch::Tensor& masked_m,
198202
const int& num_groups, const int& m, const int& n, const int& k,
199203
const int& expected_m,
200204
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
201-
const std::string& compiled_dims) {
205+
const std::string& compiled_dims,
206+
const int& max_block_n,
207+
const bool& enable_overlap,
208+
const c10::optional<torch::Tensor>& signal) {
202209
const auto& aligned_k = align(k, 128);
203210
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
204211
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
@@ -207,7 +214,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
207214
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
208215
expected_m, n, k, num_groups, major_a, major_b,
209216
torch::kFloat8_e4m3fn, d.scalar_type(), false,
210-
device_runtime->get_num_sms());
217+
device_runtime->get_num_sms(), max_block_n, enable_overlap);
211218

212219
// Requires no TMA splits
213220
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
@@ -242,6 +249,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
242249
config.multicast_config.num_multicast),
243250
.sfb = sfb.data_ptr(),
244251
.grouped_layout = masked_m.data_ptr(),
252+
.signal = enable_overlap ? signal.value().data_ptr() : nullptr,
245253
.tensor_map_a = tensor_map_a,
246254
.tensor_map_b = tensor_map_b,
247255
.tensor_map_d = tensor_map_d,
@@ -250,6 +258,9 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
250258
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
251259
const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
252260
MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args));
261+
return enable_overlap ?
262+
std::optional(std::make_pair(config.block_m, config.signal_threshold)) :
263+
std::nullopt;
253264
}
254265

255266
} // namespace deep_gemm

csrc/python_api.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,16 @@ void m_grouped_fp8_gemm_nn_contiguous_wrapper(const torch::Tensor& a_val, const
137137
deep_gemm::gemm::m_grouped_fp8_gemm_nn_contiguous({a_val, a_scale}, {b_val, b_scale}, d, m_indices, to_recipe_tuple(recipe), compiled_dims, disable_ue8m0_cast);
138138
}
139139

140-
void m_grouped_fp8_gemm_nt_masked_wrapper(const torch::Tensor& a_val, const torch::Tensor& a_scale, const torch::Tensor& b_val, const torch::Tensor& b_scale, const torch::Tensor& d, const torch::Tensor& masked_m, int64_t expected_m, const c10::optional<c10::IntArrayRef>& recipe, const std::string& compiled_dims, bool disable_ue8m0_cast) {
141-
deep_gemm::gemm::m_grouped_fp8_gemm_nt_masked({a_val, a_scale}, {b_val, b_scale}, d, masked_m, expected_m, to_recipe_tuple(recipe), compiled_dims, disable_ue8m0_cast);
140+
std::tuple<c10::optional<int64_t>, c10::optional<int64_t>> m_grouped_fp8_gemm_nt_masked_wrapper(const torch::Tensor& a_val, const torch::Tensor& a_scale, const torch::Tensor& b_val, const torch::Tensor& b_scale, const torch::Tensor& d, const torch::Tensor& masked_m, int64_t expected_m, const c10::optional<c10::IntArrayRef>& recipe, const std::string& compiled_dims, bool disable_ue8m0_cast, int64_t max_block_n, bool enable_overlap, const c10::optional<torch::Tensor>& signal) {
141+
auto result = deep_gemm::gemm::m_grouped_fp8_gemm_nt_masked({a_val, a_scale}, {b_val, b_scale}, d, masked_m, expected_m, to_recipe_tuple(recipe), compiled_dims, disable_ue8m0_cast, max_block_n, enable_overlap, signal);
142+
143+
if (!result) {
144+
return std::make_tuple(c10::nullopt, c10::nullopt);
145+
}
146+
return std::make_tuple(
147+
c10::optional<int64_t>(result->first),
148+
c10::optional<int64_t>(result->second)
149+
);
142150
}
143151

144152
void k_grouped_fp8_gemm_nt_contiguous_wrapper(const torch::Tensor& a_val, const torch::Tensor& a_scale, const torch::Tensor& b_val, const torch::Tensor& b_scale, const torch::Tensor& d, c10::List<int64_t> ks, const torch::Tensor& ks_tensor, const c10::optional<torch::Tensor>& c, c10::IntArrayRef recipe, const std::string& compiled_dims) {
@@ -342,17 +350,20 @@ TORCH_LIBRARY(deep_gemm, m) {
342350
deep_gemm_wrappers::m_grouped_fp8_gemm_nn_contiguous_wrapper(a_val, a_scale, b_val, b_scale, d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
343351
});
344352

345-
m.def(R"(m_grouped_fp8_gemm_nt_masked(Any a, Any b, Tensor d, Tensor masked_m, int expected_m, int[]? recipe=None, str compiled_dims="nk", bool disable_ue8m0_cast=False) -> ())");
353+
m.def(R"(m_grouped_fp8_gemm_nt_masked(Any a, Any b, Tensor d, Tensor masked_m, int expected_m, int[]? recipe=None, str compiled_dims="nk", bool disable_ue8m0_cast=False, int max_block_n=256, bool enable_overlap=False, Tensor? signal=None) -> (int?, int?))");
346354
m.impl("m_grouped_fp8_gemm_nt_masked", torch::kCUDA, [](const c10::IValue& a_input, const c10::IValue& b_input,
347355
const torch::Tensor& d,
348356
const torch::Tensor& masked_m,
349357
int64_t expected_m,
350358
const c10::optional<c10::IntArrayRef>& recipe,
351359
const std::string& compiled_dims,
352-
bool disable_ue8m0_cast) {
360+
bool disable_ue8m0_cast,
361+
int64_t max_block_n,
362+
bool enable_overlap,
363+
const c10::optional<torch::Tensor>& signal) {
353364
auto [a_val, a_scale] = parse_tensor_or_tuple(a_input);
354365
auto [b_val, b_scale] = parse_tensor_or_tuple(b_input);
355-
deep_gemm_wrappers::m_grouped_fp8_gemm_nt_masked_wrapper(a_val, a_scale, b_val, b_scale, d, masked_m, expected_m, recipe, compiled_dims, disable_ue8m0_cast);
366+
return deep_gemm_wrappers::m_grouped_fp8_gemm_nt_masked_wrapper(a_val, a_scale, b_val, b_scale, d, masked_m, expected_m, recipe, compiled_dims, disable_ue8m0_cast, max_block_n, enable_overlap, signal);
356367
});
357368

358369
m.def(R"(k_grouped_fp8_gemm_nt_contiguous(Any a, Any b, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c=None, int[] recipe=[1, 1, 128], str compiled_dims="mn") -> ())");

deep_gemm/include/deep_gemm/common/utils.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ __device__ __forceinline__ void prefetch_l1(void *ptr) {
158158
asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
159159
}
160160

161+
__device__ __forceinline__ void store_wait() {
162+
asm volatile("cp.async.bulk.wait_group 0;\n" ::: "memory");
163+
}
164+
165+
__device__ __forceinline__ int atomic_add_release_global(int* addr, int value) {
166+
int ret;
167+
asm volatile ("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(value));
168+
return ret;
169+
}
170+
161171
template <uint32_t kNumBytes>
162172
struct Vectorized {
163173
static auto zeros() {

deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
3838
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
3939
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
4040
uint32_t kNumSMs, GemmType kGemmType,
41-
typename epilogue_type_t>
41+
typename epilogue_type_t, bool kEnableOverlap>
4242
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
43-
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
43+
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal,
4444
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
4545
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
4646
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
@@ -395,6 +395,18 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
395395
cute::tma_store_arrive();
396396
}
397397
__syncwarp();
398+
399+
if constexpr (kEnableOverlap) {
400+
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
401+
store_wait();
402+
}
403+
404+
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
405+
406+
if (threadIdx.x == 0) {
407+
atomic_add_release_global(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1);
408+
}
409+
}
398410
}
399411
}
400412
#else

deep_gemm/testing/numeric.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,18 @@ def count_bytes(*tensors):
1717
elif t is not None:
1818
total += t.numel() * t.element_size()
1919
return total
20+
21+
def check_signal(num_local_expert, max_m, block_m, threshold, signal, masked_m):
22+
ceil_div = lambda a, b: (a + b - 1) // b
23+
24+
expert_len = max_m // block_m
25+
for expert in range(num_local_expert):
26+
mask = masked_m[expert]
27+
start = expert * expert_len
28+
end = expert * expert_len + expert_len
29+
valid_len = ceil_div(mask, block_m)
30+
for i in range(start, end):
31+
if i < start + valid_len:
32+
assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}'
33+
else:
34+
assert signal[i] == 0, f'{i=}, {signal[i]=}'

0 commit comments

Comments
 (0)