Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
95d0bec
add fused concat
k50112113 Jun 24, 2025
3d551f7
add fused elementwise and pytest
k50112113 Jun 24, 2025
38d2366
clean up
k50112113 Jun 24, 2025
ae8b17e
add fused quant
k50112113 Jun 24, 2025
2e7c081
clean up
k50112113 Jun 24, 2025
ea161bc
add fused mxfp4 quant and pytest
k50112113 Jun 24, 2025
9e29cca
add gemm_a16w16_atomic and related tests
cagrikymk Jun 24, 2025
4a868f9
extend tests and add config files for bf16 GEMM with atomic add
cagrikymk Jun 25, 2025
26f468a
fused quant. code cleanup
cagrikymk Jun 25, 2025
d028a7e
formatting changes
cagrikymk Jun 25, 2025
aa7fe7c
add rms norm quant tests and changes
cagrikymk Jun 25, 2025
a128d5b
update/add DS configs for GEMMs
cagrikymk Jun 25, 2025
8bec97a
add prequant config
k50112113 Jun 25, 2025
2fd002e
fix bf16 gemm
k50112113 Jun 25, 2025
7e1ff03
fix fp4 gemm atomic add
k50112113 Jun 26, 2025
f14db5a
black reformatting
k50112113 Jun 26, 2025
5534a9a
tune fp4 prequant gemm atomic
k50112113 Jun 26, 2025
d0e8c37
fix batched fp4 prequant and add new DS configs
cagrikymk Jun 27, 2025
8d7fda5
optimize shapes for deepseek
lburzawa Jun 27, 2025
2ccc45b
black reformatting
k50112113 Jul 2, 2025
302b717
update bf16 atomic GEMM
cagrikymk Jul 8, 2025
c0e7ff6
Merge branch 'main' into shaoclee/ds_fused_custom_ops
cagrikymk Jul 8, 2025
0c484c3
Merge branch 'main' into shaoclee/ds_fused_custom_ops
k50112113 Jul 8, 2025
c394d99
add documentation for fused_concat
k50112113 Jul 8, 2025
aff2092
update rope qk cat
k50112113 Jul 8, 2025
e20e181
rename fused elementwise
k50112113 Jul 8, 2025
adeab4e
add doc for fused quant
k50112113 Jul 9, 2025
f333f65
rename fused_quant into fused_mxfp4_quant
k50112113 Jul 9, 2025
4583988
fix typo
k50112113 Jul 9, 2025
9691bc5
fix a minor bug
cagrikymk Jul 9, 2025
fd51c92
black formatting
k50112113 Jul 9, 2025
86d891c
update comments and drop unused arg.
cagrikymk Jul 9, 2025
ce608d1
fix pre-quant GEMM tests based on func. sign. changes
cagrikymk Jul 9, 2025
06f2c26
fix pre-quant GEMM tests based on func. sign. changes
cagrikymk Jul 9, 2025
76e0bf7
doc on fused_mul_add
k50112113 Jul 9, 2025
b780674
Merge branch 'main' into shaoclee/ds_fused_custom_ops
k50112113 Jul 14, 2025
75d14bb
corner case fix
k50112113 Jul 14, 2025
78aeb7b
corner case fix
k50112113 Jul 14, 2025
0035df6
test
k50112113 Jul 15, 2025
1046898
Fix pytest errors with LRU cache inplace mod - big test case error re…
willzhou-amd Jul 15, 2025
10ffb64
Black linting change
willzhou-amd Jul 15, 2025
7a47179
Fix linting error
willzhou-amd Jul 15, 2025
a4c9b0a
Move inplace config operations to _get_config
willzhou-amd Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 77 additions & 50 deletions aiter/ops/triton/batched_gemm_afp4wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,22 @@ def _batched_gemm_afp4_wfp4_kernel(
M,
N,
K,
stride_ab,
stride_am,
stride_ak,
stride_bb,
stride_bk,
stride_bn,
stride_cb,
stride_ck,
stride_cm,
stride_cn,
stride_asb,
stride_asm,
stride_ask,
stride_bsb,
stride_bsn,
stride_bsk,
stride_in_ab,
stride_in_am,
stride_in_ak,
stride_in_bb,
stride_in_bk,
stride_in_bn,
stride_in_cb,
stride_in_ck,
stride_in_cm,
stride_in_cn,
stride_in_asb,
stride_in_asm,
stride_in_ask,
stride_in_bsb,
stride_in_bsn,
stride_in_bsk,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
Expand All @@ -74,21 +74,21 @@ def _batched_gemm_afp4_wfp4_kernel(
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""

tl.assume(stride_ab > 0)
tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bb > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_cb > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
tl.assume(stride_asb > 0)
tl.assume(stride_asm > 0)
tl.assume(stride_ask > 0)
tl.assume(stride_bsb > 0)
tl.assume(stride_bsk > 0)
tl.assume(stride_bsn > 0)
tl.assume(stride_in_ab > 0)
tl.assume(stride_in_am > 0)
tl.assume(stride_in_ak > 0)
tl.assume(stride_in_bb > 0)
tl.assume(stride_in_bk > 0)
tl.assume(stride_in_bn > 0)
tl.assume(stride_in_cb > 0)
tl.assume(stride_in_cm > 0)
tl.assume(stride_in_cn > 0)
tl.assume(stride_in_asb > 0)
tl.assume(stride_in_asm > 0)
tl.assume(stride_in_ask > 0)
tl.assume(stride_in_bsb > 0)
tl.assume(stride_in_bsk > 0)
tl.assume(stride_in_bsn > 0)

# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
Expand All @@ -103,10 +103,27 @@ def _batched_gemm_afp4_wfp4_kernel(
# Cast batch id and batch dimension strides to int64 to avoid int32 overflow during offset calculation
# Note: If you're attempting to cast strides to int64 to prevent integer overflow, use `tl.cast` instead of `.to()`.
# See https://github.com/ROCm/aiter/pull/597 for rationale
stride_ab = tl.cast(stride_ab, tl.int64)
stride_bb = tl.cast(stride_bb, tl.int64)
stride_cb = tl.cast(stride_cb, tl.int64)
pid_batch = tl.cast(pid_batch, tl.int64)
# stride_ab = tl.cast(stride_ab, tl.int64)
# stride_bb = tl.cast(stride_bb, tl.int64)
# stride_cb = tl.cast(stride_cb, tl.int64)
# pid_batch = tl.cast(pid_batch, tl.int64)

stride_ab = tl.cast(stride_in_ab, tl.int64)
stride_am = tl.cast(stride_in_am, tl.int64)
stride_ak = tl.cast(stride_in_ak, tl.int64)
stride_bb = tl.cast(stride_in_bb, tl.int64)
stride_bk = tl.cast(stride_in_bk, tl.int64)
stride_bn = tl.cast(stride_in_bn, tl.int64)
stride_cb = tl.cast(stride_in_cb, tl.int64)
stride_ck = tl.cast(stride_in_ck, tl.int64)
stride_cm = tl.cast(stride_in_cm, tl.int64)
stride_cn = tl.cast(stride_in_cn, tl.int64)
stride_asb = tl.cast(stride_in_asb, tl.int64)
stride_asm = tl.cast(stride_in_asm, tl.int64)
stride_ask = tl.cast(stride_in_ask, tl.int64)
stride_bsb = tl.cast(stride_in_bsb, tl.int64)
stride_bsk = tl.cast(stride_in_bsk, tl.int64)
stride_bsn = tl.cast(stride_in_bsn, tl.int64)

if NUM_KSPLIT == 1:
remap_xcd(pid, GRID_MN)
Expand Down Expand Up @@ -316,19 +333,38 @@ def _get_config(
else:
key = "default" # fall back to default config
if M < 32:
return _get_config._config_dict[key]["small"]
config = _get_config._config_dict[key]["small"]
elif M <= 128:
BLK_M = triton.next_power_of_2(M)
if BLK_M == 32:
return _get_config._config_dict[key]["medium_M32"]
config = _get_config._config_dict[key]["medium_M32"]
elif BLK_M == 64:
return _get_config._config_dict[key]["medium_M64"]
config = _get_config._config_dict[key]["medium_M64"]
elif BLK_M == 128:
return _get_config._config_dict[key]["medium_M128"]
config = _get_config._config_dict[key]["medium_M128"]
elif M <= 256:
return _get_config._config_dict[key]["large"]
config = _get_config._config_dict[key]["large"]
else:
return _get_config._config_dict[key]["xlarge"]
config = _get_config._config_dict[key]["xlarge"]

config = config.copy() # Avoid modifying the original config

if config["NUM_KSPLIT"] > 1:
SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk(
K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"]
)

config["SPLITK_BLOCK_SIZE"] = SPLITK_BLOCK_SIZE
config["BLOCK_SIZE_K"] = BLOCK_SIZE_K
config["NUM_KSPLIT"] = NUM_KSPLIT
else:
config["SPLITK_BLOCK_SIZE"] = 2 * K

if config["BLOCK_SIZE_K"] >= 2 * K:
config["BLOCK_SIZE_K"] = triton.next_power_of_2(2 * K)
config["SPLITK_BLOCK_SIZE"] = 2 * K

return config


def batched_gemm_afp4wfp4(
Expand Down Expand Up @@ -370,14 +406,6 @@ def batched_gemm_afp4wfp4(
config = _get_config(M, N, K)

if config["NUM_KSPLIT"] > 1:
SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk(
K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"]
)

config["SPLITK_BLOCK_SIZE"] = SPLITK_BLOCK_SIZE
config["BLOCK_SIZE_K"] = BLOCK_SIZE_K
config["NUM_KSPLIT"] = NUM_KSPLIT

if _USE_GEMM_SPLITK_BF16:
y_pp = torch.empty(
(Batch, config["NUM_KSPLIT"], M, N), dtype=y.dtype, device=y.device
Expand All @@ -389,7 +417,6 @@ def batched_gemm_afp4wfp4(
device=y.device,
)
else:
config["SPLITK_BLOCK_SIZE"] = 2 * K
y_pp = None

grid = lambda META: ( # noqa: E731
Expand Down
11 changes: 7 additions & 4 deletions aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,18 +322,16 @@ def _get_config(
def batched_gemm_afp4wfp4_pre_quant(
x,
w,
x_scales,
w_scales,
dtype: Optional[float] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
):
"""
Computes the matmul Y = X x W
X and W are e2m1 fp4 tensors.
x_scales and w_scales are e8m0 tensors.
W is an e2m1 fp4 tensor and w_scales is an e8m0 tensor.
Every 32 elements in the K dimension share one e8m0 scale.

X gets quantized to the microscale fp4 (mxfp4) format before the GEMM.

Key parameters:
- X: Matrix X with shape (B, M, K).
Expand Down Expand Up @@ -378,6 +376,10 @@ def batched_gemm_afp4wfp4_pre_quant(
config["SPLITK_BLOCK_SIZE"] = 2 * K
y_pp = None

if config["BLOCK_SIZE_K"] >= 2 * K:
config["BLOCK_SIZE_K"] = triton.next_power_of_2(2 * K)
config["SPLITK_BLOCK_SIZE"] = 2 * K

grid = lambda META: ( # noqa: E731
Batch,
(
Expand Down Expand Up @@ -440,3 +442,4 @@ def batched_gemm_afp4wfp4_pre_quant(
ACTUAL_KSPLIT,
config["NUM_KSPLIT"],
)
return y
15 changes: 15 additions & 0 deletions aiter/ops/triton/configs/gemm/MI300X-GEMM-A16W16-ATOMIC.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"any": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"NUM_KSPLIT":1,
"cache_modifier": null,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"kpack": 1
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
{
"small": {
"BLOCK_SIZE_M": 4,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
},
"medium_M32": {
"BLOCK_SIZE_M": 4,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
},
"medium_M64": {
"BLOCK_SIZE_M": 4,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
},
"medium_M128": {
"BLOCK_SIZE_M": 4,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
},
"large": {
"BLOCK_SIZE_M": 8,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
},
"xlarge": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
{
"small": {
"BLOCK_SIZE_M": 4,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
},
"medium_M32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
},
"medium_M64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
},
"medium_M128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
},
"large": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
},
"xlarge": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 1,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 1
}
}
Loading