Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (proxy): scale computation delegated to bias proxy #938

Merged
merged 5 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,9 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):

@staticmethod
def gate_params_fwd(gate, quant_input):
acc_scale = None
quant_weight_ih = gate.input_weight()
quant_weight_hh = gate.hidden_weight()
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor):
acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1)
acc_scale = quant_weight_ih.scale.view(acc_scale_shape)
acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape)
quant_bias = gate.bias_quant(gate.bias, acc_scale)
quant_bias = gate.bias_quant(gate.bias, quant_input, quant_weight_ih)
return quant_weight_ih, quant_weight_hh, quant_bias

def reset_parameters(self) -> None:
Expand Down
15 changes: 1 addition & 14 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,6 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
def max_acc_bit_width(self, input_bit_width: Tensor, quant_weight_bit_width: Tensor):
pass

def quant_output_scale_impl(
self, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor):
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim)
output_scale = quant_weight_scale.view(output_scale_shape)
output_scale = output_scale * quant_input_scale.view(output_scale_shape)
return output_scale

@property
def requires_export_handler(self):
return (
Expand All @@ -150,7 +142,6 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
return out

quant_input = self.input_quant(inp)

quant_weight = self.quant_weight(quant_input)

compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance(
Expand All @@ -159,12 +150,8 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
self.output_quant.is_quant_enabled) and self.return_quant_tensor:
raise RuntimeError("QuantLayer is not correctly configured")

output_scale = None
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale)

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale)
quant_bias = self.bias_quant(self.bias, quant_input, quant_weight)
else:
quant_bias = None
output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias)
Expand Down
35 changes: 32 additions & 3 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from brevitas.inject import BaseInjector as Injector
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
from brevitas.utils.torch_utils import compute_channel_view_shape

from .quant_proxy import QuantProxyFromInjector
from .quant_proxy import QuantProxyProtocol
Expand Down Expand Up @@ -234,10 +235,38 @@ def bit_width(self):
bit_width = self.__call__(self.tracked_parameter_list[0], zhs).bit_width
return bit_width

def forward(self,
x: Tensor,
input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]:
def quant_output_scale_impl(
self, input: QuantTensor, weight: QuantTensor, module: torch.nn.Module) -> Tensor:
channel_dim = -1 if isinstance(module, torch.nn.Linear) else 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come its -1 if linear and 1 otherwise?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
For linear layer, the channel dimension is always the last one, otherwise for Conv, ConvTranpose etc. is always at dim 1, at least for the input tensor.

output_scale_shape = compute_channel_view_shape(input, channel_dim=channel_dim)
output_scale = weight.scale.view(output_scale_shape)
output_scale = output_scale * input.scale.view(output_scale_shape)
return output_scale

def compute_bias_scale(
self,
input: Optional[Union[Tensor, QuantTensor]],
weight: Optional[Union[Tensor, QuantTensor]]) -> Optional[Tensor]:
if not self.requires_input_scale:
return None
if not isinstance(input, QuantTensor) or not isinstance(weight, QuantTensor):
return None
if len(self.tracked_module_list) > 1:
if not all(
[type[self.tracked_module_list[0]] == type[x] for x in self.tracked_module_list]):
raise RuntimeError(
"Bias quantizer shared across different type of layers with external scale is not supported."
)
scale = self.quant_output_scale_impl(input, weight, self.tracked_module_list[0])
return scale

def forward(
self,
x: Tensor,
input: Optional[Union[Tensor, QuantTensor]] = None,
weight: Optional[Union[Tensor, QuantTensor]] = None) -> Union[Tensor, QuantTensor]:
out = x
input_scale = self.compute_bias_scale(input, weight)
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None:
Expand Down
Loading