Skip to content

Commit

Permalink
Change to use weight_loader_v2 ModelWeightParameter and PerTensorScal…
Browse files Browse the repository at this point in the history
…eParameter
  • Loading branch information
pavanimajety committed Sep 6, 2024
1 parent 6b9ac06 commit 877fa68
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 64 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod"
"TPUInt8LinearMethod", "ModelOptFp8LinearMethod"
]


Expand Down
101 changes: 38 additions & 63 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
from torch.nn import Module
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported,
requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)

logger = init_logger(__name__)

Expand All @@ -26,16 +25,11 @@ class ModelOptFp8Config(QuantizationConfig):
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "static",
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change.")
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change.")

@classmethod
def get_name(cls) -> str:
Expand All @@ -47,7 +41,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:

@classmethod
def get_min_capability(cls) -> int:
return 80
return 89

@classmethod
def get_config_filenames(cls) -> List[str]:
Expand All @@ -58,9 +52,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
activation_scheme = "static"
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme)
if not is_checkpoint_fp8_serialized:
raise ValueError("ModelOpt currently only supports static FP8"
"quantization in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
return cls(is_checkpoint_fp8_serialized)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
Expand Down Expand Up @@ -112,66 +109,44 @@ def create_weights(
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)

layer.process_after_load = True
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype,
),
requires_grad=False,
)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
set_weight_attrs(
weight,
{
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
},
)

# If checkpoint is fp8, load them.
# Otherwise, wait until process_weights_after_loading.

if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, **extra_weight_attrs)
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
input_scale = create_per_tensor_scale_param(
output_partition_sizes, **extra_weight_attrs)
layer.register_parameter("input_scale", input_scale)
# INPUT SCALE
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)

scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)

def process_weights_after_loading(self, layer: Module) -> None:
if (not hasattr(layer, "process_after_load")
or not layer.process_after_load):
return
# If checkpoint is fp/bf16 and not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.logical_widths = None
layer.input_scale = None
return

else:
# weight_dtype = torch.float8_e4m3fn
weight_scale = layer.weight_scale.to(torch.float32)
max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)

def apply(
self,
Expand Down

0 comments on commit 877fa68

Please sign in to comment.