Skip to content

Commit

Permalink
New QuantTensor Structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Apr 9, 2024
1 parent 1365296 commit 40ea5b4
Show file tree
Hide file tree
Showing 12 changed files with 232 additions and 174 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/fx/value_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import torch.utils._pytree as pytree

from brevitas import torch_version
from brevitas.quant_tensor import QuantTensorBase
from brevitas.quant_tensor import QuantTensor

from . import *
from . import _assert_is_none
Expand All @@ -82,7 +82,7 @@
from . import ScopeContextManager

_UNSET = object()
extended_base_types = base_types + (QuantTensorBase,)
extended_base_types = base_types + (QuantTensor,)

FRAME_FILES = [
'fx/brevitas_tracer.py',
Expand Down
10 changes: 5 additions & 5 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor
from brevitas.quant_tensor import IntQuantTensor


class gpfq_mode(gpxq_mode):
Expand Down Expand Up @@ -319,11 +319,11 @@ def process_input(self, inp):

is_quant_enabled = self.layer.weight_quant.is_quant_enabled

# If using quantized activations, inp could be QuantTensor. In
# If using quantized activations, inp could be IntQuantTensor. In
# this case, we overwrite the metadata.
if isinstance(inp, QuantTensor):
if isinstance(inp, IntQuantTensor):
if is_quant_enabled and self.quant_input is None:
self.quant_input = QuantTensor(
self.quant_input = IntQuantTensor(
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=inp.scale,
Expand All @@ -339,7 +339,7 @@ def single_layer_update(self):
# raise error in case no quant-input is here
if self.quant_input is None:
raise ValueError('Expected self.quant_input to calculate L1-norm upper bound, but recevied None. ' + \
'Make sure that either the input to the model is a QuantTensor or the layer has an input quant enabled. ' \
'Make sure that either the input to the model is a IntQuantTensor or the layer has an input quant enabled. ' \
'Also, check if `use_quant_activations=True` in `gpfq_mode` when `accumulator_bit_width` is specified. ' + \
'Alternatively, provide a custom `a2q_layer_filter_fnc` to `gpfq_mode` to filter layers without a quant_tensor input.')
weight = self.layer.weight.data
Expand Down
14 changes: 8 additions & 6 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

Expand Down Expand Up @@ -74,8 +75,9 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
# Hack to recognize a QuantTensor that has decayed to a tuple
# when used as input to tracing (e.g. during ONNX export)
if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and
len(inp) == len(QuantTensor._fields) and all([isinstance(t, Tensor) for t in inp])):
inp = QuantTensor(*inp)
len(inp) == len(IntQuantTensor._fields) and
all([isinstance(t, Tensor) for t in inp])):
inp = IntQuantTensor(*inp)
if not torch._C._get_tracing_state():
if isinstance(inp, QuantTensor):
inp = inp.set(value=inp.value.rename(None))
Expand Down Expand Up @@ -186,7 +188,7 @@ def pack_quant_outputs(self, quant_outputs):
# inner layers in a deep network overrides it, so we check again.
if self.export_mode:
if self.return_quant_tensor and self.io_quant.is_quant_enabled:
return QuantTensor(
return IntQuantTensor(
quant_outputs,
self.io_quant.scale(),
self.io_quant.zero_point(),
Expand All @@ -198,7 +200,7 @@ def pack_quant_outputs(self, quant_outputs):
seq_dim = 1 if self.cell.batch_first else 0
if self.return_quant_tensor and self.io_quant.is_quant_enabled:
outputs = [
QuantTensor(
IntQuantTensor(
torch.unsqueeze(quant_output[0], dim=seq_dim),
quant_output[1],
quant_output[2],
Expand All @@ -217,7 +219,7 @@ def pack_quant_state(self, quant_state, quant):
# inner layers in a deep network overrides it, so we check again.
if self.export_mode:
if self.return_quant_tensor and quant.is_quant_enabled:
quant_state = QuantTensor(
quant_state = IntQuantTensor(
torch.unsqueeze(quant_state, dim=0),
quant.scale(),
quant.zero_point(),
Expand All @@ -228,7 +230,7 @@ def pack_quant_state(self, quant_state, quant):
quant_state = torch.unsqueeze(quant_state, dim=0)
else:
if self.return_quant_tensor and quant.is_quant_enabled:
quant_state = QuantTensor(
quant_state = IntQuantTensor(
torch.unsqueeze(quant_state[0], dim=0),
quant_state[1],
quant_state[2],
Expand Down
30 changes: 17 additions & 13 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas import config
from brevitas.function import max_int
from brevitas.inject import BaseInjector as Injector
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

Expand Down Expand Up @@ -103,11 +104,11 @@ def bit_width(self):
bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width
return bit_width

def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width = impl(x)
return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return x

Expand All @@ -128,11 +129,11 @@ def pre_zero_point(self):
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple
return pre_zero_point

def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x)
return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return x

Expand All @@ -157,11 +158,12 @@ def pre_zero_point(self):
raise NotImplementedError

def forward(
self,
x: torch.Tensor,
quant_input: Optional[Union[Tensor, QuantTensor]] = None) -> Union[Tensor, QuantTensor]:
self,
x: torch.Tensor,
quant_input: Optional[Union[Tensor,
IntQuantTensor]] = None) -> Union[Tensor, IntQuantTensor]:
if isinstance(quant_input,
QuantTensor) and not self.training and self.cache_inference_quant_act:
IntQuantTensor) and not self.training and self.cache_inference_quant_act:
cached_inp = _CachedIO(quant_input.detach(), self.cache_quant_io_metadata_only)
self._cached_act = cached_inp

Expand All @@ -170,14 +172,14 @@ def forward(
assert self._cached_act is not None, "No cached quant input found. Enable caching and perform a forward pass"
quant_input = self._cached_act
else:
assert isinstance(quant_input, QuantTensor), "Input must be quantized"
assert isinstance(quant_input, IntQuantTensor), "Input must be quantized"

input_bit_width = quant_input.bit_width
input_is_signed = quant_input.signed

impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed)
return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return x

Expand Down Expand Up @@ -236,7 +238,7 @@ def bit_width(self):

def forward(self,
x: Tensor,
input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]:
input_scale: Optional[Tensor] = None) -> Union[Tensor, IntQuantTensor]:
out = x
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
Expand All @@ -251,10 +253,12 @@ def forward(self,
else:
out, out_scale, out_zp, out_bit_width = impl(x)

out = QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
out = IntQuantTensor(
out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
else:
out = x
if isinstance(out, QuantTensor) and not self.training and self.cache_inference_quant_bias:
if isinstance(out,
IntQuantTensor) and not self.training and self.cache_inference_quant_bias:
cached_bias = _CachedIO(out.detach(), metadata_only=False)
self._cached_bias = cached_bias
return out
26 changes: 14 additions & 12 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing_extensions import runtime_checkable

import brevitas
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

Expand Down Expand Up @@ -166,11 +167,11 @@ def bit_width(self, force_eval=True):
elif self._cached_act is None:
return None

def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]:
out = x
if self.fused_activation_quant_proxy is not None:
y = x
if isinstance(y, QuantTensor):
if isinstance(y, IntQuantTensor):
y = y.value

if self.export_mode:
Expand All @@ -180,15 +181,15 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
if isinstance(y, tuple) and not any(map(lambda f: f is None, y)):
out = QuantTensor(*y, signed=self.is_signed, training=self.training)
out = IntQuantTensor(*y, signed=self.is_signed, training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
if isinstance(x, QuantTensor):
out = QuantTensor(
if isinstance(x, IntQuantTensor):
out = IntQuantTensor(
y, x.scale, x.zero_point, x.bit_width, x.signed, self.training)
else:
out = y
Expand All @@ -199,7 +200,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
else:
# If fused activation quant proxy is not enabled, return the input
out = x
if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor):
if not self.training and self.cache_inference_quant_act and isinstance(out, IntQuantTensor):
cached_out = _CachedIO(out.detach(), self.cache_quant_io_metadata_only)
self._cached_act = cached_out
return out
Expand All @@ -216,11 +217,11 @@ def zero_point(self, force_eval=True):

class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]:
def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
return QuantTensor(
return IntQuantTensor(
out_value, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
return x

Expand All @@ -232,19 +233,20 @@ def bit_width(self):
return None
zhs = self._zero_hw_sentinel()
# Signed might or might not be defined. We just care about retrieving the bitwidth
empty_imp = QuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training)
empty_imp = IntQuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training)
bit_width = self.__call__(empty_imp).bit_width
return bit_width

def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]:
def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
if self.export_mode:
out_tuple = self.export_handler(
x.value, x.scale, x.zero_point, x.bit_width, x.signed)
else:
out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
return QuantTensor(out_value, out_scale, out_zp, out_bit_width, x.signed, self.training)
return IntQuantTensor(
out_value, out_scale, out_zp, out_bit_width, x.signed, self.training)
else:
return x

Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from .base_quant_tensor import *
from .base_quant_tensor import _unpack_quant_tensor
from .base_quant_tensor import QuantTensorBase
from .int_quant_tensor import QuantTensor
from .int_quant_tensor import *
Loading

0 comments on commit 40ea5b4

Please sign in to comment.