Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion maint/gemm_v2/latency_mha_fwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
parser.add_argument('--heads', type=int, default=16, help='heads')
parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=512, help='dim')
parser.add_argument('--dim', type=int, default=256, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--use_v2", action="store_true")
Expand Down
23 changes: 17 additions & 6 deletions tilelang/language/v2/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]

_dtype_cvt = [
# Base dtype conversion list
_dtype_cvt_base = [
(None, 'handle', ctypes.c_long, 'long', None), # use long to repr void*
(bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(int, 'int32', ctypes.c_int32, 'int', 'Int32'),
Expand All @@ -36,14 +37,24 @@
(torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'),
(None, 'float8_e4m3', None, None, 'Float8E4M3'),
(torch.float8_e4m3fn, 'float8_e4m3fn', None, None, 'Float8E4M3FN'),
(torch.float8_e4m3fnuz, 'float8_e4m3fnuz', None, None, 'Float8E4M3FNUZ'),
(torch.float8_e5m2, 'float8_e5m2', None, None, 'Float8E5M2'),
(torch.float8_e5m2fnuz, 'float8_e5m2fnuz', None, None, 'Float8E5M2FNUZ'),
(torch.float8_e8m0fnu, 'float8_e8m0fnu', None, None, 'Float8E8M0FNU'),
(torch.bfloat16, 'bfloat16', None, None, 'BFloat16'),
]

# Dynamically add fp8-related types if they exist in torch
_fp8_dtype_mappings = [
('float8_e4m3fn', 'Float8E4M3FN'),
('float8_e4m3fnuz', 'Float8E4M3FNUZ'),
('float8_e5m2', 'Float8E5M2'),
('float8_e5m2fnuz', 'Float8E5M2FNUZ'),
('float8_e8m0fnu', 'Float8E8M0FNU'),
]

_dtype_cvt = list(_dtype_cvt_base)
for torch_attr_name, tvm_name in _fp8_dtype_mappings:
if hasattr(torch, torch_attr_name):
torch_dtype = getattr(torch, torch_attr_name)
_dtype_cvt.append((torch_dtype, torch_attr_name, None, None, tvm_name))


def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x):
return {
Expand Down
Loading