Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various SDXL quantization fixes #977

Merged
merged 24 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9181cff
Fix (examples/sdxl): Fix issue setting device when checkpoint is loaded.
nickfraser Jun 25, 2024
4ff1b43
Fix (example/sdxl): Added argument for linear output bitwidth.
nickfraser Jun 25, 2024
ce38f86
Fix (example/sdxl): Fix when replacing `diffusers.models.lora.LoRACom…
nickfraser Jun 25, 2024
975c9ee
Fix (example/sdxl): print output directory.
nickfraser Jun 26, 2024
b3ed0d8
Feat (example/sdxl): add extra option to quantize conv layers like SDP
nickfraser Jun 26, 2024
11f9dc8
fix (example/sdxl): Updated usage README.
nickfraser Jun 26, 2024
f77bf6c
Fix (example/sdxl): print which checkpoint is loaded.
nickfraser Jun 26, 2024
606ddec
Fix (example/sdxl): Move to CPU before 'param_only' export
nickfraser Jun 26, 2024
fb2fc87
Fix (example/sdxl): Added pandas requirement with specific version.
nickfraser Jun 28, 2024
cb3593b
Fix (example/sdxl): pre-commit
nickfraser Jun 28, 2024
fe30b66
Fix (example/sdxl): pre-commit fix to requirements.
nickfraser Jun 28, 2024
4546ffe
Fix model loading
Giuseppe5 Jul 3, 2024
99fc16f
Fix latents dtype
Giuseppe5 Jul 3, 2024
b8b31f8
Fix biwdith
Giuseppe5 Jul 3, 2024
4fd4998
Fix export
Giuseppe5 Jul 3, 2024
6f0d2f9
Fix tests
Giuseppe5 Jul 3, 2024
41bba9c
Feat (example/sdxl): Added fix for VAE @ FP16
nickfraser Jul 10, 2024
3241a5a
Fix (example/sdxl): Only apply VAE fix for SDXL
nickfraser Jul 10, 2024
307b128
Docs (example/sdxl): Updated usage
nickfraser Jul 10, 2024
b9cc9c1
Fix (example/generative): Added missing `use_fnuz` arg.
nickfraser Jul 12, 2024
fdfcafc
Update
Giuseppe5 Jul 15, 2024
ade7c6b
Lambda inspection
Giuseppe5 Jul 15, 2024
ce2993e
fix
Giuseppe5 Jul 15, 2024
5b0beb6
Add license
Giuseppe5 Jul 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,19 @@ def _module_attributes(self, module):
attrs['bias'] = module.bias
return attrs

def _evaluate_new_kwargs(self, new_kwargs, old_module):
def _evaluate_new_kwargs(self, new_kwargs, old_module, name):
update_dict = dict()
for k, v in self.new_module_kwargs.items():
if islambda(v):
v = v(old_module)
if name is not None:
v = v(old_module, name)
else:
v = v(old_module)
update_dict[k] = v
new_kwargs.update(update_dict)
return new_kwargs

def _init_new_module(self, old_module: Module):
def _init_new_module(self, old_module: Module, name=None):
# get attributes of original module
new_kwargs = self._module_attributes(old_module)
# transforms attribute of original module, e.g. bias Parameter -> bool
Expand All @@ -138,7 +141,7 @@ def _init_new_module(self, old_module: Module):
new_module_signature_keys = signature_keys(self.new_module_class)
new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys}
# update with kwargs passed to the rewriter
new_kwargs = self._evaluate_new_kwargs(new_kwargs, old_module)
new_kwargs = self._evaluate_new_kwargs(new_kwargs, old_module, name)
# init the new module
new_module = self.new_module_class(**new_kwargs)
return new_module
Expand Down Expand Up @@ -204,10 +207,10 @@ def __init__(self, old_module_instance, new_module_class, **kwargs):
self.old_module_instance = old_module_instance

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
for name, old_module in model.named_modules():
if old_module is self.old_module_instance:
# init the new module based on the old one
new_module = self._init_new_module(old_module)
new_module = self._init_new_module(old_module, name)
self._replace_old_module(model, old_module, new_module)
break
return model
Expand Down
19 changes: 14 additions & 5 deletions src/brevitas_examples/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--conv-input-bit-width CONV_INPUT_BIT_WIDTH]
[--act-eq-alpha ACT_EQ_ALPHA]
[--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH]
[--linear-output-bit-width LINEAR_OUTPUT_BIT_WIDTH]
[--weight-param-method {stats,mse}]
[--input-param-method {stats,mse}]
[--input-scale-stats-op {minmax,percentile}]
Expand All @@ -96,15 +97,16 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--quantize-input-zero-point | --no-quantize-input-zero-point]
[--export-cpu-float32 | --no-export-cpu-float32]
[--use-mlperf-inference | --no-use-mlperf-inference]
[--use-ocp | --no-use-ocp] [--use-nfuz | --no-use-nfuz]
[--use-ocp | --no-use-ocp] [--use-fnuz | --no-use-fnuz]
[--use-negative-prompts | --no-use-negative-prompts]
[--dry-run | --no-dry-run]
[--quantize-sdp-1 | --no-quantize-sdp-1]
[--quantize-sdp-2 | --no-quantize-sdp-2]
[--override-conv-quant-config | --no-override-conv-quant-config]

Stable Diffusion quantization

options:
optional arguments:
-h, --help show this help message and exit
-m MODEL, --model MODEL
Path or name of the model.
Expand Down Expand Up @@ -176,6 +178,8 @@ options:
Alpha for activation equalization. Default: 0.9
--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH
Input bit width. Default: 0 (not quantized).
--linear-output-bit-width LINEAR_OUTPUT_BIT_WIDTH
Input bit width. Default: 0 (not quantized).
--weight-param-method {stats,mse}
How scales/zero-point are determined. Default: stats.
--input-param-method {stats,mse}
Expand Down Expand Up @@ -241,9 +245,9 @@ options:
True
--no-use-ocp Disable Use OCP format for float quantization.
Default: True
--use-nfuz Enable Use NFUZ format for float quantization.
--use-fnuz Enable Use FNUZ format for float quantization.
Default: True
--no-use-nfuz Disable Use NFUZ format for float quantization.
--no-use-fnuz Disable Use FNUZ format for float quantization.
Default: True
--use-negative-prompts
Enable Use negative prompts during
Expand All @@ -259,5 +263,10 @@ options:
--no-quantize-sdp-1 Disable Quantize SDP. Default: Disabled
--quantize-sdp-2 Enable Quantize SDP. Default: Disabled
--no-quantize-sdp-2 Disable Quantize SDP. Default: Disabled

--override-conv-quant-config
Enable Quantize Convolutions in the same way as SDP
(i.e., FP8). Default: Disabled
--no-override-conv-quant-config
Disable Quantize Convolutions in the same way as SDP
(i.e., FP8). Default: Disabled
```
51 changes: 37 additions & 14 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from dependencies import value
from diffusers import DiffusionPipeline
from diffusers import EulerDiscreteScheduler
from diffusers import StableDiffusionXLPipeline
from diffusers.models.attention_processor import Attention
from diffusers.models.attention_processor import AttnProcessor
Expand All @@ -37,7 +38,6 @@
from brevitas.utils.torch_utils import KwargsForwardHook
from brevitas_examples.common.generative.quantize import generate_quant_maps
from brevitas_examples.common.generative.quantize import generate_quantizers
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.common.parse_utils import add_bool_arg
from brevitas_examples.common.parse_utils import quant_format_validator
from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager
Expand Down Expand Up @@ -152,13 +152,14 @@ def main(args):

