Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (mx): gptq compatibility and quant tests #1013

Merged
merged 13 commits into from
Sep 5, 2024
24 changes: 12 additions & 12 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from brevitas.function.shape import over_output_channels
from brevitas.function.shape import over_output_features
from brevitas.function.shape import over_tensor
from brevitas.utils.torch_utils import padding


class PermuteDims(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -154,17 +155,19 @@ def forward(self, x: torch.Tensor):


class OverSubChannelBlockView(brevitas.jit.ScriptModule):
__constants__ = ['expanded_scaling_shape']
__constants__ = ['expanded_groupwise_shape', 'group_size', 'group_dim']

def __init__(self, expanded_scaling_shape, padding) -> None:
def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None:
super(OverSubChannelBlockView, self).__init__()
self.expanded_scaling_shape = expanded_scaling_shape
self.padding = padding
self.expanded_groupwise_shape = expanded_groupwise_shape
self.group_dim = group_dim
self.group_size = group_size

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
y = torch.nn.functional.pad(x, self.padding, mode='constant', value=0)
y = y.view(self.expanded_scaling_shape)
y = torch.nn.functional.pad(
x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.)
y = y.view(self.expanded_groupwise_shape)
return y


Expand All @@ -181,12 +184,9 @@ def forward(self, x):

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
padding = [0, 0] * len(tensor_shape_list)
if tensor_shape_list[self.group_dim] % self.group_size != 0:
padding[2 * self.group_dim] = self.group_size - tensor_shape_list[
self.group_dim] % self.group_size
padding = list(reversed(padding))
x = torch.nn.functional.pad(x, padding, mode='constant', value=0)
pad = padding(x, self.group_size, self.group_dim)

x = torch.nn.functional.pad(x, pad, mode='constant', value=0.)

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor:
return value


@brevitas.jit.script
@brevitas.jit.ignore
def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor):
max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
max_mantissa = torch.sum((
Expand Down
44 changes: 29 additions & 15 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
from typing import List, Optional, Set
import warnings

import torch
from torch.fx import GraphModule as TorchGraphModule

from brevitas.fx import GraphModule
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.graph.utils import is_conv_transposed
import brevitas.nn as qnn
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor.base_quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

SUPPORTED_CONV_OP = (
Expand Down Expand Up @@ -194,26 +197,29 @@ def __init__(
self.layer = layer
self.name = name
self.act_order = act_order
if self.layer.weight_quant.is_groupwise:
weight = self.layer.weight_quant.apply_input_view(self.layer.weight)
weight = weight.view(self.layer.weight_quant.quant_injector.reshaped_groupwise_shape)
self.layer.weight.data = weight.data
self.layer.in_channels = weight.shape[1] if is_conv_transposed(
self.layer) else weight.shape[0]

weight = layer.weight.data
weight_shape = torch.tensor(layer.weight.shape)

if create_weight_orig and not hasattr(self.layer, 'weight_orig'):
self.layer.register_buffer('weight_orig', layer.weight.detach().clone())

# By default, use groups = 1
self.groups = 1
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
if is_conv_transposed(self.layer):
weight_shape[1], weight_shape[0] = weight_shape[0], weight_shape[1]
self.groups = self.layer.groups

# Number of rows is equal to the output channels (OC)
self.rows = weight.shape[0]
self.rows = weight_shape[0]
# Number of columns is equal to the input channels (IC)
self.columns = weight.shape[1]
self.columns = torch.prod(weight_shape[1:])
self.len_parallel_layers = len_parallel_layers

self.disable_pre_forward_hook = False
Expand Down Expand Up @@ -262,17 +268,25 @@ def get_quant_weights(self, i, i1, permutation_list):
# For QuantLinear and for some QuantConvolutional layers, we exploit the possibility
# of quantizing only a subset of the entire matrix speeding up the computation of GPxQ
if isinstance(self.layer, qnn.QuantLinear):
index = permutation_list[0][i]
subtensor_slice_list = [None, (index, index + 1)]
q = self.layer.quant_weight(
subtensor_slice_list=subtensor_slice_list,
quant_input=self.quant_metadata).value.unsqueeze(0) # [1, OC, 1]
if self.layer.weight_quant.is_groupwise:
# No slicing, not optimized
index = permutation_list[0][i]
q = self.layer.quant_weight(quant_input=self.quant_metadata).value.unsqueeze(
0) # [1, OC, 1]
q = q[:, :, i:i + 1] # [groups, OC/groups, 1]
else:
index = permutation_list[0][i]
subtensor_slice_list = [None, (index, index + 1)]
q = self.layer.quant_weight(
subtensor_slice_list=subtensor_slice_list,
quant_input=self.quant_metadata).value.unsqueeze(0) # [1, OC, 1]
elif isinstance(self.layer, SUPPORTED_CONV_OP):
# For depthwise and ConvTranspose we fall back to quantizing the entire martix.
# For all other cases, we create a mask that represent the slicing we will perform on the weight matrix
# and we quantize only the selected dimensions.
if self.groups > 1 or (self.groups == 1 and isinstance(
self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))):
if self.layer.weight_quant.is_groupwise or self.groups > 1 or (
self.groups == 1 and
isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))):

quant_weight = self.layer.quant_weight(quant_input=self.quant_metadata)
quant_weight = quant_weight.value
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
'get_output_channels',
'get_output_channel_dim']

CONV_TRANSPOSED = [
CONV_TRANSPOSED = (
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
qnn.QuantConvTranspose1d,
qnn.QuantConvTranspose2d,
qnn.QuantConvTranspose3d]
qnn.QuantConvTranspose3d)


