88
99from vllm .config import get_current_vllm_config
1010from vllm .model_executor .layers .layernorm import RMSNorm
11+ from vllm .model_executor .layers .quantization .input_quant_fp8 import QuantFP8
1112from vllm .model_executor .layers .quantization .utils .quant_utils import (
1213 QuantKey , _normalize_quant_group_shape , kFp8DynamicTensorSym ,
1314 kFp8DynamicTokenSym , kFp8StaticTensorSym )
@@ -100,17 +101,29 @@ def __call__(
100101
101102class MatcherQuant :
102103
103- def __init__ (self , quant_key : QuantKey ):
104+ def __init__ (self , quant_key : QuantKey , enabled : Optional [bool ] = None ):
105+
104106 self .quant_key = quant_key
105107 assert quant_key in QUANT_OPS , \
106108 f"unsupported quantization scheme { quant_key } "
107109 self .QUANT_OP = QUANT_OPS [quant_key ]
108110
109- def forward (
111+ assert quant_key .scale2 is None
112+ self .quant_fp8 = QuantFP8 (quant_key .scale .static ,
113+ quant_key .scale .group_shape )
114+
115+ if enabled is None :
116+ # TODO either pass config to enabled or set it globally
117+ # (global during pass init seems reasonable)
118+ enabled = self .quant_fp8 .enabled ()
119+
120+ self .forward = self .forward_custom if enabled else self .forward_native
121+
122+ def forward_custom (
110123 self ,
111124 input : torch .Tensor ,
112125 scale : Optional [torch .Tensor ] = None ,
113- ) -> Union [ torch . Tensor , tuple [torch .Tensor , torch .Tensor ] ]:
126+ ) -> tuple [torch .Tensor , torch .Tensor ]:
114127 # TODO: why does empty_like produce a permute but
115128 # empty via shape doesn't?
116129 result = torch .empty (input .shape ,
@@ -123,7 +136,7 @@ def forward(
123136 result = result ,
124137 input = input ,
125138 scale = scale )
126- return result
139+ return result , scale
127140 else :
128141 assert scale is None
129142 scale = self .make_scale (input )
@@ -134,6 +147,13 @@ def forward(
134147 scale_ub = None )
135148 return result , scale
136149
150+ def forward_native (
151+ self ,
152+ input : torch .Tensor ,
153+ scale : Optional [torch .Tensor ] = None ,
154+ ) -> tuple [torch .Tensor , torch .Tensor ]:
155+ return self .quant_fp8 (input , scale )
156+
137157 def make_scale (self , input : torch .Tensor ):
138158 normalized_group_shape = _normalize_quant_group_shape (
139159 input , self .quant_key .scale .group_shape )
@@ -146,9 +166,8 @@ def make_scale(self, input: torch.Tensor):
146166 device = input .device ,
147167 dtype = torch .float32 )
148168
149- def __call__ (
150- self ,
151- input : torch .Tensor ,
152- scale : Optional [torch .Tensor ] = None
153- ) -> Union [torch .Tensor , tuple [torch .Tensor , torch .Tensor ]]:
169+ def __call__ (self ,
170+ input : torch .Tensor ,
171+ scale : Optional [torch .Tensor ] = None
172+ ) -> tuple [torch .Tensor , torch .Tensor ]:
154173 return self .forward (input , scale )
0 commit comments