Skip to content

Commit

Permalink
Fix for output quant metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 4, 2024
1 parent a32b72a commit 7e16d9c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ def merge_bn_in(self, bn):
merge_bn(self, bn, output_channel_dim=self.output_channel_dim)

def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
output_scale = None
output_bit_width = None
output_signed = None
output_zero_point = None

inp = self.unpack_input(inp)

# shortcut execution through the export impl during export
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/nn/quant_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def forward(self, inp, state):
quant_weight_ih, quant_weight_hh, quant_bias = self.gate_params_fwd(
self.gate_params, quant_input)
quant_input_value = _unpack_quant_tensor(quant_input)
if getattr(quant_bias, 'value', quant_bias) is None:
if quant_bias is None:
quant_bias = torch.tensor(0., device=quant_input_value.device)
else:
quant_bias = _unpack_quant_tensor(quant_bias)
Expand Down

0 comments on commit 7e16d9c

Please sign in to comment.