latents = None
if args.path_to_latents is not None:
latents = torch.load(args.path_to_latents).to(torch.float16)
latents = torch.load(args.path_to_latents).to(dtype)

# Create output dir. Move to tmp if None
ts = datetime.fromtimestamp(time.time())
str_ts = ts.strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(args.output_path, f'{str_ts}')
os.mkdir(output_dir)
print(f"Saving results in {output_dir}")

# Dump args to json
with open(os.path.join(output_dir, 'args.json'), 'w') as fp:
Expand All @@ -169,7 +170,11 @@ def main(args):

# Load model from float checkpoint
print(f"Loading model from {args.model}...")
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype)
variant = 'fp16' if dtype == torch.float16 else None
pipe = DiffusionPipeline.from_pretrained(
args.model, torch_dtype=dtype, variant=variant, use_safetensors=True)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.vae.config.force_upcast = True
print(f"Model loaded from {args.model}.")

# Move model to target device
Expand Down Expand Up @@ -212,7 +217,7 @@ def main(args):

if args.activation_equalization:
pipe.set_progress_bar_config(disable=True)
with activation_equalization_mode(
with torch.no_grad(), activation_equalization_mode(
pipe.unet,
alpha=args.act_eq_alpha,
layerwise=True,
Expand Down Expand Up @@ -261,8 +266,6 @@ def input_bit_width(module):
return args.linear_input_bit_width
elif isinstance(module, nn.Conv2d):
return args.conv_input_bit_width
elif isinstance(module, QuantIdentity):
return args.quant_identity_bit_width
else:
raise RuntimeError(f"Module {module} not supported.")

Expand Down Expand Up @@ -345,7 +348,7 @@ def input_zp_stats_type():
weight_group_size=args.weight_group_size,
quantize_weight_zero_point=args.quantize_weight_zero_point,
quantize_input_zero_point=args.quantize_input_zero_point,
input_bit_width=input_bit_width,
input_bit_width=args.linear_output_bit_width,
input_quant_format='e4m3',
input_scale_type=args.input_scale_type,
input_scale_precision=args.input_scale_precision,
Expand All @@ -358,7 +361,6 @@ def input_zp_stats_type():
# We generate all quantizers, but we are only interested in activation quantization for
# the output of softmax and the output of QKV
input_quant = float_sdpa_quantizers[0]
input_quant = input_quant.let(**{'bit_width': args.linear_output_bit_width})
if args.quantize_sdp_2:
rewriter = ModuleToModuleByClass(
Attention,
Expand All @@ -374,14 +376,22 @@ def input_zp_stats_type():
config.IGNORE_MISSING_KEYS = False
pipe.unet = pipe.unet.to(args.device)
pipe.unet = pipe.unet.to(dtype)
quant_kwargs = layer_map[torch.nn.Linear][1]
quant_kwargs = layer_map['diffusers.models.lora.LoRACompatibleLinear'][1]
what_to_quantize = []
if args.quantize_sdp_1:
what_to_quantize.extend(['to_q', 'to_k'])
if args.quantize_sdp_2:
what_to_quantize.extend(['to_v'])
quant_kwargs['output_quant'] = lambda module, name: input_quant if any(ending in name for ending in what_to_quantize) else None
layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], quant_kwargs)

if args.override_conv_quant_config:
print(
f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}"
)
conv_qkwargs = layer_map[torch.nn.Conv2d][1]
conv_qkwargs['input_quant'] = float_sdpa_quantizers[2]
conv_qkwargs['weight_quant'] = float_sdpa_quantizers[1]
layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs)

