Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 22, 2024
1 parent acbaa6a commit 6c6e0c4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
5 changes: 2 additions & 3 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,9 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None)

if not self.is_output_quant_enabled:
if not self.is_output_quant_enabled and self.return_quant_tensor:
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
if (quant_input.zero_point != 0.0
).any() or (quant_weight.zero_point != 0.0).any() and self.return_quant_tensor:
if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any():
raise RuntimeError(
"Computing zero point of output accumulator not supported yet.")
elif output_zero_point is None:
Expand Down
8 changes: 7 additions & 1 deletion tests/brevitas/nn/test_nn_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,13 @@ def test_quant_mha(model_input, current_cases):
with pytest.raises(RuntimeError, match='Input scale required'):
output, _ = model(inp, inp, inp)
return

elif kwargs['weight_quant'] is not None and kwargs['io_quant'] is None:
if kwargs['weight_quant'] == 'quant_asym' and kwargs['return_quant_tensor']:
with pytest.raises(
RuntimeError,
match='Computing zero point of output accumulator not supported yet.'):
output, _ = model(inp, inp, inp)
return
output, _ = model(inp, inp, inp)

if kwargs['return_quant_tensor']:
Expand Down

0 comments on commit 6c6e0c4

Please sign in to comment.