Skip to content

Commit 622f1c4

Browse files
author
maleksan85
committed
Small improvement for linear
1 parent 474014a commit 622f1c4

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.logger import init_logger
1616
from vllm.model_executor.layers.quantization.base_config import (
1717
QuantizationConfig, QuantizeMethodBase)
18+
from vllm.model_executor.layers.tuned_gemm import tgemm
1819
# yapf: disable
1920
from vllm.model_executor.parameter import (BasevLLMParameter,
2021
BlockQuantScaleParameter,
@@ -137,7 +138,7 @@ def apply(self,
137138
layer: torch.nn.Module,
138139
x: torch.Tensor,
139140
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
140-
return torch.mm(x, torch.transpose(layer.weight, 0, 1))
141+
return tgemm.mm(x, layer.weight, bias)
141142

142143

143144
class LinearBase(torch.nn.Module):

vllm/model_executor/layers/tuned_gemm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import torch.nn.functional as F
99

1010
from vllm import _custom_ops as ops
11+
from vllm import envs
1112
from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM
1213
from vllm.platforms import current_platform
1314
from vllm.utils import is_mi250, is_navi
1415

1516
support_tuned_gemms = False
16-
if current_platform.is_rocm():
17+
if current_platform.is_rocm() and not envs.VLLM_USE_V1:
1718
import vllm._gradlib_C # noqa: F401
1819
support_tuned_gemms = True
1920

0 commit comments

Comments
 (0)