Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/models/language/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@
pytest.param(
"allenai/OLMoE-1B-7B-0924-Instruct",
marks=[pytest.mark.cpu_model],
)
),
pytest.param("swiss-ai/Apertus-8B"), # apertus
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
Expand Down
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def check_available_online(
# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only]
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B",
min_transformers_version="4.56.0",
trust_remote_code=True),
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B",
trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
# Skip if transformers version is incompatible
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_transformers_version(on_fail="skip")
# Ensure all model classes can be imported successfully
model_cls = ModelRegistry._try_load_model_cls(model_arch)
assert model_cls is not None
Expand Down
111 changes: 111 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@

from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import LazyDict

logger = init_logger(__name__)


@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
Expand Down Expand Up @@ -363,6 +366,112 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)


@CustomOp.register("xielu")
class XIELU(CustomOp):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""

def __init__(
self,
alpha_p_init: float = 0.8,
alpha_n_init: float = 0.8,
beta: float = 0.5,
eps: float = -1e-6,
dtype: torch.dtype = torch.bfloat16,
with_vector_loads: bool = False,
):
super().__init__()
self.alpha_p = nn.Parameter(
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) -
1).unsqueeze(0))
self.alpha_n = nn.Parameter(
torch.log(
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) -
1).unsqueeze(0))
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
self.with_vector_loads = with_vector_loads
# Temporary until xIELU CUDA fully implemented
self._beta_scalar = float(self.beta.detach().cpu().float().item())
self._eps_scalar = float(self.eps.detach().cpu().float().item())

self._xielu_cuda_obj = None
try:
import xielu.ops # noqa: F401

self._xielu_cuda_obj = torch.classes.xielu.XIELU()
msg = "Using experimental xIELU CUDA."
try:
from torch._dynamo import allow_in_graph

self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
msg += " Enabled torch._dynamo for xIELU CUDA."
except Exception as err:
msg += (f" Could not enable torch._dynamo for xIELU ({err}) - "
"this may result in slower performance.")
self._xielu_cuda_fn = self._xielu_cuda
logger.warning_once(msg)
except Exception as err:
logger.warning_once(
"CUDA-fused xIELU not available (%s) –"
" falling back to a Python version.\n"
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
str(err),
)

def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
alpha_p = nn.functional.softplus(self.alpha_p)
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
return torch.where(
x > 0,
alpha_p * x * x + self.beta * x,
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n +
self.beta * x,
)

def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""Firewall function to prevent torch.compile from seeing .item()"""
assert self._xielu_cuda_obj is not None, (
"XIELU CUDA object must not be None")
original_shape = x.shape
# CUDA kernel expects 3D tensors, reshape if needed
while x.dim() < 3:
x = x.unsqueeze(0)
if x.dim() > 3:
x = x.view(-1, 1, x.size(-1))
if original_shape != x.shape:
logger.warning_once(
"Warning: xIELU input tensor expects 3 dimensions"
" but got (shape: %s). Reshaping to (shape: %s).",
original_shape,
x.shape,
)
result = self._xielu_cuda_obj.forward(
x,
self.alpha_p,
self.alpha_n,
# Temporary until xIELU CUDA fully implemented ->
# self.{beta,eps}.item()
self._beta_scalar,
self._eps_scalar,
self.with_vector_loads,
)
return result.view(original_shape)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._xielu_cuda_obj is not None and input.is_cuda:
if not torch._dynamo.is_compiling():
return self._xielu_cuda_fn(input)
else:
logger.warning_once(
"torch._dynamo is compiling, using Python version of xIELU."
)
return self._xielu_python(input)


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.

Expand Down Expand Up @@ -426,6 +535,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
lambda: nn.Tanh(),
"sigmoid":
lambda: nn.Sigmoid(),
"xielu":
lambda: XIELU(),
})


Expand Down
Loading