Skip to content

Commit a0462b4

Browse files
committed
Run pre-commit checks
Signed-off-by: Han Qi <hanq@google.com>
1 parent bf881f6 commit a0462b4

File tree

5 files changed

+188
-333
lines changed

5 files changed

+188
-333
lines changed

tpu_inference/layers/vllm/quantization/__init__.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@
22

33
from jax.sharding import Mesh
44
from vllm.config import VllmConfig
5-
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
5+
from vllm.model_executor.layers.quantization.base_config import \
6+
QuantizationConfig
67

78
from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
89
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
9-
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import (
10-
VllmCompressedTensorsConfig,
11-
) # noqa: E501
12-
from tpu_inference.layers.vllm.quantization.unquantized import VllmUnquantizedConfig
10+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
11+
VllmCompressedTensorsConfig # noqa: E501
12+
from tpu_inference.layers.vllm.quantization.unquantized import \
13+
VllmUnquantizedConfig
1314

1415

15-
def get_tpu_quantization_config(
16-
vllm_config: VllmConfig, mesh: Mesh
17-
) -> QuantizationConfig:
16+
def get_tpu_quantization_config(vllm_config: VllmConfig,
17+
mesh: Mesh) -> QuantizationConfig:
1818
model_config = copy.deepcopy(vllm_config.model_config)
1919
# TODO(kyuyeunk): Add support for "tpu_int8".
2020
method_to_config: dict[str, str] = {
@@ -29,11 +29,11 @@ def get_tpu_quantization_config(
2929
# breakpoint()
3030
if model_config.quantization not in method_to_config:
3131
raise NotImplementedError(
32-
f"{model_config.quantization} quantization method not supported."
33-
)
32+
f"{model_config.quantization} quantization method not supported.")
3433
quant_config = method_to_config[model_config.quantization]
3534
assert issubclass(quant_config, JaxCommonConfig)
3635
quant_config.set_configs(vllm_config, mesh)
3736

3837
model_config.quantization = quant_config.get_name()
39-
return VllmConfig.get_quantization_config(model_config, vllm_config.load_config)
38+
return VllmConfig.get_quantization_config(model_config,
39+
vllm_config.load_config)

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,45 +6,41 @@
66
from vllm.logger import init_logger
77
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
88
from vllm.model_executor.layers.linear import LinearBase
9-
from vllm.model_executor.layers.quantization import register_quantization_config
10-
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase # noqa: E501
9+
from vllm.model_executor.layers.quantization import \
10+
register_quantization_config
11+
from vllm.model_executor.layers.quantization.base_config import \
12+
QuantizeMethodBase # noqa: E501
1113
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
12-
CompressedTensorsConfig,
13-
CompressedTensorsKVCacheMethod,
14-
CompressedTensorsLinearMethod,
15-
CompressedTensorsScheme,
16-
)
17-
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import (
18-
CompressedTensorsW8A8Fp8MoEMethod,
19-
)
14+
CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
15+
CompressedTensorsLinearMethod, CompressedTensorsScheme)
2016
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
21-
find_matched_target,
22-
is_activation_quantization_format,
23-
should_ignore_layer,
24-
)
17+
find_matched_target, should_ignore_layer)
2518

2619
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
27-
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import (
28-
VllmCompressedTensorsW8A8Fp8,
29-
)
30-
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import (
31-
VllmCompressedTensorsW8A8Int8,
32-
)
33-
from tpu_inference.layers.vllm.quantization.unquantized import VllmUnquantizedConfig
20+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
21+
CompressedTensorsW8A8Fp8MoEMethod
22+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
23+
VllmCompressedTensorsW8A8Fp8
24+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
25+
VllmCompressedTensorsW8A8Int8
26+
from tpu_inference.layers.vllm.quantization.unquantized import \
27+
VllmUnquantizedConfig
3428

3529
P = PartitionSpec
3630
logger = init_logger(__name__)
3731

3832

3933
@register_quantization_config("jax-compressed-tensors")
4034
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
35+
4136
@classmethod
4237
def get_name(cls) -> str:
4338
return "jax-compressed-tensors"
4439

45-
def get_scheme(
46-
self, layer: torch.nn.Module, layer_name: Optional[str] = None
47-
) -> Optional["CompressedTensorsScheme"]:
40+
def get_scheme(self,
41+
layer: torch.nn.Module,
42+
layer_name: Optional[str] = None
43+
) -> Optional["CompressedTensorsScheme"]:
4844
"""
4945
compressed-tensors supports non uniform in the following way:
5046
@@ -71,24 +67,18 @@ def get_scheme(
7167
scheme_dict = self.target_scheme_map[matched_target]
7268
weight_quant = scheme_dict.get("weights")
7369
input_quant = scheme_dict.get("input_activations")
74-
format = scheme_dict.get("format")
7570

7671
if weight_quant is None:
77-
logger.warning_once(
78-
"Acceleration for non-quantized schemes is "
79-
"not supported by Compressed Tensors. "
80-
"Falling back to UnquantizedLinearMethod"
81-
)
72+
logger.warning_once("Acceleration for non-quantized schemes is "
73+
"not supported by Compressed Tensors. "
74+
"Falling back to UnquantizedLinearMethod")
8275
return None
8376

8477
# TODO(kyuyeunk): Add support for different act_quant_format
85-
act_quant_format = (
86-
is_activation_quantization_format( # noqa: F841
87-
format
88-
)
89-
if format is not None
90-
else is_activation_quantization_format(self.quant_format)
91-
)
78+
# act_quant_format = (
79+
# is_activation_quantization_format( # noqa: F841
80+
# format) if format is not None else
81+
# is_activation_quantization_format(self.quant_format))
9282

9383
linear_config = self.get_linear_config(layer)
9484
if self._is_fp8_w8a8(weight_quant, input_quant):
@@ -105,28 +95,29 @@ def get_scheme(
10595
input_symmetric=input_quant.symmetric,
10696
jax_config=linear_config,
10797
)
108-
raise NotImplementedError("No compressed-tensors compatible scheme was found.")
98+
raise NotImplementedError(
99+
"No compressed-tensors compatible scheme was found.")
109100

110101
def get_quant_method(
111102
self,
112103
layer: torch.nn.Module,
113104
prefix: str,
114105
) -> Optional[QuantizeMethodBase]:
115-
if should_ignore_layer(
116-
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
117-
):
106+
if should_ignore_layer(prefix,
107+
ignore=self.ignore,
108+
fused_mapping=self.packed_modules_mapping):
118109
return VllmUnquantizedConfig.get_quant_method(self, layer, prefix)
119110
if isinstance(layer, LinearBase):
120111
scheme = self.get_scheme(layer=layer, layer_name=prefix)
121112
if scheme is None:
122-
return VllmUnquantizedConfig.get_quant_method(self, layer, prefix)
113+
return VllmUnquantizedConfig.get_quant_method(
114+
self, layer, prefix)
123115
layer.scheme = scheme
124116
return CompressedTensorsLinearMethod(self)
125117
if isinstance(layer, FusedMoE):
126118
print("HERE", layer)
127-
return CompressedTensorsW8A8Fp8MoEMethod(
128-
self, layer.quant_config, self.mesh
129-
)
119+
return CompressedTensorsW8A8Fp8MoEMethod(self, layer.quant_config,
120+
self.mesh)
130121
if isinstance(layer, Attention):
131122
return CompressedTensorsKVCacheMethod(self)
132123
return None

0 commit comments

Comments
 (0)