diff --git a/maint/gemm_v2/latency_mha_fwd_bhsd.py b/maint/gemm_v2/latency_mha_fwd_bhsd.py index cbe93bf69..4126bb9d3 100644 --- a/maint/gemm_v2/latency_mha_fwd_bhsd.py +++ b/maint/gemm_v2/latency_mha_fwd_bhsd.py @@ -12,7 +12,7 @@ parser.add_argument('--heads', type=int, default=16, help='heads') parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length') parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length') -parser.add_argument('--dim', type=int, default=512, help='dim') +parser.add_argument('--dim', type=int, default=256, help='dim') parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--use_v2", action="store_true") diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index 57ef60787..2161e3770 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -10,7 +10,8 @@ # Python 3.9 compatibility: avoid PEP 604 unions at runtime AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] -_dtype_cvt = [ +# Base dtype conversion list +_dtype_cvt_base = [ (None, 'handle', ctypes.c_long, 'long', None), # use long to repr void* (bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'), (int, 'int32', ctypes.c_int32, 'int', 'Int32'), @@ -36,14 +37,24 @@ (torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'), (torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'), (None, 'float8_e4m3', None, None, 'Float8E4M3'), - (torch.float8_e4m3fn, 'float8_e4m3fn', None, None, 'Float8E4M3FN'), - (torch.float8_e4m3fnuz, 'float8_e4m3fnuz', None, None, 'Float8E4M3FNUZ'), - (torch.float8_e5m2, 'float8_e5m2', None, None, 'Float8E5M2'), - (torch.float8_e5m2fnuz, 'float8_e5m2fnuz', None, None, 'Float8E5M2FNUZ'), - (torch.float8_e8m0fnu, 'float8_e8m0fnu', None, None, 'Float8E8M0FNU'), (torch.bfloat16, 'bfloat16', None, None, 'BFloat16'), ] +# Dynamically add fp8-related types if they exist in torch +_fp8_dtype_mappings = [ + ('float8_e4m3fn', 'Float8E4M3FN'), + ('float8_e4m3fnuz', 'Float8E4M3FNUZ'), + ('float8_e5m2', 'Float8E5M2'), + ('float8_e5m2fnuz', 'Float8E5M2FNUZ'), + ('float8_e8m0fnu', 'Float8E8M0FNU'), +] + +_dtype_cvt = list(_dtype_cvt_base) +for torch_attr_name, tvm_name in _fp8_dtype_mappings: + if hasattr(torch, torch_attr_name): + torch_dtype = getattr(torch, torch_attr_name) + _dtype_cvt.append((torch_dtype, torch_attr_name, None, None, tvm_name)) + def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x): return {