Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Padded indices free. #14

Open
wants to merge 1 commit into
base: khd
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scattermoe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import kernels
from . import parallel_experts
from . import mlp
from .triton_implementation import padded_block_indices
from .triton_implementation import expert_boundaries
from .parallel_experts import ParallelExperts
51 changes: 34 additions & 17 deletions scattermoe/kernels/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@


@triton.autotune(
configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)],
configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32},
num_stages=4, num_warps=4)],
key=["N", "K"],
)
@triton.jit
Expand Down Expand Up @@ -43,7 +44,8 @@ def scatter2scatter_triton_kernel(
block_start_idx = tl.load(block_start_idx_ptr + M_block_id)

M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M)
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E)
E_idxs = tl.load(expert_idxs_ptr + M_block,
mask=M_block < (FAN_OUT * M), other=E)
E_idx = tl.min(E_idxs)
E_mask = E_idxs == E_idx
M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0)
Expand All @@ -63,8 +65,10 @@ def scatter2scatter_triton_kernel(
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N

X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
X_blk_ptrs = X_ptr + M_in_idx[:, None] * \
stride_xm + K_block[None, :] * stride_xk
W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + \
N_block[None, :] * stride_wn + E_idx * stride_we

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(K, BLOCK_K)
Expand All @@ -89,16 +93,20 @@ def scatter2scatter_triton_kernel(
W_blk_ptrs += BLOCK_K * stride_wk
acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE)

Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] *
stride_ym + N_block[None, :] * stride_yn)
tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :])


@triton.autotune(
configs=[
# different block M and reducing stages
triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 128}, num_stages=1, num_warps=4),
triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 64}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_N": 128, "BLOCK_K": 128,
"BLOCK_M": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_N": 128, "BLOCK_K": 128,
"BLOCK_M": 128}, num_stages=1, num_warps=4),
triton.Config({"BLOCK_N": 128, "BLOCK_K": 128,
"BLOCK_M": 64}, num_stages=2, num_warps=4),
# keep 4 stages and keep two 64 block sizes
# - NOTE: these can get good performances for low M, but for large M the variation
# triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_M': 64}, num_stages=4, num_warps=4),
Expand Down Expand Up @@ -151,15 +159,19 @@ def groupXtY_triton_kernel(

K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
K_mask = K_block < K
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
K_block = tl.max_contiguous(
tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)

N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
N_block = tl.max_contiguous(
tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)

M_idxs = M_block
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
xt_blk_ptrs = X_ptr + K_block[:, None] * \
stride_xn + M_idxs[None, :] * stride_xm
dy_blk_ptrs = DY_ptr + M_idxs[:, None] * \
stride_dym + N_block[None, :] * stride_dyk

acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
Expand All @@ -173,18 +185,21 @@ def groupXtY_triton_kernel(
if no_k_mask:
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
else:
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
xt = tl.load(
xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])

if no_n_mask:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
else:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
dy = tl.load(
dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])

xt_blk_ptrs += BLOCK_M * stride_xm
dy_blk_ptrs += BLOCK_M * stride_dym
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)

DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + \
K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])

Expand Down Expand Up @@ -216,8 +231,10 @@ def group_triton_kernel(
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)

K_blk = tl.arange(0, BLOCK_K)
src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
src_blk_ptrs = src_ptr + \
(N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * \
stride_tn + K_blk[None, :] * stride_ti

if has_coeff:
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
Expand Down
33 changes: 18 additions & 15 deletions scattermoe/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from torch import nn
from torch.nn import functional as F

# from . import kernels
# from .parallel_experts import ParallelExperts
from .triton_implementation import ParallelExperts, padded_block_indices
from .triton_implementation import ParallelExperts, expert_boundaries

class GLUMLP(nn.Module):
def __init__(
Expand Down Expand Up @@ -33,18 +31,22 @@ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Te
x = x.view(-1, x_shape[-1])
with torch.no_grad():
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(expert_idxs.flatten())
padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs, self.num_experts)

expert_offsets = expert_boundaries(sorted_expert_idxs, self.num_experts)
h, gates = self.experts(
x, self.top_k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
inputs=x,
k=self.top_k,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
expert_offsets=expert_offsets,
grouped_out=True
).chunk(2, dim=-1)
h = self.activation(gates) * h
y = self.output_experts(
h, 1, sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
inputs=h,
k=1,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
expert_offsets=expert_offsets,
grouped_in=True,
gates=expert_p,
)
Expand Down Expand Up @@ -79,19 +81,20 @@ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Te
x = x.view(-1, x_shape[-1])
with torch.no_grad():
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(expert_idxs.flatten())
padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs, self.num_experts)
expert_offsets = expert_boundaries(sorted_expert_idxs, self.num_experts)

