From 51ce9f53e9cef0b84b77715ff1ae5708979a9fb5 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 21 Mar 2024 15:39:24 +0000 Subject: [PATCH] add use case for custom kernel for matvec operation --- vllm/model_executor/layers/linear.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 40e681df48f86..77e03aba573ad 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,6 +13,7 @@ divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs from vllm.logger import init_logger +from vllm import custom_ops logger = init_logger(__name__) @@ -72,6 +73,20 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight = weights["weight"] + if x.shape[0] == 1: + m, n, k = weight.shape[0], x.shape[0], x.shape[1] + out = torch.empty(x.shape[0], weight.shape[0], dtype=x.dtype) + if k == 8192 and (m == 1280 or m == 7168): + custom_ops.LLMM1(weight, x, out, 8) + elif k == 3584 and m == 8192: + custom_ops.LLMM1(weight, x, out, 8) + elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + custom_ops.LLMM1(weight, x, out, 4) + else: + out = F.linear(x, weight) + if bias != None: + out = out + bias + return out if self.separate_bias_add: if bias is not None: return F.linear(x, weight) + bias