diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 6b09f0823e..3499fe67ff 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 6b09f0823e64df0bcd035443c9ce4599a838de02 +Subproject commit 3499fe67ff24f9e3610b208d14589fac645e0ea7 diff --git a/aiter/ops/rmsnorm.py b/aiter/ops/rmsnorm.py index 56dc04873d..988d8c2be6 100644 --- a/aiter/ops/rmsnorm.py +++ b/aiter/ops/rmsnorm.py @@ -39,6 +39,7 @@ def rms_norm( input: Tensor, weight: Tensor, epsilon: float, + use_model_sensitive_rmsnorm: int, ): """ CK version of rmsnorm @@ -48,7 +49,10 @@ def rms_norm( @compile_ops("module_rmsnorm") def rmsnorm2d_fwd( - input: torch.Tensor, weight: torch.Tensor, epsilon: float + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + use_model_sensitive_rmsnorm: int, ) -> torch.Tensor: ... @@ -60,6 +64,7 @@ def rmsnorm2d_fwd_with_add( residual_out: Tensor, weight: Tensor, epsilon: float, + use_model_sensitive_rmsnorm: int, ): ... @@ -71,6 +76,7 @@ def rmsnorm2d_fwd_with_smoothquant( yscale: Tensor, weight: Tensor, epsilon: float, + use_model_sensitive_rmsnorm: int, ): ... @@ -84,12 +90,18 @@ def rmsnorm2d_fwd_with_add_smoothquant( yscale: Tensor, weight: Tensor, epsilon: float, + use_model_sensitive_rmsnorm: int, ): ... @compile_ops("module_rmsnorm") def rmsnorm2d_fwd_with_dynamicquant( - out: Tensor, input: Tensor, yscale: Tensor, weight: Tensor, epsilon: float + out: Tensor, + input: Tensor, + yscale: Tensor, + weight: Tensor, + epsilon: float, + use_model_sensitive_rmsnorm: int, ): ... @@ -102,4 +114,5 @@ def rmsnorm2d_fwd_with_add_dynamicquant( yscale: Tensor, weight: Tensor, epsilon: float, + use_model_sensitive_rmsnorm: int, ): ... diff --git a/csrc/include/rmsnorm.h b/csrc/include/rmsnorm.h index b1b3822610..966de4c5db 100644 --- a/csrc/include/rmsnorm.h +++ b/csrc/include/rmsnorm.h @@ -17,50 +17,64 @@ */ #include -void rms_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, - double epsilon); +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); -void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, - torch::Tensor &weight, double epsilon); +void fused_add_rms_norm(torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, + double epsilon); // ck -torch::Tensor rmsnorm2d(torch::Tensor &input, torch::Tensor &weight, - double epsilon); +torch::Tensor +rmsnorm2d(torch::Tensor& input, + torch::Tensor& weight, + double epsilon, + int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation -void rmsnorm2d_with_add(torch::Tensor &out, // [m ,n] - torch::Tensor &input, // [m ,n] - torch::Tensor &residual_in, // [m ,n] - torch::Tensor &residual_out, // [m ,n] - torch::Tensor &weight, // [1 ,n] - double epsilon); +void rmsnorm2d_with_add( + torch::Tensor& out, // [m ,n] + torch::Tensor& input, // [m ,n] + torch::Tensor& residual_in, // [m ,n] + torch::Tensor& residual_out, // [m ,n] + torch::Tensor& weight, // [1 ,n] + double epsilon, + int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation -void rmsnorm2d_with_smoothquant(torch::Tensor &out, // [m ,n] - torch::Tensor &input, // [m ,n] - torch::Tensor &xscale, // [1 ,n] - torch::Tensor &yscale, // [m ,1] - torch::Tensor &weight, // [1 ,n] - double epsilon); +void rmsnorm2d_with_smoothquant( + torch::Tensor& out, // [m ,n] + torch::Tensor& input, // [m ,n] + torch::Tensor& xscale, // [1 ,n] + torch::Tensor& yscale, // [m ,1] + torch::Tensor& weight, // [1 ,n] + double epsilon, + int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation -void rmsnorm2d_with_add_smoothquant(torch::Tensor &out, // [m ,n] - torch::Tensor &input, // [m ,n] - torch::Tensor &residual_in, // [m ,n] - torch::Tensor &residual_out, // [m ,n] - torch::Tensor &xscale, // [1 ,n] - torch::Tensor &yscale, // [m ,1] - torch::Tensor &weight, // [1 ,n] - double epsilon, - std::optional out_before_quant); +void rmsnorm2d_with_add_smoothquant( + torch::Tensor& out, // [m ,n] + torch::Tensor& input, // [m ,n] + torch::Tensor& residual_in, // [m ,n] + torch::Tensor& residual_out, // [m ,n] + torch::Tensor& xscale, // [1 ,n] + torch::Tensor& yscale, // [m ,1] + torch::Tensor& weight, // [1 ,n] + double epsilon, + std::optional out_before_quant, + int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation -void rmsnorm2d_with_dynamicquant(torch::Tensor &out, // [m ,n] - torch::Tensor &input, // [m ,n] - torch::Tensor &yscale, // [m ,1] - torch::Tensor &weight, // [1 ,n] - double epsilon); +void rmsnorm2d_with_dynamicquant( + torch::Tensor& out, // [m ,n] + torch::Tensor& input, // [m ,n] + torch::Tensor& yscale, // [m ,1] + torch::Tensor& weight, // [1 ,n] + double epsilon, + int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation -void rmsnorm2d_with_add_dynamicquant(torch::Tensor &out, // [m ,n] - torch::Tensor &input, // [m ,n] - torch::Tensor &residual_in, // [m ,n] - torch::Tensor &residual_out, // [m ,n] - torch::Tensor &yscale, // [m ,1] - torch::Tensor &weight, // [1 ,n] - double epsilon); +void rmsnorm2d_with_add_dynamicquant( + torch::Tensor& out, // [m ,n] + torch::Tensor& input, // [m ,n] + torch::Tensor& residual_in, // [m ,n] + torch::Tensor& residual_out, // [m ,n] + torch::Tensor& yscale, // [m ,1] + torch::Tensor& weight, // [1 ,n] + double epsilon, + int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 2552cdb956..2201c32e6f 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -18,8 +18,8 @@ m.def("mul_", &aiter_mul_, "apply for mul_ with transpose and broadcast."); \ m.def("sub_", &aiter_sub_, "apply for sub_ with transpose and broadcast."); \ m.def("div_", &aiter_div_, "apply for div_ with transpose and broadcast."); -#define AITER_UNARY_PYBIND \ - m.def("sigmoid", &aiter_sigmoid, "apply for sigmoid."); \ +#define AITER_UNARY_PYBIND \ + m.def("sigmoid", &aiter_sigmoid, "apply for sigmoid."); \ m.def("tanh", &aiter_tanh, "apply for tanh."); #define ATTENTION_ASM_MLA_PYBIND \ @@ -325,7 +325,7 @@ py::arg("x_scale"), \ py::arg("w_scale"), \ py::arg("Out"), \ - py::arg("splitK") = 0); + py::arg("splitK") = 0); #define GEMM_A8W8_BLOCKSCALE_PYBIND \ m.def("gemm_a8w8_blockscale", \ @@ -853,7 +853,12 @@ "Apply Root Mean Square (RMS) Normalization to the input tensor."); \ m.def( \ "fused_add_rms_norm_cu", &fused_add_rms_norm, "In-place fused Add and RMS Normalization"); \ - m.def("rmsnorm2d_fwd", &rmsnorm2d, py::arg("input"), py::arg("weight"), py::arg("epsilon")); \ + m.def("rmsnorm2d_fwd", \ + &rmsnorm2d, \ + py::arg("input"), \ + py::arg("weight"), \ + py::arg("epsilon"), \ + py::arg("use_model_sensitive_rmsnorm") = 0); \ m.def("rmsnorm2d_fwd_with_add", \ &rmsnorm2d_with_add, \ py::arg("out"), \ @@ -861,7 +866,8 @@ py::arg("residual_in"), \ py::arg("residual_out"), \ py::arg("weight"), \ - py::arg("epsilon")); \ + py::arg("epsilon"), \ + py::arg("use_model_sensitive_rmsnorm") = 0); \ m.def("rmsnorm2d_fwd_with_smoothquant", &rmsnorm2d_with_smoothquant); \ m.def("rmsnorm2d_fwd_with_add_smoothquant", \ &rmsnorm2d_with_add_smoothquant, \ @@ -873,7 +879,8 @@ py::arg("yscale"), \ py::arg("weight"), \ py::arg("epsilon"), \ - py::arg("out_before_quant") = std::nullopt); \ + py::arg("out_before_quant") = std::nullopt, \ + py::arg("use_model_sensitive_rmsnorm") = 0); \ m.def("rmsnorm2d_fwd_with_dynamicquant", &rmsnorm2d_with_dynamicquant); \ m.def("rmsnorm2d_fwd_with_add_dynamicquant", &rmsnorm2d_with_add_dynamicquant); diff --git a/csrc/py_itfs_ck/rmsnorm_ck_kernels.cu b/csrc/py_itfs_ck/rmsnorm_ck_kernels.cu index 6f236ae881..5701d70e84 100644 --- a/csrc/py_itfs_ck/rmsnorm_ck_kernels.cu +++ b/csrc/py_itfs_ck/rmsnorm_ck_kernels.cu @@ -1,96 +1,107 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include +#include "py_itfs_common.h" #include #include -#include "py_itfs_common.h" +#include #include "rmsnorm2d_fwd.hpp" -void rmsnorm2d(torch::Tensor &out, // [m, n] - torch::Tensor &input, // [m, n] - torch::Tensor &weight, // [1, n] - double epsilon) +void rmsnorm2d( + torch::Tensor& out, // [m, n] + torch::Tensor& input, // [m, n] + torch::Tensor& weight, // [1, n] + double epsilon, + int use_model_sensitive_rmsnorm = 0) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(dtype); - int n = input.size(-1); - int m = input.numel() / n; - int stride = input.stride(0); - int xr_stride = -1; - int y_stride = out.stride(0); - int yr_stride = -1; - bool SaveRms = false; + int n = input.size(-1); + int m = input.numel() / n; + int stride = input.stride(0); + int xr_stride = -1; + int y_stride = out.stride(0); + int yr_stride = -1; + bool SaveRms = false; const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - rmsnorm2d_fwd({ - dtype_str, // input precision - dtype_str, // output precision - dtype_str, // x-scale, used for [1*N] input smooth quant - dtype_str, // y-scale, used for [M*1] output for next layer - SaveRms, - false, // save_unquant - 0, // fused_add - 0 // fused_quant - }, + rmsnorm2d_fwd({dtype_str, // input precision + dtype_str, // output precision + dtype_str, // x-scale, used for [1*N] input smooth quant + dtype_str, // y-scale, used for [M*1] output for next layer + SaveRms, + false, // save_unquant + 0, // fused_add + 0, // fused_quant + use_model_sensitive_rmsnorm}, {input.data_ptr(), nullptr, // p_x_residual nullptr, // p_x_scale - weight.data_ptr(), out.data_ptr(), + weight.data_ptr(), + out.data_ptr(), nullptr, // p_y_residual nullptr, // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant - static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, + static_cast(epsilon), + m, + n, + stride, + xr_stride, + y_stride, + yr_stride}, {stream}); } -torch::Tensor rmsnorm2d(torch::Tensor &input, // [m, n] - torch::Tensor &weight, // [1, n] - double epsilon) +torch::Tensor rmsnorm2d( + torch::Tensor& input, // [m, n] + torch::Tensor& weight, // [1, n] + double epsilon, + int use_model_sensitive_rmsnorm = 0) // 0: Use default RMSNorm; 1: Use T5-like implementation { torch::Tensor out = torch::empty_like(input); - rmsnorm2d(out, input, weight, epsilon); + rmsnorm2d(out, input, weight, epsilon, use_model_sensitive_rmsnorm); return out; } -void rmsnorm2d_with_add(torch::Tensor &out, // [m ,n] - torch::Tensor &input, // [m ,n] - torch::Tensor &residual_in, // [m ,n] - torch::Tensor &residual_out, // [m ,n] - torch::Tensor &weight, // [1 ,n] - double epsilon) +void rmsnorm2d_with_add( + torch::Tensor& out, // [m ,n] + torch::Tensor& input, // [m ,n] + torch::Tensor& residual_in, // [m ,n] + torch::Tensor& residual_out, // [m ,n] + torch::Tensor& weight, // [1 ,n] + double epsilon, + int use_model_sensitive_rmsnorm = 0) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(input.dtype()); - int n = input.size(-1); - int m = input.numel() / n; - int stride = input.stride(0); - int xr_stride = residual_in.stride(0); - int y_stride = out.stride(0); - int yr_stride = residual_out.stride(0); - bool SaveRms = false; + int n = input.size(-1); + int m = input.numel() / n; + int stride = input.stride(0); + int xr_stride = residual_in.stride(0); + int y_stride = out.stride(0); + int yr_stride = residual_out.stride(0); + bool SaveRms = false; const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - rmsnorm2d_fwd({ - dtype_str, // input precision - dtype_str, // output precision - dtype_str, // x-scale, used for [1*N] input smooth quant - dtype_str, // y-scale, used for [M*1] output for next layer - SaveRms, - false, // save_unquant - 1, // fused_add - 0 // fused_quant - }, + rmsnorm2d_fwd({dtype_str, // input precision + dtype_str, // output precision + dtype_str, // x-scale, used for [1*N] input smooth quant + dtype_str, // y-scale, used for [M*1] output for next layer + SaveRms, + false, // save_unquant + 1, // fused_add + 0, // fused_quant + use_model_sensitive_rmsnorm}, {input.data_ptr(), // p_x residual_in.data_ptr(), // p_x_residual nullptr, // p_x_scale @@ -100,45 +111,52 @@ void rmsnorm2d_with_add(torch::Tensor &out, // [m ,n] nullptr, // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant - static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, + static_cast(epsilon), + m, + n, + stride, + xr_stride, + y_stride, + yr_stride}, {stream}); } -void rmsnorm2d_with_smoothquant(torch::Tensor &out, // [m ,n] - torch::Tensor &input, // [m ,n] - torch::Tensor &xscale, // [1 ,n] - torch::Tensor &yscale, // [m ,1] - torch::Tensor &weight, // [1 ,n] - double epsilon) +void rmsnorm2d_with_smoothquant( + torch::Tensor& out, // [m ,n] + torch::Tensor& input, // [m ,n] + torch::Tensor& xscale, // [1 ,n] + torch::Tensor& yscale, // [m ,1] + torch::Tensor& weight, // [1 ,n] + double epsilon, + int use_model_sensitive_rmsnorm = 0) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); - std::string dtype_str = torchDTypeToStr(input.dtype()); - std::string out_dtype_str = torchDTypeToStr(out.dtype()); + std::string dtype_str = torchDTypeToStr(input.dtype()); + std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string xscale_dtype_str = torchDTypeToStr(xscale.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); - int n = input.size(-1); - int m = input.numel() / n; - int stride = input.stride(0); - int xr_stride = -1; - int y_stride = out.stride(0); - int yr_stride = -1; - bool SaveRms = false; + int n = input.size(-1); + int m = input.numel() / n; + int stride = input.stride(0); + int xr_stride = -1; + int y_stride = out.stride(0); + int yr_stride = -1; + bool SaveRms = false; const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - rmsnorm2d_fwd({ - dtype_str, // input precision - out_dtype_str, // output precision - xscale_dtype_str, // x-scale, used for [1*N] input smooth quant - yscale_dtype_str, // y-scale, used for [M*1] output for next layer - SaveRms, - false, // save_unquant - 0, // fused_add - 1 // fused_quant - }, + rmsnorm2d_fwd({dtype_str, // input precision + out_dtype_str, // output precision + xscale_dtype_str, // x-scale, used for [1*N] input smooth quant + yscale_dtype_str, // y-scale, used for [M*1] output for next layer + SaveRms, + false, // save_unquant + 0, // fused_add + 1, // fused_quant + use_model_sensitive_rmsnorm}, {input.data_ptr(), // p_x nullptr, // p_x_residual xscale.data_ptr(), // p_x_scale @@ -148,94 +166,109 @@ void rmsnorm2d_with_smoothquant(torch::Tensor &out, // [m ,n] yscale.data_ptr(), // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant - static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, + static_cast(epsilon), + m, + n, + stride, + xr_stride, + y_stride, + yr_stride}, {stream}); } -void rmsnorm2d_with_add_smoothquant(torch::Tensor &out, // [m ,n] - torch::Tensor &input, // [m ,n] - torch::Tensor &residual_in, // [m ,n] - torch::Tensor &residual_out, // [m ,n] - torch::Tensor &xscale, // [1 ,n] - torch::Tensor &yscale, // [m ,1] - torch::Tensor &weight, // [1 ,n] - double epsilon, - std::optional out_before_quant) +void rmsnorm2d_with_add_smoothquant( + torch::Tensor& out, // [m ,n] + torch::Tensor& input, // [m ,n] + torch::Tensor& residual_in, // [m ,n] + torch::Tensor& residual_out, // [m ,n] + torch::Tensor& xscale, // [1 ,n] + torch::Tensor& yscale, // [m ,1] + torch::Tensor& weight, // [1 ,n] + double epsilon, + std::optional out_before_quant, + int use_model_sensitive_rmsnorm = 0) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); - std::string dtype_str = torchDTypeToStr(input.dtype()); - std::string out_dtype_str = torchDTypeToStr(out.dtype()); + std::string dtype_str = torchDTypeToStr(input.dtype()); + std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string xscale_dtype_str = torchDTypeToStr(xscale.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); - int n = input.size(-1); - int m = input.numel() / n; - int stride = input.stride(0); - int xr_stride = residual_in.stride(0); - int y_stride = out.stride(0); - int yr_stride = residual_out.stride(0); - bool SaveRms = false; + int n = input.size(-1); + int m = input.numel() / n; + int stride = input.stride(0); + int xr_stride = residual_in.stride(0); + int y_stride = out.stride(0); + int yr_stride = residual_out.stride(0); + bool SaveRms = false; const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - rmsnorm2d_fwd({ - dtype_str, // input precision - out_dtype_str, // output precision - xscale_dtype_str, // x-scale, used for [1*N] input smooth quant - yscale_dtype_str, // y-scale, used for [M*1] output for next layer - SaveRms, - out_before_quant.has_value(), // save_unquant - 1, // fused_add - 1 // fused_quant - }, - {input.data_ptr(), // p_x - residual_in.data_ptr(), // p_x_residual - xscale.data_ptr(), // p_x_scale - weight.data_ptr(), // p_gamma - out.data_ptr(), // p_y - residual_out.data_ptr(), // p_y_residual - yscale.data_ptr(), // p_y_scale - nullptr, // p_invRms - out_before_quant.has_value() ? out_before_quant.value().data_ptr() : nullptr, // p_y_unquant - static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, + rmsnorm2d_fwd({dtype_str, // input precision + out_dtype_str, // output precision + xscale_dtype_str, // x-scale, used for [1*N] input smooth quant + yscale_dtype_str, // y-scale, used for [M*1] output for next layer + SaveRms, + out_before_quant.has_value(), // save_unquant + 1, // fused_add + 1, // fused_quant + use_model_sensitive_rmsnorm}, + {input.data_ptr(), // p_x + residual_in.data_ptr(), // p_x_residual + xscale.data_ptr(), // p_x_scale + weight.data_ptr(), // p_gamma + out.data_ptr(), // p_y + residual_out.data_ptr(), // p_y_residual + yscale.data_ptr(), // p_y_scale + nullptr, // p_invRms + out_before_quant.has_value() ? out_before_quant.value().data_ptr() + : nullptr, // p_y_unquant + static_cast(epsilon), + m, + n, + stride, + xr_stride, + y_stride, + yr_stride}, {stream}); } -void rmsnorm2d_with_dynamicquant(torch::Tensor &out, // [m ,n] - torch::Tensor &input, // [m ,n] - torch::Tensor &yscale, // [m ,1] - torch::Tensor &weight, // [1 ,n] - double epsilon) +void rmsnorm2d_with_dynamicquant( + torch::Tensor& out, // [m ,n] + torch::Tensor& input, // [m ,n] + torch::Tensor& yscale, // [m ,1] + torch::Tensor& weight, // [1 ,n] + double epsilon, + int use_model_sensitive_rmsnorm = 0) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); - std::string dtype_str = torchDTypeToStr(input.dtype()); - std::string out_dtype_str = torchDTypeToStr(out.dtype()); + std::string dtype_str = torchDTypeToStr(input.dtype()); + std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); - int n = input.size(-1); - int m = input.numel() / n; - int stride = input.stride(0); - int xr_stride = -1; - int y_stride = out.stride(0); - int yr_stride = -1; - bool SaveRms = false; + int n = input.size(-1); + int m = input.numel() / n; + int stride = input.stride(0); + int xr_stride = -1; + int y_stride = out.stride(0); + int yr_stride = -1; + bool SaveRms = false; const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - rmsnorm2d_fwd({ - dtype_str, // input precision - out_dtype_str, // output precision - dtype_str, // x-scale, used for [1*N] input smooth quant - yscale_dtype_str, // y-scale, used for [M*1] output for next layer - SaveRms, - false, // save_unquant - 0, // fused_add - 2 // fused_quant - }, + rmsnorm2d_fwd({dtype_str, // input precision + out_dtype_str, // output precision + dtype_str, // x-scale, used for [1*N] input smooth quant + yscale_dtype_str, // y-scale, used for [M*1] output for next layer + SaveRms, + false, // save_unquant + 0, // fused_add + 2, // fused_quant + use_model_sensitive_rmsnorm}, {input.data_ptr(), // p_x nullptr, // p_x_residual nullptr, // p_x_scale @@ -245,45 +278,52 @@ void rmsnorm2d_with_dynamicquant(torch::Tensor &out, // [m ,n] yscale.data_ptr(), // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant - static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, + static_cast(epsilon), + m, + n, + stride, + xr_stride, + y_stride, + yr_stride}, {stream}); } -void rmsnorm2d_with_add_dynamicquant(torch::Tensor &out, // [m ,n] - torch::Tensor &input, // [m ,n] - torch::Tensor &residual_in, // [m ,n] - torch::Tensor &residual_out, // [m ,n] - torch::Tensor &yscale, // [m ,1] - torch::Tensor &weight, // [1 ,n] - double epsilon) +void rmsnorm2d_with_add_dynamicquant( + torch::Tensor& out, // [m ,n] + torch::Tensor& input, // [m ,n] + torch::Tensor& residual_in, // [m ,n] + torch::Tensor& residual_out, // [m ,n] + torch::Tensor& yscale, // [m ,1] + torch::Tensor& weight, // [1 ,n] + double epsilon, + int use_model_sensitive_rmsnorm = 0) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); - std::string dtype_str = torchDTypeToStr(input.dtype()); - std::string out_dtype_str = torchDTypeToStr(out.dtype()); + std::string dtype_str = torchDTypeToStr(input.dtype()); + std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); - int n = input.size(-1); - int m = input.numel() / n; - int stride = input.stride(0); - int xr_stride = residual_in.stride(0); - int y_stride = out.stride(0); - int yr_stride = residual_out.stride(0); - bool SaveRms = false; + int n = input.size(-1); + int m = input.numel() / n; + int stride = input.stride(0); + int xr_stride = residual_in.stride(0); + int y_stride = out.stride(0); + int yr_stride = residual_out.stride(0); + bool SaveRms = false; const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - rmsnorm2d_fwd({ - dtype_str, // input precision - out_dtype_str, // output precision - dtype_str, // x-scale, used for [1*N] input smooth quant - yscale_dtype_str, // y-scale, used for [M*1] output for next layer - SaveRms, - false, // save_unquant - 1, // fused_add - 2 // fused_quant - }, + rmsnorm2d_fwd({dtype_str, // input precision + out_dtype_str, // output precision + dtype_str, // x-scale, used for [1*N] input smooth quant + yscale_dtype_str, // y-scale, used for [M*1] output for next layer + SaveRms, + false, // save_unquant + 1, // fused_add + 2, // fused_quant + use_model_sensitive_rmsnorm}, {input.data_ptr(), // p_x residual_in.data_ptr(), // p_x_residual nullptr, // p_x_scale @@ -293,6 +333,12 @@ void rmsnorm2d_with_add_dynamicquant(torch::Tensor &out, // [m ,n] yscale.data_ptr(), // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant - static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, + static_cast(epsilon), + m, + n, + stride, + xr_stride, + y_stride, + yr_stride}, {stream}); } \ No newline at end of file diff --git a/op_tests/test_rmsnorm2d.py b/op_tests/test_rmsnorm2d.py index 9431c7d1ee..57260caf68 100644 --- a/op_tests/test_rmsnorm2d.py +++ b/op_tests/test_rmsnorm2d.py @@ -28,14 +28,22 @@ def run_torch(input, weight, eps, residual=None): @perftest() -def run_ck(input, weight, eps, residual=None): +def run_ck(input, weight, eps, residual=None, use_model_sensitive_rmsnorm=0): if residual is None: residual_out = None - output = aiter.rms_norm(input, weight, eps) + output = aiter.rms_norm(input, weight, eps, use_model_sensitive_rmsnorm) else: residual_out = torch.empty_like(input) output = torch.empty_like(input) - aiter.rmsnorm2d_fwd_with_add(output, input, residual, residual_out, weight, eps) + aiter.rmsnorm2d_fwd_with_add( + output, + input, + residual, + residual_out, + weight, + eps, + use_model_sensitive_rmsnorm, + ) return output, residual_out @@ -75,13 +83,19 @@ def test_rmsnorm2d_fuseAdd(dtype, m, n): # input = k (a, res_a, *_), avg_a = run_torch(input, weight, 1e-5, residual=res) (b, res_b, *_), avg_b = run_ck(input, weight, 1e-5, residual=res) - (c, res_c, *_), avg_c = run_cu(input, weight, 1e-5, residual=res) + (c, res_c, *_), avg_c = run_ck( + input, weight, 1e-5, residual=res, use_model_sensitive_rmsnorm=1 + ) + (d, res_d, *_), avg_d = run_cu(input, weight, 1e-5, residual=res) msg = f"[perf] dim: {str(dim):<20}, dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, cu avg: {avg_c:<8.2f} us,uplift: {avg_a/avg_b-1:<5.1%}" checkAllclose(a, b, atol=0.03, msg=msg) - checkAllclose(res_a, res_b, msg="ck res check") - # checkAllclose(a, c, atol=0.03, msg='cu') - # checkAllclose(res_a, res_c, atol=0.01, msg='cu res check') + checkAllclose(res_a, res_b, msg="ck res check (NO_SPECIFIC_MODEL)") + + checkAllclose(a, c, atol=0.03, msg=msg) + checkAllclose(res_a, res_c, msg="ck res check (T5_MODEL_LIKE)") + # checkAllclose(a, d, atol=0.03, msg='cu') + # checkAllclose(res_a, res_d, atol=0.01, msg='cu res check') # for dtype in [dtypes.fp16, dtypes.bf16]: