Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pixart-sigma test to image example #247

Merged
merged 5 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions examples/vision/text-to-image/quantize_pixart_sigma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
import gc

import torch
from diffusers import DiffusionPipeline

from optimum.quanto import freeze, qfloat8, qint4, qint8, quantize


NUM_INFERENCE_STEPS = 50

TORCH_DTYPES = {"fp16": torch.float16, "bf16": torch.bfloat16}
QTYPES = {
"fp8": qfloat8,
"int8": qint8,
"int4": qint4,
"none": None,
}


def load_pipeline(model_id, torch_dtype, qtype=None, device="cpu"):
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True).to(device)

if qtype:
quantize(pipe.transformer, weights=qtype)
freeze(pipe.transformer)
quantize(pipe.text_encoder, weights=qtype)
freeze(pipe.text_encoder)

pipe.set_progress_bar_config(disable=True)
return pipe


def get_device_memory(device):
gc.collect()
if device.type == "cuda":
torch.cuda.empty_cache()
return torch.cuda.memory_allocated()
elif device.type == "mps":
torch.mps.empty_cache()
return torch.mps.current_allocated_memory()
return None


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default="PixArt-alpha/PixArt-Sigma-XL-2-1024-MS")
parser.add_argument("--prompt", type=str, default="ghibli style, a fantasy landscape with castles")
parser.add_argument("--torch_dtype", type=str, default="fp16", choices=list(TORCH_DTYPES.keys()))
parser.add_argument("--qtype", type=str, default=None, choices=list(QTYPES.keys()))
parser.add_argument("--device", type=str, default=None, help="The device to use for generation.")
args = parser.parse_args()

if args.device is None:
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
else:
device = torch.device(args.device)

pipeline = load_pipeline(
args.model_id, TORCH_DTYPES[args.torch_dtype], QTYPES[args.qtype] if args.qtype else None, device
)

print(f"torch_dtype: {args.torch_dtype}, qtype: {args.qtype}.")
memory = get_device_memory(device)
if memory is not None:
memory_gb = memory / 2**30
print(f"{device.type} device memory: {memory_gb:.2f} GB.")

if args.qtype == "int4" and device.type == "CUDA":
raise ValueError("This example does not work (yet) for int4 on CUDA")

