Skip to content

Commit

Permalink
Feat (proxy): scale computation delegated to bias proxy (#938)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Apr 19, 2024
1 parent e1d5bbe commit 670420f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 23 deletions.
7 changes: 1 addition & 6 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,9 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):

@staticmethod
def gate_params_fwd(gate, quant_input):
acc_scale = None
quant_weight_ih = gate.input_weight()
quant_weight_hh = gate.hidden_weight()
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor):
acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1)
acc_scale = quant_weight_ih.scale.view(acc_scale_shape)
acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape)
quant_bias = gate.bias_quant(gate.bias, acc_scale)
quant_bias = gate.bias_quant(gate.bias, quant_input, quant_weight_ih)
return quant_weight_ih, quant_weight_hh, quant_bias

def reset_parameters(self) -> None:
Expand Down
15 changes: 1 addition & 14 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,6 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
def max_acc_bit_width(self, input_bit_width: Tensor, quant_weight_bit_width: Tensor):
pass

def quant_output_scale_impl(
self, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor):
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim)
output_scale = quant_weight_scale.view(output_scale_shape)
output_scale = output_scale * quant_input_scale.view(output_scale_shape)
return output_scale

@property
def requires_export_handler(self):
return (
Expand All @@ -150,7 +142,6 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
return out

quant_input = self.input_quant(inp)

quant_weight = self.quant_weight(quant_input)

compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance(
Expand All @@ -159,12 +150,8 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
self.output_quant.is_quant_enabled) and self.return_quant_tensor:
raise RuntimeError("QuantLayer is not correctly configured")

output_scale = None
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale)

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale)
quant_bias = self.bias_quant(self.bias, quant_input, quant_weight)
else:
quant_bias = None
output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias)
Expand Down
35 changes: 32 additions & 3 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from brevitas.inject import BaseInjector as Injector
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
from brevitas.utils.torch_utils import compute_channel_view_shape

from .quant_proxy import QuantProxyFromInjector
from .quant_proxy import QuantProxyProtocol
Expand Down Expand Up @@ -234,10 +235,38 @@ def bit_width(self):
bit_width = self.__call__(self.tracked_parameter_list[0], zhs).bit_width
return bit_width

def forward(self,
x: Tensor,
input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]:
def quant_output_scale_impl(
self, input: QuantTensor, weight: QuantTensor, module: torch.nn.Module) -> Tensor:
channel_dim = -1 if isinstance(module, torch.nn.Linear) else 1
output_scale_shape = compute_channel_view_shape(input, channel_dim=channel_dim)
output_scale = weight.scale.view(output_scale_shape)
output_scale = output_scale * input.scale.view(output_scale_shape)
return output_scale

def compute_bias_scale(
self,
input: Optional[Union[Tensor, QuantTensor]],
weight: Optional[Union[Tensor, QuantTensor]]) -> Optional[Tensor]:
if not self.requires_input_scale:
return None
if not isinstance(input, QuantTensor) or not isinstance(weight, QuantTensor):
return None
if len(self.tracked_module_list) > 1:
if not all(
[type[self.tracked_module_list[0]] == type[x] for x in self.tracked_module_list]):
raise RuntimeError(
"Bias quantizer shared across different type of layers with external scale is not supported."
)
scale = self.quant_output_scale_impl(input, weight, self.tracked_module_list[0])
return scale

def forward(
self,
x: Tensor,
input: Optional[Union[Tensor, QuantTensor]] = None,
weight: Optional[Union[Tensor, QuantTensor]] = None) -> Union[Tensor, QuantTensor]:
out = x
input_scale = self.compute_bias_scale(input, weight)
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None:
Expand Down

0 comments on commit 670420f

Please sign in to comment.