Skip to content

Commit

Permalink
Add kernel for GeGLU with approximate GELU (#3337)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Mar 13, 2024
1 parent 49a3c86 commit 602358f
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 7 deletions.
22 changes: 21 additions & 1 deletion csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,25 @@ template<typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float) x;
constexpr float ALPHA = M_SQRT1_2;
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}

template<typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float) x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715;
float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube);
return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
}

} // namespace vllm

// Launch activation and gating kernel.
Expand Down Expand Up @@ -73,6 +86,13 @@ void gelu_and_mul(
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
}

void gelu_tanh_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}

namespace vllm {

// Element-wise activation kernel template.
Expand Down
4 changes: 4 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ void gelu_and_mul(
torch::Tensor& out,
torch::Tensor& input);

void gelu_tanh_and_mul(
torch::Tensor& out,
torch::Tensor& input);

void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
Expand Down
6 changes: 5 additions & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU.");
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
Expand Down
11 changes: 8 additions & 3 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
]


@pytest.mark.parametrize("activation", [SiluAndMul, GeluAndMul])
@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_act_and_mul(
activation: Type[torch.nn.Module],
activation: str,
num_tokens: int,
d: int,
dtype: torch.dtype,
Expand All @@ -36,7 +36,12 @@ def test_act_and_mul(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
layer = activation()
if activation == "silu":
layer = SiluAndMul()
elif activation == "gelu":
layer = GeluAndMul(approximate="none")
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
out = layer(x)
ref_out = layer._forward(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,25 @@ class GeluAndMul(nn.Module):
return: (batch_size, seq_len, d) or (num_tokens, d)
"""

def __init__(self, approximate: str = "none"):
super().__init__()
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")

def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d]) * x[..., d:]
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]

def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.gelu_and_mul(out, x)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out


Expand Down

0 comments on commit 602358f

Please sign in to comment.