h = self.experts(
x, self.top_k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
grouped_out=True
expert_offsets,
grouped_out=True, grouped_in=False,
gates=None
)
h = self.activation(h)
y = self.output_experts(
h, 1, sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
grouped_in=True,
expert_offsets,
grouped_out=False, grouped_in=True,
gates=expert_p,
)
y = y.view(*x_shape[:-1], y.size(-1))
Expand Down
4 changes: 1 addition & 3 deletions scattermoe/triton_implementation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn

from .ops import padded_block_indices, scattered_experts
from .ops import expert_boundaries, scattered_experts


class ParallelExperts(nn.Module):
Expand All @@ -28,7 +28,6 @@ def forward(
k,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
gates=None,
grouped_in=False,
Expand All @@ -40,7 +39,6 @@ def forward(
k,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
gates,
grouped_in,
Expand Down
73 changes: 41 additions & 32 deletions scattermoe/triton_implementation/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def scatter2scatter_triton_kernel(
Y_ptr, stride_ym, stride_yn,
grouped_idx_ptr,
expert_idxs_ptr,
block_start_idx_ptr,
# block_start_idx_ptr,
FAN_OUT,
M,
K: tl.constexpr,
Expand All @@ -36,46 +36,55 @@ def scatter2scatter_triton_kernel(
M_block_id = pid // N_BLOCK_COUNT
N_block_id = pid % N_BLOCK_COUNT
M_range = tl.arange(0, BLOCK_M)
block_start_idx = tl.load(block_start_idx_ptr + M_block_id)
N_range = tl.arange(0, BLOCK_N)
# block_start_idx = tl.load(block_start_idx_ptr + M_block_id)

M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M)

N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
# M_block = tl.max_contiguous(M_block_id * BLOCK_M + M_range, BLOCK_M)
M_block = M_block_id * BLOCK_M + M_range
N_block = N_block_id * BLOCK_N + N_range
N_mask = N_block < N

E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E)
M_boundary_mask = M_block < (FAN_OUT * M)
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)

no_k_mask = K % BLOCK_K == 0
no_n_mask = N % BLOCK_N == 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)

E_idx = tl.min(E_idxs)
E_mask = E_idxs == E_idx
M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0)
if x_grouped:
M_in_idx = M_block
else:
M_in_idx = M_idx // FAN_OUT
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
E_first_idx = tl.min(E_idxs)
E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask, other=0).to(tl.int32)

# iters = E_last_idx - E_first_idx + 1
# for i in range(iters):
# E_idx = i + E_first_idx
for E_idx in range(E_first_idx, E_last_idx + 1):
E_mask = E_idxs == E_idx
# E_M_idx = tl.where(E_mask, M_idx, 0)
E_M_idx = M_idx
if x_grouped:
M_in_idx = M_block
else:
M_in_idx = E_M_idx // FAN_OUT

acc = compute_expert_block(
E_idx, E_mask,
M_in_idx,
N_block, N_mask,
X_ptr, stride_xm, stride_xk,
W_ptr, stride_we, stride_wk, stride_wn,
K,
acc,
allow_tf32,
no_k_mask, no_n_mask,
ACC_TYPE, BLOCK_K
)
if y_grouped:
M_out_idx = M_block
else:
M_out_idx = M_idx

acc = compute_expert_block(
E_idx, E_mask,
M_in_idx,
N_block, N_mask,
X_ptr, stride_xm, stride_xk,
W_ptr, stride_we, stride_wk, stride_wn,
K,
acc,
allow_tf32,
no_k_mask, no_n_mask,
ACC_TYPE, BLOCK_K
)

Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :])
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])

@triton.jit
def compute_expert_block(
Expand All @@ -91,14 +100,13 @@ def compute_expert_block(
ACC_TYPE, BLOCK_K):

K_block = tl.arange(0, BLOCK_K)

X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
iters = tl.cdiv(K, BLOCK_K)

for K_block_id in range(0, iters):
if no_k_mask:
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])

if no_n_mask or K_block_id < (iters - 1):
w = tl.load(W_blk_ptrs)
else:
Expand Down Expand Up @@ -154,7 +162,8 @@ def groupXtY_triton_kernel(
pid1 = tl.program_id(axis=1)
num0 = tl.num_programs(0)
num1 = tl.num_programs(1)
pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
# pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)

K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
E_idx = pid0 // K_BLOCK_COUNT
Expand Down
Loading