diff --git a/scattermoe/__init__.py b/scattermoe/__init__.py index ca2a509..f495c5a 100644 --- a/scattermoe/__init__.py +++ b/scattermoe/__init__.py @@ -1,5 +1,5 @@ from . import kernels from . import parallel_experts from . import mlp -from .triton_implementation import padded_block_indices +from .triton_implementation import expert_boundaries from .parallel_experts import ParallelExperts diff --git a/scattermoe/kernels/triton.py b/scattermoe/kernels/triton.py index 89274ce..19f9b50 100644 --- a/scattermoe/kernels/triton.py +++ b/scattermoe/kernels/triton.py @@ -3,7 +3,8 @@ @triton.autotune( - configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], + configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, + num_stages=4, num_warps=4)], key=["N", "K"], ) @triton.jit @@ -43,7 +44,8 @@ def scatter2scatter_triton_kernel( block_start_idx = tl.load(block_start_idx_ptr + M_block_id) M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) - E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) + E_idxs = tl.load(expert_idxs_ptr + M_block, + mask=M_block < (FAN_OUT * M), other=E) E_idx = tl.min(E_idxs) E_mask = E_idxs == E_idx M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) @@ -63,8 +65,10 @@ def scatter2scatter_triton_kernel( N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) N_mask = N_block < N - X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk - W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we + X_blk_ptrs = X_ptr + M_in_idx[:, None] * \ + stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + \ + N_block[None, :] * stride_wn + E_idx * stride_we acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) iters = tl.cdiv(K, BLOCK_K) @@ -89,16 +93,20 @@ def scatter2scatter_triton_kernel( W_blk_ptrs += BLOCK_K * stride_wk acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) - Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * + stride_ym + N_block[None, :] * stride_yn) tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) @triton.autotune( configs=[ # different block M and reducing stages - triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 128}, num_stages=1, num_warps=4), - triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 64}, num_stages=2, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, + "BLOCK_M": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, + "BLOCK_M": 128}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, + "BLOCK_M": 64}, num_stages=2, num_warps=4), # keep 4 stages and keep two 64 block sizes # - NOTE: these can get good performances for low M, but for large M the variation # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_M': 64}, num_stages=4, num_warps=4), @@ -151,15 +159,19 @@ def groupXtY_triton_kernel( K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) K_mask = K_block < K - K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + K_block = tl.max_contiguous( + tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) N_mask = N_block < N - N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + N_block = tl.max_contiguous( + tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) M_idxs = M_block - xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm - dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + xt_blk_ptrs = X_ptr + K_block[:, None] * \ + stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = DY_ptr + M_idxs[:, None] * \ + stride_dym + N_block[None, :] * stride_dyk acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) iters = tl.cdiv(end_idx - start_idx, BLOCK_M) @@ -173,18 +185,21 @@ def groupXtY_triton_kernel( if no_k_mask: xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) else: - xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + xt = tl.load( + xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) if no_n_mask: dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) else: - dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + dy = tl.load( + dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) xt_blk_ptrs += BLOCK_M * stride_xm dy_blk_ptrs += BLOCK_M * stride_dym acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) - DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn + DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + \ + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn acc = acc.to(DW_blk_ptrs.dtype.element_ty) tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) @@ -216,8 +231,10 @@ def group_triton_kernel( N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) K_blk = tl.arange(0, BLOCK_K) - src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk - tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + src_blk_ptrs = src_ptr + \ + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * \ + stride_tn + K_blk[None, :] * stride_ti if has_coeff: c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] diff --git a/scattermoe/mlp.py b/scattermoe/mlp.py index b7c984e..47c3786 100644 --- a/scattermoe/mlp.py +++ b/scattermoe/mlp.py @@ -2,9 +2,7 @@ from torch import nn from torch.nn import functional as F -# from . import kernels -# from .parallel_experts import ParallelExperts -from .triton_implementation import ParallelExperts, padded_block_indices +from .triton_implementation import ParallelExperts, expert_boundaries class GLUMLP(nn.Module): def __init__( @@ -33,18 +31,22 @@ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Te x = x.view(-1, x_shape[-1]) with torch.no_grad(): sorted_expert_idxs, sorted_scattered_idxs = torch.sort(expert_idxs.flatten()) - padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs, self.num_experts) - + expert_offsets = expert_boundaries(sorted_expert_idxs, self.num_experts) h, gates = self.experts( - x, self.top_k, - sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, + inputs=x, + k=self.top_k, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + expert_offsets=expert_offsets, grouped_out=True ).chunk(2, dim=-1) h = self.activation(gates) * h y = self.output_experts( - h, 1, sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, + inputs=h, + k=1, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + expert_offsets=expert_offsets, grouped_in=True, gates=expert_p, ) @@ -79,19 +81,20 @@ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Te x = x.view(-1, x_shape[-1]) with torch.no_grad(): sorted_expert_idxs, sorted_scattered_idxs = torch.sort(expert_idxs.flatten()) - padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs, self.num_experts) + expert_offsets = expert_boundaries(sorted_expert_idxs, self.num_experts) h = self.experts( x, self.top_k, sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, - grouped_out=True + expert_offsets, + grouped_out=True, grouped_in=False, + gates=None ) h = self.activation(h) y = self.output_experts( h, 1, sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, - grouped_in=True, + expert_offsets, + grouped_out=False, grouped_in=True, gates=expert_p, ) y = y.view(*x_shape[:-1], y.size(-1)) diff --git a/scattermoe/triton_implementation/__init__.py b/scattermoe/triton_implementation/__init__.py index 99085b7..640d92b 100644 --- a/scattermoe/triton_implementation/__init__.py +++ b/scattermoe/triton_implementation/__init__.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from .ops import padded_block_indices, scattered_experts +from .ops import expert_boundaries, scattered_experts class ParallelExperts(nn.Module): @@ -28,7 +28,6 @@ def forward( k, sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, gates=None, grouped_in=False, @@ -40,7 +39,6 @@ def forward( k, sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, gates, grouped_in, diff --git a/scattermoe/triton_implementation/kernels.py b/scattermoe/triton_implementation/kernels.py index 293a5e0..ff75c4c 100644 --- a/scattermoe/triton_implementation/kernels.py +++ b/scattermoe/triton_implementation/kernels.py @@ -16,7 +16,7 @@ def scatter2scatter_triton_kernel( Y_ptr, stride_ym, stride_yn, grouped_idx_ptr, expert_idxs_ptr, - block_start_idx_ptr, + # block_start_idx_ptr, FAN_OUT, M, K: tl.constexpr, @@ -36,46 +36,55 @@ def scatter2scatter_triton_kernel( M_block_id = pid // N_BLOCK_COUNT N_block_id = pid % N_BLOCK_COUNT M_range = tl.arange(0, BLOCK_M) - block_start_idx = tl.load(block_start_idx_ptr + M_block_id) + N_range = tl.arange(0, BLOCK_N) + # block_start_idx = tl.load(block_start_idx_ptr + M_block_id) - M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) - - N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + # M_block = tl.max_contiguous(M_block_id * BLOCK_M + M_range, BLOCK_M) + M_block = M_block_id * BLOCK_M + M_range + N_block = N_block_id * BLOCK_N + N_range N_mask = N_block < N - - E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) + M_boundary_mask = M_block < (FAN_OUT * M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) no_k_mask = K % BLOCK_K == 0 no_n_mask = N % BLOCK_N == 0 - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - E_idx = tl.min(E_idxs) - E_mask = E_idxs == E_idx - M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) - if x_grouped: - M_in_idx = M_block - else: - M_in_idx = M_idx // FAN_OUT + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + E_first_idx = tl.min(E_idxs) + E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) + M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask, other=0).to(tl.int32) + + # iters = E_last_idx - E_first_idx + 1 + # for i in range(iters): + # E_idx = i + E_first_idx + for E_idx in range(E_first_idx, E_last_idx + 1): + E_mask = E_idxs == E_idx + # E_M_idx = tl.where(E_mask, M_idx, 0) + E_M_idx = M_idx + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = E_M_idx // FAN_OUT + + acc = compute_expert_block( + E_idx, E_mask, + M_in_idx, + N_block, N_mask, + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + K, + acc, + allow_tf32, + no_k_mask, no_n_mask, + ACC_TYPE, BLOCK_K + ) if y_grouped: M_out_idx = M_block else: M_out_idx = M_idx - acc = compute_expert_block( - E_idx, E_mask, - M_in_idx, - N_block, N_mask, - X_ptr, stride_xm, stride_xk, - W_ptr, stride_we, stride_wk, stride_wn, - K, - acc, - allow_tf32, - no_k_mask, no_n_mask, - ACC_TYPE, BLOCK_K - ) - Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) - tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) @triton.jit def compute_expert_block( @@ -91,14 +100,13 @@ def compute_expert_block( ACC_TYPE, BLOCK_K): K_block = tl.arange(0, BLOCK_K) - X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we iters = tl.cdiv(K, BLOCK_K) + for K_block_id in range(0, iters): if no_k_mask: x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) - if no_n_mask or K_block_id < (iters - 1): w = tl.load(W_blk_ptrs) else: @@ -154,7 +162,8 @@ def groupXtY_triton_kernel( pid1 = tl.program_id(axis=1) num0 = tl.num_programs(0) num1 = tl.num_programs(1) - pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + # pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128) K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) E_idx = pid0 // K_BLOCK_COUNT diff --git a/scattermoe/triton_implementation/ops/__init__.py b/scattermoe/triton_implementation/ops/__init__.py index 861d4e8..a6859dc 100644 --- a/scattermoe/triton_implementation/ops/__init__.py +++ b/scattermoe/triton_implementation/ops/__init__.py @@ -7,28 +7,14 @@ torch._dynamo.config.capture_scalar_outputs = True -def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): +def expert_boundaries(sorted_experts_idxs: torch.Tensor, k: int): # there is an overhead of launching a custom op so we only use the custom op when compiling if torch.compiler.is_compiling(): expert_counts = compileable_bincount(sorted_experts_idxs, k) else: expert_counts = sorted_experts_idxs.bincount(minlength=k) - - padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 - padded_expert_block_end = padded_block_counts.cumsum(-1) expert_boundaries_end = expert_counts.cumsum(-1) - expert_boundaries_start = expert_boundaries_end - expert_counts - padded_expert_block_start = padded_expert_block_end - padded_block_counts - - block_idxs = torch.arange( - padded_expert_block_end[-1], dtype=sorted_experts_idxs.dtype, device=sorted_experts_idxs.device - ).unsqueeze(1) - - block_mask = (block_idxs < padded_expert_block_start) | (block_idxs >= padded_expert_block_end) - expanded_block_idxs = N_BLOCK_SIZE * (block_idxs - padded_expert_block_start) + expert_boundaries_start - expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) - - return expanded_block_idxs, expert_boundaries_end + return expert_boundaries_end class _ScatteredExperts(torch.autograd.Function): @@ -40,7 +26,6 @@ def forward( k, sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, gates=None, grouped_in=False, @@ -53,7 +38,6 @@ def forward( W=expert_weights, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, out=output, FAN_OUT=k, x_grouped=grouped_in, @@ -71,7 +55,6 @@ def forward( expert_weights, sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, gates, output_expanded, @@ -90,7 +73,6 @@ def backward(ctx, grad_out): expert_weights, sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, gates, output_expanded, @@ -139,14 +121,14 @@ def backward(ctx, grad_out): d_expanded_input = grouped_x - d_weights = torch.zeros( - expert_weights.size(0), - grouped_grad_out.size(-1), - grouped_x.size(-1), - device=grouped_grad_out.device, - dtype=grouped_grad_out.dtype, - ).permute(0, 2, 1) - + # d_weights = torch.zeros( + # expert_weights.size(0), + # grouped_grad_out.size(-1), + # grouped_x.size(-1), + # device=grouped_grad_out.device, + # dtype=grouped_grad_out.dtype, + # ).permute(0, 2, 1) + d_weights = torch.zeros_like(expert_weights) group_bwd_W( DY=grouped_grad_out, X=grouped_x, @@ -160,7 +142,6 @@ def backward(ctx, grad_out): W=expert_weights.permute(0, 2, 1), sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, out=d_expanded_input, FAN_OUT=1, x_grouped=True, @@ -181,8 +162,7 @@ def backward(ctx, grad_out): # sorted_expert_idxs, sorted_scattered_idxs, None, None, - # padded_block_idxs, expert_offsets, - None, + # expert_offsets, None, # gates d_gates, @@ -197,19 +177,27 @@ def scattered_experts( k, sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, gates=None, grouped_in=False, grouped_out=False, ): + x = inputs + expert_weights = expert_weights + k = k + sorted_expert_idxs = sorted_expert_idxs + sorted_scattered_idxs = sorted_scattered_idxs + expert_offsets = expert_offsets + gates = gates + grouped_in = grouped_in + grouped_out = grouped_out + return _ScatteredExperts.apply( - inputs, + x, expert_weights, k, sorted_expert_idxs, sorted_scattered_idxs, - padded_block_idxs, expert_offsets, gates, grouped_in, diff --git a/scattermoe/triton_implementation/ops/compileable_ops.py b/scattermoe/triton_implementation/ops/compileable_ops.py index 23f547b..9d3c2d3 100644 --- a/scattermoe/triton_implementation/ops/compileable_ops.py +++ b/scattermoe/triton_implementation/ops/compileable_ops.py @@ -8,6 +8,7 @@ LIBRARY_NAME = "scattermoe" BLOCK_M = 128 +ALLOW_TF32 = False torch._dynamo.config.capture_scalar_outputs = True @@ -27,7 +28,6 @@ def _scatter2scatter( W: torch.Tensor, sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor, - padded_block_idxs: torch.Tensor, out: torch.Tensor, FAN_OUT: int, x_grouped: bool = False, @@ -38,25 +38,20 @@ def _scatter2scatter( assert out.size(0) == sorted_expert_idxs.size(0) assert out.size(1) == W.size(-1) - grid = lambda meta: (padded_block_idxs.size(0) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) + grid = lambda meta: ( + triton.cdiv(sorted_expert_idxs.size(0), meta["BLOCK_M"]) * + triton.cdiv(meta["N"], meta["BLOCK_N"]), + ) scatter2scatter_triton_kernel[grid]( # X_ptr, stride_xm, stride_xk, - X, - X.stride(0), - X.stride(1), + X, X.stride(0), X.stride(1), # W_ptr, stride_we, stride_wk, stride_wn, - W, - W.stride(0), - W.stride(1), - W.stride(2), + W, W.stride(0), W.stride(1), W.stride(2), # Y_ptr, stride_ym, stride_yn, - out, - out.stride(0), - out.stride(1), + out, out.stride(0), out.stride(1), grouped_idx_ptr=sorted_scattered_idxs, expert_idxs_ptr=sorted_expert_idxs, - block_start_idx_ptr=padded_block_idxs, FAN_OUT=FAN_OUT, M=X.size(0), K=X.size(1), @@ -64,7 +59,7 @@ def _scatter2scatter( E=W.size(0), BLOCK_M=BLOCK_M, ACC_TYPE=tl.float32, - allow_tf32=torch.backends.cudnn.allow_tf32, + allow_tf32=torch.backends.cudnn.allow_tf32 and ALLOW_TF32, x_grouped=x_grouped, y_grouped=y_grouped, ) @@ -77,7 +72,6 @@ def _scatter2scatter_compileable( W: torch.Tensor, sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor, - padded_block_idxs: torch.Tensor, out: torch.Tensor, FAN_OUT: int, x_grouped: bool = False, @@ -88,7 +82,6 @@ def _scatter2scatter_compileable( W=W, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, out=out, FAN_OUT=FAN_OUT, x_grouped=x_grouped, @@ -101,19 +94,17 @@ def scatter2scatter( W: torch.Tensor, sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor, - padded_block_idxs: torch.Tensor, out: torch.Tensor, FAN_OUT: int, x_grouped: bool = False, y_grouped: bool = False, ) -> None: - if torch.compiler.is_compiling(): + if False: # torch.compiler.is_compiling(): _scatter2scatter_compileable( X=X, W=W, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, out=out, FAN_OUT=FAN_OUT, x_grouped=x_grouped, @@ -125,7 +116,6 @@ def scatter2scatter( W=W, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, out=out, FAN_OUT=FAN_OUT, x_grouped=x_grouped, @@ -157,7 +147,7 @@ def _group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor K=X.size(-1), # ACC_TYPE: tl.constexpr, ACC_TYPE=tl.float32, - allow_tf32=torch.backends.cudnn.allow_tf32, + allow_tf32=torch.backends.cudnn.allow_tf32 and ALLOW_TF32, ) @@ -170,7 +160,7 @@ def _group_bwd_W_compileable( def group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int) -> None: - if torch.compiler.is_compiling(): + if False: # torch.compiler.is_compiling(): _group_bwd_W_compileable(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) else: _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) @@ -228,7 +218,7 @@ def group( coeff: torch.Tensor | None = None, fan_out: int = 1, ) -> None: - if torch.compiler.is_compiling(): + if False: # torch.compiler.is_compiling(): _group_compileable(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) else: _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out)