Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions aiter/ops/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def rms_norm(
input: Tensor,
weight: Tensor,
epsilon: float,
use_model_sensitive_rmsnorm: int,
):
"""
CK version of rmsnorm
Expand All @@ -48,7 +49,10 @@ def rms_norm(

@compile_ops("module_rmsnorm")
def rmsnorm2d_fwd(
input: torch.Tensor, weight: torch.Tensor, epsilon: float
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
use_model_sensitive_rmsnorm: int,
) -> torch.Tensor: ...


Expand All @@ -60,6 +64,7 @@ def rmsnorm2d_fwd_with_add(
residual_out: Tensor,
weight: Tensor,
epsilon: float,
use_model_sensitive_rmsnorm: int,
): ...


Expand All @@ -71,6 +76,7 @@ def rmsnorm2d_fwd_with_smoothquant(
yscale: Tensor,
weight: Tensor,
epsilon: float,
use_model_sensitive_rmsnorm: int,
): ...


Expand All @@ -84,12 +90,18 @@ def rmsnorm2d_fwd_with_add_smoothquant(
yscale: Tensor,
weight: Tensor,
epsilon: float,
use_model_sensitive_rmsnorm: int,
): ...


@compile_ops("module_rmsnorm")
def rmsnorm2d_fwd_with_dynamicquant(
out: Tensor, input: Tensor, yscale: Tensor, weight: Tensor, epsilon: float
out: Tensor,
input: Tensor,
yscale: Tensor,
weight: Tensor,
epsilon: float,
use_model_sensitive_rmsnorm: int,
): ...


Expand All @@ -102,4 +114,5 @@ def rmsnorm2d_fwd_with_add_dynamicquant(
yscale: Tensor,
weight: Tensor,
epsilon: float,
use_model_sensitive_rmsnorm: int,
): ...
92 changes: 53 additions & 39 deletions csrc/include/rmsnorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,50 +17,64 @@
*/
#include <torch/extension.h>

void rms_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight,
double epsilon);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon);

void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
torch::Tensor &weight, double epsilon);
void fused_add_rms_norm(torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
double epsilon);

// ck
torch::Tensor rmsnorm2d(torch::Tensor &input, torch::Tensor &weight,
double epsilon);
torch::Tensor
rmsnorm2d(torch::Tensor& input,
torch::Tensor& weight,
double epsilon,
int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation

void rmsnorm2d_with_add(torch::Tensor &out, // [m ,n]
torch::Tensor &input, // [m ,n]
torch::Tensor &residual_in, // [m ,n]
torch::Tensor &residual_out, // [m ,n]
torch::Tensor &weight, // [1 ,n]
double epsilon);
void rmsnorm2d_with_add(
torch::Tensor& out, // [m ,n]
torch::Tensor& input, // [m ,n]
torch::Tensor& residual_in, // [m ,n]
torch::Tensor& residual_out, // [m ,n]
torch::Tensor& weight, // [1 ,n]
double epsilon,
int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation

void rmsnorm2d_with_smoothquant(torch::Tensor &out, // [m ,n]
torch::Tensor &input, // [m ,n]
torch::Tensor &xscale, // [1 ,n]
torch::Tensor &yscale, // [m ,1]
torch::Tensor &weight, // [1 ,n]
double epsilon);
void rmsnorm2d_with_smoothquant(
torch::Tensor& out, // [m ,n]
torch::Tensor& input, // [m ,n]
torch::Tensor& xscale, // [1 ,n]
torch::Tensor& yscale, // [m ,1]
torch::Tensor& weight, // [1 ,n]
double epsilon,
int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation

void rmsnorm2d_with_add_smoothquant(torch::Tensor &out, // [m ,n]
torch::Tensor &input, // [m ,n]
torch::Tensor &residual_in, // [m ,n]
torch::Tensor &residual_out, // [m ,n]
torch::Tensor &xscale, // [1 ,n]
torch::Tensor &yscale, // [m ,1]
torch::Tensor &weight, // [1 ,n]
double epsilon,
std::optional<torch::Tensor> out_before_quant);
void rmsnorm2d_with_add_smoothquant(
torch::Tensor& out, // [m ,n]
torch::Tensor& input, // [m ,n]
torch::Tensor& residual_in, // [m ,n]
torch::Tensor& residual_out, // [m ,n]
torch::Tensor& xscale, // [1 ,n]
torch::Tensor& yscale, // [m ,1]
torch::Tensor& weight, // [1 ,n]
double epsilon,
std::optional<torch::Tensor> out_before_quant,
int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation

void rmsnorm2d_with_dynamicquant(torch::Tensor &out, // [m ,n]
torch::Tensor &input, // [m ,n]
torch::Tensor &yscale, // [m ,1]
torch::Tensor &weight, // [1 ,n]
double epsilon);
void rmsnorm2d_with_dynamicquant(
torch::Tensor& out, // [m ,n]
torch::Tensor& input, // [m ,n]
torch::Tensor& yscale, // [m ,1]
torch::Tensor& weight, // [1 ,n]
double epsilon,
int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation

void rmsnorm2d_with_add_dynamicquant(torch::Tensor &out, // [m ,n]
torch::Tensor &input, // [m ,n]
torch::Tensor &residual_in, // [m ,n]
torch::Tensor &residual_out, // [m ,n]
torch::Tensor &yscale, // [m ,1]
torch::Tensor &weight, // [1 ,n]
double epsilon);
void rmsnorm2d_with_add_dynamicquant(
torch::Tensor& out, // [m ,n]
torch::Tensor& input, // [m ,n]
torch::Tensor& residual_in, // [m ,n]
torch::Tensor& residual_out, // [m ,n]
torch::Tensor& yscale, // [m ,1]
torch::Tensor& weight, // [1 ,n]
double epsilon,
int use_model_sensitive_rmsnorm); // 0: Use default RMSNorm; 1: Use T5-like implementation
19 changes: 13 additions & 6 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
m.def("mul_", &aiter_mul_, "apply for mul_ with transpose and broadcast."); \
m.def("sub_", &aiter_sub_, "apply for sub_ with transpose and broadcast."); \
m.def("div_", &aiter_div_, "apply for div_ with transpose and broadcast.");
#define AITER_UNARY_PYBIND \
m.def("sigmoid", &aiter_sigmoid, "apply for sigmoid."); \
#define AITER_UNARY_PYBIND \
m.def("sigmoid", &aiter_sigmoid, "apply for sigmoid."); \
m.def("tanh", &aiter_tanh, "apply for tanh.");

#define ATTENTION_ASM_MLA_PYBIND \
Expand Down Expand Up @@ -325,7 +325,7 @@
py::arg("x_scale"), \
py::arg("w_scale"), \
py::arg("Out"), \
py::arg("splitK") = 0);
py::arg("splitK") = 0);

#define GEMM_A8W8_BLOCKSCALE_PYBIND \
m.def("gemm_a8w8_blockscale", \
Expand Down Expand Up @@ -853,15 +853,21 @@
"Apply Root Mean Square (RMS) Normalization to the input tensor."); \
m.def( \
"fused_add_rms_norm_cu", &fused_add_rms_norm, "In-place fused Add and RMS Normalization"); \
m.def("rmsnorm2d_fwd", &rmsnorm2d, py::arg("input"), py::arg("weight"), py::arg("epsilon")); \
m.def("rmsnorm2d_fwd", \
&rmsnorm2d, \
py::arg("input"), \
py::arg("weight"), \
py::arg("epsilon"), \
py::arg("use_model_sensitive_rmsnorm") = 0); \
m.def("rmsnorm2d_fwd_with_add", \
&rmsnorm2d_with_add, \
py::arg("out"), \
py::arg("input"), \
py::arg("residual_in"), \
py::arg("residual_out"), \
py::arg("weight"), \
py::arg("epsilon")); \
py::arg("epsilon"), \
py::arg("use_model_sensitive_rmsnorm") = 0); \
m.def("rmsnorm2d_fwd_with_smoothquant", &rmsnorm2d_with_smoothquant); \
m.def("rmsnorm2d_fwd_with_add_smoothquant", \
&rmsnorm2d_with_add_smoothquant, \
Expand All @@ -873,7 +879,8 @@
py::arg("yscale"), \
py::arg("weight"), \
py::arg("epsilon"), \
py::arg("out_before_quant") = std::nullopt); \
py::arg("out_before_quant") = std::nullopt, \
py::arg("use_model_sensitive_rmsnorm") = 0); \
m.def("rmsnorm2d_fwd_with_dynamicquant", &rmsnorm2d_with_dynamicquant); \
m.def("rmsnorm2d_fwd_with_add_dynamicquant", &rmsnorm2d_with_add_dynamicquant);

Expand Down
Loading