From 0f60606111fe00893ff405fda90892ed53aefe13 Mon Sep 17 00:00:00 2001 From: Fabian Grob <34524155+fabianandresgrob@users.noreply.github.com> Date: Fri, 31 May 2024 15:27:51 +0200 Subject: [PATCH] Feat (graph/gpfq): compression with random projection (#964) --- src/brevitas/graph/gpfq.py | 55 ++++++++++++++++--- .../imagenet_classification/ptq/ptq_common.py | 12 +++- .../ptq/ptq_evaluate.py | 16 +++++- 3 files changed, 71 insertions(+), 12 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 17d6a9c89..fd7df9223 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -2,10 +2,14 @@ # SPDX-License-Identifier: BSD-3-Clause from copy import deepcopy +import math +from math import pi from typing import Callable, List, Optional import numpy as np import torch +from torch.fft import fft +from torch.fft import fftn import torch.nn as nn import unfoldNd @@ -19,6 +23,24 @@ import brevitas.nn as qnn +def random_projection( + float_input: torch.Tensor, quantized_input: torch.Tensor, compression_rate: float): + # use random projection to reduce dimensionality + n = quantized_input.size(1) + target_dim = int(compression_rate * n) + dev = float_input.device + # create gaussian random matrix + R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(target_dim, n), device=dev) + quantized_input = torch.transpose(quantized_input, 1, 2) @ R.T + float_input = torch.transpose(float_input, 1, 2) @ R.T + del R + # reshape back + quantized_input = torch.transpose(quantized_input, 1, 2) + float_input = torch.transpose(float_input, 1, 2) + + return float_input, quantized_input + + class gpfq_mode(gpxq_mode): """ Apply GPFQ algorithm. @@ -64,7 +86,8 @@ def __init__( act_order: bool = False, use_gpfa2q: bool = False, accumulator_bit_width: Optional[int] = None, - a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True) -> None: + a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True, + compression_rate: Optional[float] = 0.0) -> None: if not inplace: model = deepcopy(model) super().__init__( @@ -83,6 +106,11 @@ def __init__( self.accumulator_bit_width = accumulator_bit_width self.a2q_layer_filter_fnc = a2q_layer_filter_fnc # returns true when to use GPFA2Q + # selecting impl of random proj + self.compression_rate = compression_rate + if self.compression_rate < 0.0 or self.compression_rate > 1.0: + raise ValueError('Compression rate for random projection must be between 0 and 1.') + def catch_stopfwd(self, *args, **kwargs): # Collect quant input try: @@ -127,7 +155,8 @@ def initialize_module_optimizer( act_order=act_order, len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, - p=self.p) + p=self.p, + compression_rate=self.compression_rate) else: return GPFA2Q( layer=layer, @@ -136,7 +165,8 @@ def initialize_module_optimizer( len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, p=self.p, - accumulator_bit_width=self.accumulator_bit_width) + accumulator_bit_width=self.accumulator_bit_width, + compression_rate=self.compression_rate) class GPFQ(GPxQ): @@ -144,7 +174,9 @@ class GPFQ(GPxQ): Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main """ - def __init__(self, layer, name, act_order, len_parallel_layers, create_weight_orig, p) -> None: + def __init__( + self, layer, name, act_order, len_parallel_layers, create_weight_orig, p, + compression_rate) -> None: super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) @@ -152,6 +184,7 @@ def __init__(self, layer, name, act_order, len_parallel_layers, create_weight_or self.quantized_input = None self.index_computed = False self.p = p + self.compression_rate = compression_rate def update_batch(self, module, input, current_layer): if self.disable_pre_forward_hook: @@ -246,10 +279,12 @@ def single_layer_update(self): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] - U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) + if self.compression_rate > 0.0: + self.float_input, self.quantized_input = random_projection(self.float_input, self.quantized_input, self.compression_rate) self.float_input = self.float_input.to(dev) self.quantized_input = self.quantized_input.to(dev) + U = torch.zeros( + weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) # We don't need full Hessian, we just need the diagonal self.H_diag = self.quantized_input.transpose(2, 1).square().sum( 2) # summing over Batch dimension @@ -300,7 +335,8 @@ def __init__( len_parallel_layers, create_weight_orig, accumulator_bit_width, - p) -> None: + p, + compression_rate) -> None: GPFQ.__init__( self, layer=layer, @@ -308,7 +344,8 @@ def __init__( act_order=act_order, len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, - p=p) + p=p, + compression_rate=compression_rate) self.accumulator_bit_width = accumulator_bit_width assert self.accumulator_bit_width is not None @@ -329,6 +366,8 @@ def single_layer_update(self): weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] U = torch.zeros( weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) + if self.compression_rate > 0.0: + self.float_input, self.quantized_input = random_projection(self.float_input, self.quantized_input, self.compression_rate) self.float_input = self.float_input.to(dev) self.quantized_input = self.quantized_input.to(dev) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 2ac2af250..fcb0be367 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -535,7 +535,14 @@ def apply_gptq(calib_loader, model, act_order=False): gptq.update() -def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumulator_bit_width=None): +def apply_gpfq( + calib_loader, + model, + act_order, + p=1.0, + use_gpfa2q=False, + accumulator_bit_width=None, + compression_rate=0.0): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device @@ -545,7 +552,8 @@ def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumula use_quant_activations=True, act_order=act_order, use_gpfa2q=use_gpfa2q, - accumulator_bit_width=accumulator_bit_width) as gpfq: + accumulator_bit_width=accumulator_bit_width, + compression_rate=compression_rate) as gpfq: gpfq_model = gpfq.model for i in tqdm(range(gpfq.num_layers)): for i, (images, target) in enumerate(calib_loader): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 377e705ab..7e2bf6ee5 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -239,6 +239,12 @@ def parse_type(v, default_type): help= 'Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0)' ) +parser.add_argument( + '--compression-rate', + default=0.0, + type=float, + help='Specify compression rate < 1.0 for random projection. Default is 0.0 and does not use RP.' +) add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)') @@ -426,7 +432,12 @@ def main(): if args.gpfq: print("Performing GPFQ:") - apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpxq_act_order) + apply_gpfq( + calib_loader, + quant_model, + p=args.gpfq_p, + act_order=args.gpxq_act_order, + compression_rate=args.compression_rate) if args.gpfa2q: print("Performing GPFA2Q:") @@ -436,7 +447,8 @@ def main(): p=args.gpfq_p, act_order=args.gpxq_act_order, use_gpfa2q=args.gpfa2q, - accumulator_bit_width=args.accumulator_bit_width) + accumulator_bit_width=args.accumulator_bit_width, + compression_rate=args.compression_rate) if args.gptq: print("Performing GPTQ:")