img_name = f"pixart-sigma-dtype@{args.torch_dtype}-qtype@{args.qtype}.png"
image = pipeline(
prompt=args.prompt,
num_inference_steps=NUM_INFERENCE_STEPS,
num_images_per_prompt=1,
generator=torch.manual_seed(0),
).images[0]
image.save(img_name)
6 changes: 4 additions & 2 deletions optimum/quanto/library/qbytes_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def qbytes_mm(activations: torch.Tensor, weights: torch.Tensor, output_scales: t
# If one of the terms is an int the matmul might overflow
mm_dtype = torch.float32
activations = activations.to(mm_dtype)
weights = weights.to(mm_dtype)
outputs = torch.matmul(activations, weights.t()) * output_scales.t()
# Apply the scale to the weights before the matrix multiplication to put them back
# into their initial numerical range and avoid overflows
weights = weights.to(mm_dtype) * output_scales
outputs = torch.matmul(activations, weights.t())
return outputs.to(output_scales.dtype)


Expand Down
3 changes: 1 addition & 2 deletions optimum/quanto/tensor/optimizers/absmax_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
class AbsmaxOptimizer(SymmetricOptimizer):

def optimize(
self, base: torch.Tensor, bits: int, axis: Optional[int] = None
self, base: torch.Tensor, qmax: float, axis: Optional[int] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
base = torch.abs(base)
if axis is None:
rmax = torch.max(base)
else:
dim = list(range(1, base.ndim)) if (axis == 0) else list(range(0, base.ndim - 1))
rmax = torch.amax(torch.abs(base), dim=dim, keepdim=True)
qmax = 2 ** (bits - 1) - 1
return rmax / qmax
10 changes: 7 additions & 3 deletions optimum/quanto/tensor/optimizers/symmetric_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@

class SymmetricOptimizer(Optimizer):

def __call__(self, base: torch.Tensor, bits: int, axis: Optional[int] = None) -> torch.Tensor:
def __call__(self, base: torch.Tensor, qmax: float, axis: Optional[int] = None) -> torch.Tensor:
if axis not in [None, 0, -1]:
raise ValueError("axis parameter must be None, 0 (first axis) or -1 (last axis)")
scale = self.optimize(base, bits, axis)
if qmax <= 0.0:
raise ValueError(
"qmax must be set to the maximum positive value that can be represented by the quantized type."
)
scale = self.optimize(base, qmax, axis)
assert scale.dtype == base.dtype
return scale

def optimize(self, base: torch.Tensor, bits: int, axis: Optional[int] = None) -> torch.Tensor:
def optimize(self, base: torch.Tensor, qmax: float, axis: Optional[int] = None) -> torch.Tensor:
raise NotImplementedError
35 changes: 29 additions & 6 deletions optimum/quanto/tensor/qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class qtype:
bits: int
# This defines the storage dtype
dtype: torch.dtype
qmin: float
qmax: float

def __str__(self):
return f"quanto.{self.name}"
Expand All @@ -34,13 +36,34 @@ def __hash__(self):
return hash(str(self))


qint2 = qtype("qint2", is_floating_point=False, bits=2, dtype=torch.int8)
qint4 = qtype("qint4", is_floating_point=False, bits=4, dtype=torch.int8)
qint8 = qtype("qint8", is_floating_point=False, bits=8, dtype=torch.int8)
# Integer qtypes


def qint(bits):
qmin = -(2 ** (bits - 1))
qmax = 2 ** (bits - 1) - 1
return qtype(f"qint{bits}", is_floating_point=False, bits=bits, dtype=torch.int8, qmin=qmin, qmax=qmax)


qint2 = qint(2)
qint4 = qint(4)
qint8 = qint(8)

# Float qtypes


def qfloat(dtype: torch.dtype):
finfo = torch.finfo(dtype)
qmin = finfo.min
qmax = finfo.max
return qtype(f"q{finfo.dtype}", is_floating_point=True, bits=8, dtype=dtype, qmin=qmin, qmax=qmax)


qfloat8_e4m3fn = qfloat(torch.float8_e4m3fn)
qfloat8_e5m2 = qfloat(torch.float8_e5m2)

# Alias the float8 representation that has the better support and inference efficiency
qfloat8 = qtype("qfloat8", is_floating_point=True, bits=8, dtype=torch.float8_e4m3fn)
qfloat8_e4m3fn = qtype("qfloat8_e4m3fn", is_floating_point=True, bits=8, dtype=torch.float8_e4m3fn)
qfloat8_e5m2 = qtype("qfloat8_e5m2", is_floating_point=True, bits=8, dtype=torch.float8_e5m2)
qfloat8 = qfloat8_e4m3fn

# Convenience dict to get a dtype from its name
qtypes = {name: q for (name, q) in locals().items() if isinstance(q, qtype)}
Expand Down
2 changes: 1 addition & 1 deletion optimum/quanto/tensor/qweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def quantize_weight(
if axis is not None and t.shape[axis] == 1:
# Quantizing along an axis of dimension 1 means quantizing per-tensor
axis = None
scale = optimizer(t, qtype.bits, axis)
scale = optimizer(t, qtype.qmax, axis)
return SymmetricQuantizer.apply(t, qtype, axis, scale)
if optimizer is None:
optimizer = default_affine_optimizer
Expand Down
4 changes: 2 additions & 2 deletions test/library/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def test_qbytes_mm(batch_size, input_features, input_dtype, weight_dtype, output
input = random_tensor(input_shape, dtype=input_dtype, device=device)
weight = random_tensor((output_features, input_features), dtype=weight_dtype, device=device)
# Use a scale small enough to prevent overflows
scale = random_tensor((output_features,), dtype=output_dtype, device=device) / 1e3
scale = random_tensor((output_features, 1), dtype=output_dtype, device=device) / 1e3
output = torch.ops.quanto.qbytes_mm(input, weight, scale)
expected = torch.matmul(input.to(scale.dtype), weight.to(scale.dtype).t() * scale)
expected = torch.matmul(input.to(scale.dtype), (weight.to(scale.dtype) * scale).t())
assert_similar(expected, output)


Expand Down
Loading