Skip to content

Commit a73e34a

Browse files
committed
address comments
Signed-off-by: Han Qi <hanq@google.com>
1 parent c0bd708 commit a73e34a

File tree

2 files changed

+16
-137
lines changed

2 files changed

+16
-137
lines changed

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
2020
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
21-
CompressedTensorsW8A8Fp8MoEMethod
21+
VllmCompressedTensorsW8A8Fp8MoEMethod
2222
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
2323
VllmCompressedTensorsW8A8Fp8
2424
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
@@ -75,10 +75,6 @@ def get_scheme(self,
7575
return None
7676

7777
# TODO(kyuyeunk): Add support for different act_quant_format
78-
# act_quant_format = (
79-
# is_activation_quantization_format( # noqa: F841
80-
# format) if format is not None else
81-
# is_activation_quantization_format(self.quant_format))
8278

8379
linear_config = self.get_linear_config(layer)
8480
if self._is_fp8_w8a8(weight_quant, input_quant):
@@ -115,8 +111,8 @@ def get_quant_method(
115111
layer.scheme = scheme
116112
return CompressedTensorsLinearMethod(self)
117113
if isinstance(layer, FusedMoE):
118-
return CompressedTensorsW8A8Fp8MoEMethod(self, layer.quant_config,
119-
self.mesh)
114+
return VllmCompressedTensorsW8A8Fp8MoEMethod(
115+
self, layer.quant_config, self.mesh)
120116
if isinstance(layer, Attention):
121117
return CompressedTensorsKVCacheMethod(self)
122118
return None

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 13 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,30 @@
1212
from torchax.interop import call_jax, torch_view
1313
from torchax.ops.mappings import t2j
1414
from vllm.logger import init_logger
15-
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
16-
FusedMoeWeightScaleSupported)
17-
from vllm.model_executor.layers.fused_moe.config import (
18-
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
19-
from vllm.model_executor.layers.quantization.compressed_tensors import \
20-
compressed_tensors_moe as vllm_ct_moe
15+
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
2116
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
2217
CompressedTensorsConfig
18+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \
19+
CompressedTensorsW8A8Fp8MoEMethod
2320
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
2421
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
25-
from vllm.model_executor.utils import set_weight_attrs
22+
23+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
2624

2725
logger = init_logger(__name__)
2826

2927

30-
class CompressedTensorsW8A8Fp8MoEMethod(vllm_ct_moe.CompressedTensorsMoEMethod
31-
):
28+
class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
29+
JaxCommonConfig):
3230

3331
def __init__(self, quant_config: "CompressedTensorsConfig",
3432
moe: FusedMoEConfig, mesh: Mesh):
35-
super().__init__(moe)
33+
super().__init__(quant_config, moe)
34+
35+
self.use_marlin = False
36+
self.use_cutlass = False
37+
self.is_fp8_w8a8_sm100 = False
38+
3639
self.mesh = mesh
3740
self.quant_config = quant_config
3841
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -226,123 +229,3 @@ def apply(
226229
preferred_element_type=jnp.bfloat16.dtype)
227230

228231
return out
229-
230-
def create_weights(
231-
self,
232-
layer: torch.nn.Module,
233-
num_experts: int,
234-
hidden_size: int,
235-
intermediate_size_per_partition: int,
236-
params_dtype: torch.dtype,
237-
**extra_weight_attrs,
238-
):
239-
layer.intermediate_size_per_partition = intermediate_size_per_partition
240-
layer.hidden_size = hidden_size
241-
layer.num_experts = num_experts
242-
layer.orig_dtype = params_dtype
243-
layer.weight_block_size = None
244-
245-
params_dtype = torch.float8_e4m3fn
246-
247-
# WEIGHTS
248-
w13_weight = torch.nn.Parameter(
249-
torch.empty(
250-
num_experts,
251-
2 * intermediate_size_per_partition,
252-
hidden_size,
253-
dtype=params_dtype,
254-
),
255-
requires_grad=False,
256-
)
257-
layer.register_parameter("w13_weight", w13_weight)
258-
set_weight_attrs(w13_weight, extra_weight_attrs)
259-
260-
w2_weight = torch.nn.Parameter(
261-
torch.empty(
262-
num_experts,
263-
hidden_size,
264-
intermediate_size_per_partition,
265-
dtype=params_dtype,
266-
),
267-
requires_grad=False,
268-
)
269-
layer.register_parameter("w2_weight", w2_weight)
270-
set_weight_attrs(w2_weight, extra_weight_attrs)
271-
272-
# WEIGHT_SCALES
273-
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
274-
# Allocate 2 scales for w1 and w3 respectively.
275-
# They are combined to a single scale after weight loading.
276-
w13_weight_scale = torch.nn.Parameter(torch.ones(
277-
num_experts, 2, dtype=torch.float32),
278-
requires_grad=False)
279-
layer.register_parameter("w13_weight_scale", w13_weight_scale)
280-
w2_weight_scale = torch.nn.Parameter(torch.ones(
281-
num_experts, dtype=torch.float32),
282-
requires_grad=False)
283-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
284-
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
285-
extra_weight_attrs.update(
286-
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
287-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
288-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
289-
290-
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
291-
w13_weight_scale = torch.nn.Parameter(
292-
torch.ones(
293-
num_experts,
294-
2 * intermediate_size_per_partition,
295-
1,
296-
dtype=torch.float32,
297-
),
298-
requires_grad=False,
299-
)
300-
layer.register_parameter("w13_weight_scale", w13_weight_scale)
301-
w2_weight_scale = torch.nn.Parameter(
302-
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
303-
requires_grad=False,
304-
)
305-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
306-
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
307-
extra_weight_attrs.update(
308-
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
309-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
310-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
311-
312-
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
313-
raise AssertionError('Blockwise quant for MoE not supported yet')
314-
315-
# INPUT_SCALES
316-
if self.static_input_scales:
317-
w13_input_scale = torch.nn.Parameter(torch.ones(
318-
num_experts, dtype=torch.float32),
319-
requires_grad=False)
320-
layer.register_parameter("w13_input_scale", w13_input_scale)
321-
set_weight_attrs(w13_input_scale, extra_weight_attrs)
322-
323-
w2_input_scale = torch.nn.Parameter(torch.ones(
324-
num_experts, dtype=torch.float32),
325-
requires_grad=False)
326-
layer.register_parameter("w2_input_scale", w2_input_scale)
327-
set_weight_attrs(w2_input_scale, extra_weight_attrs)
328-
else:
329-
layer.w13_input_scale = None
330-
layer.w2_input_scale = None
331-
332-
def get_fused_moe_quant_config(
333-
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
334-
if self.use_marlin:
335-
return None
336-
337-
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
338-
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
339-
340-
return fp8_w8a8_moe_quant_config(
341-
w1_scale=layer.w13_weight_scale,
342-
w2_scale=layer.w2_weight_scale,
343-
a1_scale=layer.w13_input_scale,
344-
a2_scale=layer.w2_input_scale,
345-
per_act_token_quant=per_act_token,
346-
per_out_ch_quant=per_channel_quant,
347-
block_shape=layer.weight_block_size,
348-
)

0 commit comments

Comments
 (0)