Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 e4m3_fnuz support for rocm #2588

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 162 additions & 23 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from dataclasses import dataclass
from typing import Optional, Union, List
from typing import Optional, Tuple, Union, List
from loguru import logger

from text_generation_server.utils.import_utils import SYSTEM
Expand Down Expand Up @@ -51,8 +51,32 @@ def get_fp8_linear() -> torch.nn.Module:
return Fp8Linear


def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)

# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale


def fp8_quantize(
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
):
if FBGEMM_DYN_AVAILABLE and not scalar:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
Expand All @@ -62,8 +86,11 @@ def fp8_quantize(

# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)

if scale is None:
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)

# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
Expand All @@ -72,6 +99,10 @@ def fp8_quantize(
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()

if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should wire up scale at some point for CUDA as well.


return qweight, scale


Expand All @@ -92,9 +123,17 @@ def get_weights(self, weights: "Weights", prefix: str):
.reshape(-1)
.expand(w.shape[0])
)
try:
input_scale = weights.get_tensor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weights also has _has_tensor maybe we should make it public and use it here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for try: [...]get_tensor below.

f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
except Exception:
input_scale = None

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand Down Expand Up @@ -124,10 +163,25 @@ def get_weights_col_packed(
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])
try:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
except Exception:
input_scale = None

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand All @@ -153,10 +207,19 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
try:
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
]
input_scale = torch.cat(input_scale, dim=0).reshape(-1).max()
except Exception:
input_scale = None

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand All @@ -174,9 +237,17 @@ def get_weights_row(self, weights: "Weights", prefix: str):
.reshape(-1)
.expand(w.shape[0])
)
try:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
except Exception:
input_scale = None

return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
Expand All @@ -191,6 +262,7 @@ class Fp8Weight(Weight):
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None

def get_linear(self, bias: torch.Tensor):
Expand All @@ -200,56 +272,89 @@ def get_linear(self, bias: torch.Tensor):
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
self.weight,
self.weight_scale,
self.input_scale,
self.activation_scale_ub,
bias,
self.dtype,
)


class Fp8Linear(torch.nn.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be cleaner to have a separate Fp8LinearRocm?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason.

_device_identity_cache = {}

def __init__(
self,
qweight,
scale,
input_scale,
scale_upper_bound,
bias,
dtype,
) -> None:
super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
)

self.dtype = dtype
self.qweight = qweight
self.scale = scale
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
self.scale = scale.float()
self.input_scale = (
input_scale.float().reciprocal() if input_scale is not None else None
)

if FBGEMM_MM_AVAILABLE:
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
)
else:
self.scale_upper_bound = scale_upper_bound

self.bias = bias if bias is not None else None

@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
qweight=qweight,
scale=scale,
input_scale=None,
scale_upper_bound=None,
bias=bias,
dtype=dtype,
)

@classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
def from_fp8(cls, weight, scale, input_scale, scale_upper_bound, bias, dtype):
if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
scale_upper_bound=input_scale,
input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
bias=bias,
dtype=dtype,
)

@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]

def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE:
qinput, scale = fp8_quantize(
Expand All @@ -266,15 +371,49 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)
return y.to(self.dtype)

qinput, scale = fp8_quantize(input, scalar=True)
output, _ = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
qinput, scale = fp8_quantize(
input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)

per_tensor_weights = self.scale.numel() == 1
per_tensor_activations = scale.numel() == 1

if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
output = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
)

if type(output) is tuple and len(output) == 2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this change between torch versions or is output for AMD different?

Suggested change
if type(output) is tuple and len(output) == 2:
if isinstance(output, tuple) and len(output) == 2:

output = output[0]
else:
device_identity = None
if SYSTEM == "rocm":
device_identity = self.get_shared_device_identity(self.qweight.device)

output = torch._scaled_mm(
qinput,
self.qweight.t(),
scale_a=device_identity,
scale_b=device_identity,
out_dtype=torch.float32,
)
if type(output) is tuple and len(output) == 2:
output = output[0]

output = output * scale * self.scale.t()
if self.bias is not None:
output = output + self.bias

output = output.to(dtype=self.dtype)

return output


Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/layers/marlin/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def from_unquant(cls, weight, bias, dtype):
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)

@classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
def from_fp8(cls, weight, scale, _input_scale, _scale_upper_bound, bias, dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type.

These arguments get a bit messy. It's easy to mix up a tensor or a float (which was already happening here?). Maybe we should switch these to kwargs-only so that the call sites need to be explicit (+ type annotations).

return cls(qweight=weight, scales=scale.to(dtype), bias=bias)

def forward(self, A: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def get_model(
if method in {"gptq", "awq", "exl2"}:
log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method
elif method == "fbgemm_fp8":
elif method == "fbgemm_fp8" or method == "fp8":
log_master(logger.info, "Auto selecting quantization method fp8")
quantize = "fp8"
else:
Expand Down
Loading