Skip to content

Commit

Permalink
Fixing a2q pytest cases
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Apr 4, 2023
1 parent d4e2e46 commit 3898a65
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 51 deletions.
1 change: 1 addition & 0 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def calculate_min_accumulator_bit_width(
Example (weight-level bound):
>> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_max_l1_norm)
"""
input_is_signed = float(input_is_signed)
# if the l1-norm of the weights is specified, then use the weight-level bound
if weight_max_l1_norm is not None:
assert isinstance(weight_max_l1_norm, (float, Tensor)), "The l1-norm of the weights needs to be a float or a torch.Tensor instance."
Expand Down
27 changes: 4 additions & 23 deletions tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,53 +205,34 @@ def case_model(
is_training)


@pytest_cases.parametrize(
'input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize(
'bias_quantizer',
BIAS_QUANTIZER.items(),
ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize(
'io_quantizer',
WBIOL_IO_QUANTIZER.items(),
ids=[f'io_quant${c}' for c, _ in WBIOL_IO_QUANTIZER.items()])
@pytest_cases.parametrize(
'return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize(
'module', QUANT_WBIOL_IMPL, ids=[f'model_type${c.__name__}' for c in QUANT_WBIOL_IMPL])
@pytest_cases.parametrize(
'is_training', [True, False], ids=[f'is_training${f}' for f in [True, False]])
@pytest_cases.parametrize(
'accumulator_bit_width',
ACC_BIT_WIDTHS,
ids=[f'accumulator_bit_width${bw}' for bw in ACC_BIT_WIDTHS])
def case_model_a2q(
bias_quantizer,
io_quantizer,
return_quant_tensor,
module,
request,
input_quantized,
is_training,
accumulator_bit_width):
set_case_id(request.node.callspec.id, case_model_a2q)
case_id = get_case_id(case_model_a2q)
# forcing test to only use accumulator-aware weight quantizer
weight_quantizer = ('quant_a2q', Int8AccumulatorAwareWeightQuant)
io_name, _ = io_quantizer
if io_name == 'batch_quant':
pytest.skip(
"Skipping batch_quant tests with A2Q."
)
return build_case_model(
weight_quantizer,
bias_quantizer,
("None",None), # bias_quantizer = None
io_quantizer,
return_quant_tensor,
True, # return_quant_tensor = True
module,
case_id,
input_quantized,
is_training,
True, # input_quantizer = True
True, # is_training = True
accumulator_bit_width=accumulator_bit_width)


Expand Down
36 changes: 8 additions & 28 deletions tests/brevitas/nn/test_a2q.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def parse_args(args):

@pytest_cases.parametrize_with_cases('model_input', cases=case_model_a2q)
def test_quant_wbiol_a2q(model_input, current_cases):
"""This test verifies that the accumulator-aware weight quantization constraints the l1-norm of
the weights enough use the user-specified accumulator bit-width."""
"""This test only verifies that the accumulator-aware weight quantization constraints the l1-norm of
the weights enough use the user-specified accumulator bit-width. Baseline functionality is in the
test_nn_quantizers."""
model, input = model_input

cases_generator_func = current_cases['model_input'][1]
Expand All @@ -38,33 +39,19 @@ def test_quant_wbiol_a2q(model_input, current_cases):

# A2Q needs to have a quantized input, which can be done by input quantizer or returning
# a quantized tensor from the preceding layer
is_input_quant_tensor = kwargs['io_quant'] is not None or kwargs['input_quantized']
is_input_quant_tensor = kwargs['io_quant'] is not None or isinstance(input, QuantTensor)
assert is_input_quant_tensor, "All A2Q models require quantized inputs."

# testing the forward pass
output = model(input)

# bit-width and sign need to come from the quant tensor of the preceding layer if no io_quant
if kwargs['io_quant'] is None:
input_bit_width = input.bit_width
input_is_signed = input.signed
# else bit-width and sign come from the io_quant
else:
input_bit_width = model.conv.quant_input_bit_width()
input_is_signed = model.conv.is_quant_input_signed
quant_input = model.conv.input_quant(input)
input_bit_width = quant_input.bit_width
input_is_signed = quant_input.signed

# if the input is not quantized already, the wrap it in a QuantTensor
if not kwargs['input_quantized']:
input = QuantTensor(
input,
None, # not used by weight_quant
None, # not used by weight_quant
input_bit_width,
input_is_signed,
None # note used by weight_quant
)
# the tensor quantizer requires a QuantTensor with specified bit-width and sign
quant_weight = model.conv.quant_weight(input)
quant_weight = model.conv.quant_weight(quant_input)
quant_weight = quant_weight.int().float()
if kwargs['model_type'] == 'QuantLinear': # shape = (out_features, in_features)
quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=1)
Expand All @@ -89,10 +76,3 @@ def test_quant_wbiol_a2q(model_input, current_cases):
assert cur_acc_bit_width <= exp_acc_bit_width, \
f"Model does not satisfy accumulator bit-width bounds. Expected {exp_acc_bit_width}, got {cur_acc_bit_width}"

# might as well also check again if the output is a quant tensor if it is specified to be
if kwargs['return_quant_tensor']:
assert isinstance(output, QuantTensor)
assert output.scale is not None
assert output.bit_width is not None
else:
assert isinstance(output, torch.Tensor)

0 comments on commit 3898a65

Please sign in to comment.