diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 72afcd28..a702c8ee 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -146,6 +146,10 @@ using namespace flashinfer; #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ case at::ScalarType::Float8_e4m3fn: { \ using c_type = __nv_fp8_e4m3; \ return __VA_ARGS__(); \