Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion op_tests/test_layernorm2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import aiter
from aiter.test_common import checkAllclose, perftest
from aiter import dtypes
import argparse


@perftest()
Expand Down Expand Up @@ -105,11 +106,48 @@ def test_layernorm2d_fuseAdd(dtype, m, n):
checkAllclose(res_a, res_c, atol=0.01, msg="asm res")


l_dtype = ["bf16"]
parser = argparse.ArgumentParser(
description="Test layernorm2d performance and correctness",
)
parser.add_argument(
"-d",
"--dtype",
type=str,
choices=l_dtype,
nargs="?",
const=None,
default=None,
help="""Data type.
e.g.: -d bf16""",
)
parser.add_argument(
"-m",
type=int,
nargs="?",
default=128,
help="""Number of rows in the input tensor.
e.g.: -m 128""",
)
parser.add_argument(
"-n",
type=int,
nargs="?",
default=8192,
help="""Number of columns in the input tensor.
e.g.: -n 8192""",
)
args = parser.parse_args()
if args.dtype is None:
l_dtype = [dtypes.d_dtypes[key] for key in l_dtype]
else:
l_dtype = [dtypes.d_dtypes[args.dtype]]
# for dtype in [dtypes.fp16, dtypes.bf16]:
# for m in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
# for n in [4096, 8192, 16384, 32768, 65536]:
# test_layernorm2d(dtype, m, n)
test_layernorm2d_fuseAdd(dtypes.bf16, 128, 8192)
for dtype in l_dtype:
test_layernorm2d_fuseAdd(dtype, args.m, args.n)


# print('\nstart fuse add test')
Expand Down
10 changes: 5 additions & 5 deletions op_tests/test_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,12 @@ def test_flash_attn_output(
e.g.: -k 1024""",
)
parser.add_argument(
"-d",
"--d",
"-qk",
"--d_qk",
type=int,
default=128,
help="""Dimension of query and key. Default is 128.
e.g.: -d 256""",
e.g.: -qk 256""",
)
parser.add_argument(
"-v",
Expand Down Expand Up @@ -399,7 +399,7 @@ def test_flash_attn_output(
e.g.: -m mha""",
)
parser.add_argument(
"-dtype",
"-d",
"--dtype",
type=str,
default="bf16",
Expand All @@ -414,7 +414,7 @@ def test_flash_attn_output(
args.nheads,
args.seqlen_q,
args.seqlen_k,
args.d,
args.d_qk,
args.d_v,
args.dropout_p,
args.causal,
Expand Down
2 changes: 1 addition & 1 deletion op_tests/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def test_absorb_prefill():
nargs="?",
const=None,
default=None,
help="""Number of heads.
help="""Number of nhead and mtp.
e.g.: -n 16,1""",
)

Expand Down
10 changes: 9 additions & 1 deletion op_tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,15 @@ def calculateTensorsSize(*args):
choices=l_test,
default=None,
help="""Select test to run.
e.g.: -t test_fmoe_16_bit""",
e.g.: -t test_fmoe_16_bit
or -t test_fmoe_16_bit
or -t g1u1_no_quant
or -t g1u1_int8quant
or -t g1u1_fp8quant
or -t g1u0_int8smoothquant
or -t g1u1_int8smoothquant
or -t g1u1_fp8smoothquant
or -t g1u1_int4""",
)
parser.add_argument(
"-d",
Expand Down
Loading