def module_class_name(m: torch.nn.Module):
Expand Down Expand Up @@ -146,7 +146,7 @@ def matches_module_pattern(pattern: Iterable, node: Node, modules: Dict[str, Any


def is_conv_transposed(module):
return isinstance(module, tuple(CONV_TRANSPOSED))
return isinstance(module, CONV_TRANSPOSED)


def get_output_channel_dim(module):
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ def _disabled(fn):

script_method = torch.jit.script_method
script = torch.jit.script
ignore = torch.jit.ignore
ScriptModule = torch.jit.ScriptModule
Attribute = torch.jit.Attribute

else:

script_method = _disabled
script = _disabled
ignore = _disabled
ScriptModule = torch.nn.Module
Attribute = lambda val, type: val
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return GroupwiseFloatQuantTensor(
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth making a GroupwiseQuantProxyMixin which provides these functions? Looks like that a lot is shared between GroupwiseWeightQuantProxyFromInjector & GroupwiseActQuantProxyFromInjector...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After further review, I seem to recall there was some issue with Proxy Mixins (perhaps with dependency injection) results in the ExportMixin needing to be in a certain location. I'll leave this for now - we can revisit at another time.

x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor:
out, scale, zero_point, bit_width = qt_args
return GroupwiseIntQuantTensor(
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
self._cached_weight = self.cache_class(
out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
out = x
out = self.apply_input_view(x)
return out


Expand Down
9 changes: 5 additions & 4 deletions src/brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from brevitas import config
from brevitas.common import ExportMixin
from brevitas.core.scaling import ScalingPerOutputType
from brevitas.core.utils import StatelessBuffer
from brevitas.inject import BaseInjector as Injector
from brevitas.utils.quant_utils import float_to_int_impl_to_enum
Expand All @@ -21,10 +22,7 @@


def _is_groupwise(quant_injector):
if 'group_size' in quant_injector:
return True
else:
return False
return 'scaling_per_output' in quant_injector and quant_injector.scaling_per_output == ScalingPerOutputType.GROUP


def _is_narrow_range(quant_injector):
Expand Down Expand Up @@ -123,6 +121,9 @@ def add_tracked_module(self, module: nn.Module) -> None:
else:
raise RuntimeError("Trying to add None as a parent module.")

def apply_input_view(self, x):
return self.quant_injector.input_view_impl(x)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
elif not self.is_quant_enabled:
# A tuple helps later with control flows
# The second None value is used later
y = (self.fused_activation_quant_proxy.activation_impl(y), None)
# If quant is not enabled, we still apply input_view in the case of groupwise + padding
y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y))
y = (y, None)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
Expand Down
54 changes: 29 additions & 25 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,45 +111,49 @@ def scaling_impl(scaling_impl_type):
class SolveParameterScalingShape(ExtendedInjector):

@value
def scaling_shape(module, group_dim, group_size=None, scaling_per_output=None):
def scaling_shape(scaling_per_output, expanded_groupwise_shape=None, group_dim=None):
if scaling_per_output == ScalingPerOutputType.TENSOR:
return SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
return this.scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
assert group_size is not None, "Per Group scaling requires group size"
assert group_dim is not None, "Per Group scaling requires group dim"
size = list(module.weight.shape)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, 1)
return size
# Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1
assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured"
assert group_dim is not None, "Per Group scaling not correctly configured"
size = list(expanded_groupwise_shape)
size[group_dim + 1] = 1
return tuple(size)

@value
def reshaped_scaling_shape(module):
return module.weight.shape
def reshaped_groupwise_shape(expanded_groupwise_shape, group_dim, group_size):
new_shape = list(expanded_groupwise_shape)
del new_shape[group_dim + 1] # delete the group_size shape
# Expand the group_dim shape, accounting for padding
new_shape[group_dim] = new_shape[group_dim] * group_size
return new_shape

@value
def expanded_scaling_shape(module, group_dim, group_size=None):
assert group_size is not None, "Per Group scaling requires group size"
size = list(module.weight.shape)
def expanded_groupwise_shape(tracked_parameter_list, group_dim, group_size=None):
# expanded_groupwise_shape will be called always to create scaling_shape, but it is only needed
# for groupwise quantization. All other groupwise shape infos are derived from this.

# If conditions do not allow for groupwise quantization, early exit and return None
if group_size is None:
return

# If group_size is specified and shared quantization is used, raise an error.
assert len(tracked_parameter_list) == 1, "Shared groupwise quantization is not currently supported"

weight_shape = tracked_parameter_list[0].shape
size = list(weight_shape)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, group_size)
return size

@value
def padding(module, group_dim, group_size):
padding = [0, 0] * len(module.weight.shape)
size = list(module.weight.shape)
if size[group_dim] % group_size != 0:
# Padding is done on the left side
padding[2 * group_dim] = group_size - size[group_dim] % group_size
# Padding takes a list of 2 values per dim in reverse order (N_DIM, N_DIM-1,...,0)
# so we need to reverse the order
padding = list(reversed(padding))
return padding
return tuple(size)

@value
def group_dim(module, group_size=None):
# group_dim will be called always to create scaling_shape, but it is only needed
# for groupwise quantization.
if group_size is not None:
return 1 if not hasattr(module, 'transposed') or not module.transposed else 0

Expand Down
13 changes: 12 additions & 1 deletion src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch
from torch.nn import Sequential
Expand Down Expand Up @@ -102,3 +102,14 @@ def float_internal_scale(
internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min)
internal_scale = torch.exp2(internal_scale)
return internal_scale


@brevitas.jit.ignore
def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]:
# Given a tensor X, compute the padding aloing group_dim so that groupwise shaping is possible
padding = [0, 0] * len(x.shape)
size = x.shape
if size[group_dim] % group_size != 0:
padding[2 * group_dim] = group_size - size[group_dim] % group_size
padding = list(reversed(padding))
return padding
Loading
Loading