pipe.unet = layerwise_quantize(
model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist)
Expand All @@ -405,11 +415,13 @@ def input_zp_stats_type():
if args.load_checkpoint is not None:
with load_quant_model_mode(pipe.unet):
pipe = pipe.to('cpu')
print(f"Loading checkpoint: {args.load_checkpoint}... ", end="")
pipe.unet.load_state_dict(torch.load(args.load_checkpoint, map_location='cpu'))
pipe = pipe.to(args.device)
print(f"Checkpoint loaded!")
pipe = pipe.to(args.device)
elif not args.dry_run:
if (args.linear_input_bit_width is not None or
args.conv_input_bit_width is not None) and args.input_scale_type == 'static':
if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or
args.linear_output_bit_width > 0) and args.input_scale_type == 'static':
print("Applying activation calibration")
with torch.no_grad(), calibration_mode(pipe.unet):
run_val_inference(
Expand Down Expand Up @@ -447,7 +459,7 @@ def input_zp_stats_type():
torch.cuda.empty_cache()
if args.bias_correction:
print("Applying bias correction")
with bias_correction_mode(pipe.unet):
with torch.no_grad(), bias_correction_mode(pipe.unet):
run_val_inference(
pipe,
args.resolution,
Expand Down Expand Up @@ -530,6 +542,7 @@ def input_zp_stats_type():
export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node)
export_onnx(pipe, trace_inputs, output_dir, export_manager)
if args.export_target == 'params_only':
pipe.to('cpu')
export_quant_params(pipe, output_dir)


Expand Down Expand Up @@ -648,6 +661,11 @@ def input_zp_stats_type():
type=int,
default=0,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--linear-output-bit-width',
type=int,
default=0,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--weight-param-method',
type=str,
Expand Down Expand Up @@ -775,6 +793,11 @@ def input_zp_stats_type():
help='Generate a quantized model without any calibration. Default: Disabled')
add_bool_arg(parser, 'quantize-sdp-1', default=False, help='Quantize SDP. Default: Disabled')
add_bool_arg(parser, 'quantize-sdp-2', default=False, help='Quantize SDP. Default: Disabled')
add_bool_arg(
parser,
'override-conv-quant-config',
default=False,
help='Quantize Convolutions in the same way as SDP (i.e., FP8). Default: Disabled')
args = parser.parse_args()
print("Args: " + str(vars(args)))
main(args)
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def compute_mlperf_fid(

if model_to_replace is not None:
model.pipe = model_to_replace

model.pipe.vae.config.force_upcast = True
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
ds = Coco(
data_path=path_to_coco,
name="coco-1024",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ accelerate==0.23.0
diffusers==0.21.2
open-clip-torch==2.7.0
opencv-python==4.8.1.78
pandas==2.2.2
pycocotools==2.0.7
scipy==1.9.1
torchmetrics[image]==1.2.0
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas_examples/stable_diffusion/sd_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ def export_quant_params(pipe, output_dir):
elif isinstance(
module,
QuantWeightBiasInputOutputLayer) and id(module) not in handled_quant_layers:
full_name = name
layer_dict = dict()
layer_dict = handle_quant_param(module, layer_dict)
quant_params[full_name] = layer_dict
handled_quant_layers.add(id(module))
elif isinstance(module, QuantNonLinearActLayer):
full_name = name
layer_dict = dict()
act_scale = module.act_quant.export_handler.symbolic_kwargs[
'dequantize_symbolic_kwargs']['scale'].data
Expand Down
2 changes: 1 addition & 1 deletion tests/brevitas/graph/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,6 @@ def forward(self, x):
model = TestModel()
assert model.conv.stride == (1, 1)

kwargs = {'stride': lambda module: 2 if module.in_channels == 3 else 1}
kwargs = {'stride': lambda module, name: 2 if module.in_channels == 3 else 1}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if anything in brevitas_examples will be affected by this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lambdas gets inspected to decide what's the correct signature to use, so we have backward compatibility

model = ModuleToModuleByInstance(model.conv, nn.Conv2d, **kwargs).apply(model)
assert model.conv.stride == (2, 2)
Loading