From 78d44f861b2ef23ceb39d214658b1dd832d9cbd1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 31 Mar 2023 12:17:36 +0000 Subject: [PATCH 1/8] Feat (graph_eq): Add support for LayerNorm and MultiheadAttention --- src/brevitas/graph/equalize.py | 109 +++++++++++++++++++++------------ 1 file changed, 70 insertions(+), 39 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index cd723215f..b9a94dd37 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause + +from collections import namedtuple from copy import deepcopy from functools import partial import operator @@ -23,10 +25,12 @@ nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d, + nn.MultiheadAttention, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear, + nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) @@ -51,8 +55,11 @@ operator.mul, operator.imul, operator.__mul__, - operator.__imul__, -) + operator.__imul__) + +_select_op = ( + operator.getitem, + operator.__getitem__) _residual_methods = ('add', 'add_') @@ -61,8 +68,11 @@ _batch_norm = ( nn.BatchNorm1d, nn.BatchNorm2d, - nn.BatchNorm3d, -) + nn.BatchNorm3d) + + +WeightBiasTuple = namedtuple('WeightBiasTuple', ['weight', 'bias'], defaults=[None]) + def _select_scale_computation_fn( @@ -103,8 +113,8 @@ def _get_size(axes: Dict[nn.Module, int]) -> int: return size -def _get_input_axis(module: nn.Module) -> int: - if isinstance(module, nn.Linear): +def _get_input_axis(module: nn.Module) -> Optional[int]: + if isinstance(module, (nn.Linear, nn.MultiheadAttention)): return 1 elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): return 0 @@ -118,26 +128,32 @@ def _get_input_axis(module: nn.Module) -> int: return 0 elif module.groups == module.out_channels: return 1 + elif isinstance(module, nn.LayerNorm): + # We assume normalization happens only along the channel dimension + if len(module.weight.shape) == 1: + return 0 + else: + return None else: return None -def _get_output_axis(module: nn.Module) -> int: - if isinstance( - module, - (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, - nn.BatchNorm3d)): +def _get_output_axis(module: nn.Module) -> Optional[int]: + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.MultiheadAttention, + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): return 0 elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): return 1 + elif isinstance(module, nn.LayerNorm): + # We assume normalization happens only along the channel dimension + if len(module.weight.shape) == 1: + return 0 + else: + return None else: return None - -def _combine_weights_bias( - weight: nn.parameter.Parameter, - bias_shrinkage: Union[float, str], - bias: Optional[nn.parameter.Parameter]): +def _combine_weights_bias(weight: torch.Tensor, bias_shrinkage: Union[float, str], bias: Optional[torch.Tensor]) -> torch.Tensor: """Combine weights and bias before graph equalizattion This method merges the weight and bias of the sources, so that the resulting equalizer scale factor is influenced also by the magnitude of the bias, mitigated by a shrink factor. @@ -193,15 +209,28 @@ def _cross_layer_equalization( device = next(srcs[0].parameters()).device dtype = next(srcs[0].parameters()).dtype - for module_set in [srcs, sinks]: - for module in module_set: - if not isinstance(module, _supported_layers): - return torch.tensor( - 1., dtype=dtype, - device=device) # If module is not supported, do not perform graph equalization - - src_axes = {m: _get_output_axis(m) for m in srcs} - sink_axes = {m: _get_input_axis(m) for m in sinks} + src_axes = {} + sink_axes = {} + + for i, module in enumerate(srcs): + # If module is not supported, do not perform graph equalization + if not isinstance(module, _supported_layers): + return torch.tensor(1., dtype=dtype, device=device) + if isinstance(module, nn.MultiheadAttention): + srcs[i] = module.out_proj + src_axes[srcs[i]] = _get_output_axis(module) + + for i, module in enumerate(sinks): + # If module is not supported, do not perform graph equalization + if not isinstance(module, _supported_layers): + return torch.tensor(1., dtype=dtype, device=device) + # For MultiheadAttention, we support only self-attetion + if isinstance(module, nn.MultiheadAttention) and hasattr(sinks[i], 'in_proj_weight'): + # For sinks, we only need to modify the weight but not the bias + sinks[i] = WeightBiasTuple(module.in_proj_weight) + elif isinstance(module, nn.MultiheadAttention) and not hasattr(sinks[i], 'in_proj_weight'): + return torch.tensor(1., dtype=dtype, device=device) + sink_axes[sinks[i]] = _get_input_axis(module) # Check if any of the axis is None, which means that the module is not supported. # In that case, do not perform graph equalization @@ -216,7 +245,6 @@ def _cross_layer_equalization( # Similarly, exit if source and sink have different different sizes if None in [src_size, sink_size] or src_size != sink_size: return torch.tensor(1., dtype=dtype, device=device) - transpose = lambda module, axis: module.weight if axis == 0 else module.weight.transpose(0, 1) scale_fn = _select_scale_computation_fn(scale_computation_type) if merge_bias: @@ -288,9 +316,14 @@ def _equalize( def _is_supported_module(graph_model: GraphModule, node: Node) -> bool: - return node.op == 'call_module' and isinstance( - get_module(graph_model, node.target), _supported_layers) - + if node.op == 'call_module': + module = get_module(graph_model, node.target) + if isinstance(module, _supported_layers): + # We support only self-attention + if isinstance(module, nn.MultiheadAttention): + return all([node.all_input_nodes[0].name == n.name for n in node.all_input_nodes]) + return True + return False def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool: return node.op == 'call_module' and isinstance( @@ -298,7 +331,7 @@ def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool: def _is_scale_invariant_function(node: Node) -> bool: - return node.op == 'call_function' and node.target in _scale_invariant_op + return node.op == 'call_function' and node.target in _scale_invariant_op + _select_op def _is_reshaping_op(node: Node) -> bool: @@ -348,15 +381,13 @@ def walk_region( def _extract_regions(graph_model: GraphModule) -> Set[Tuple[str]]: regions = set() for node in graph_model.graph.nodes: - if node.op == 'call_module': - module = get_module(graph_model, node.target) - if isinstance(module, _supported_layers): - srcs, sinks = {node.target}, set() - walk_region(graph_model, node, set(), srcs, sinks, walk_forward=True) - if sinks: - # each region should appear only once, so to make it hashable - # we convert srcs and sinks to ordered lists first, and then to tuples - regions.add((tuple(sorted(srcs)), tuple(sorted(sinks)))) + if _is_supported_module(graph_model, node): + srcs, sinks = {node.target}, set() + walk_region(graph_model, node, set(), srcs, sinks, walk_forward=True) + if sinks: + # each region should appear only once, so to make it hashable + # we convert srcs and sinks to ordered lists first, and then to tuples + regions.add((tuple(sorted(srcs)), tuple(sorted(sinks)))) # for clarity, sort by the of the first source regions = sorted(regions, key=lambda region: region[0][0]) return regions From fd807d8b8dd7c8ebc73035afae19a0ebefe0e8bd Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 9 Mar 2023 15:27:48 +0000 Subject: [PATCH 2/8] Test (graph_eq): Add test for MultiheadAttention, LayerNorm and ViT --- tests/brevitas/graph/equalization_fixtures.py | 70 ++++++++++++++++--- tests/brevitas/graph/test_equalization.py | 23 +++--- 2 files changed, 76 insertions(+), 17 deletions(-) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 8da0cef0b..88f8505ef 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -1,12 +1,17 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from packaging import version +import pytest import pytest_cases from pytest_cases import fixture_union import torch import torch.nn as nn +from brevitas import torch_version + MODELS = { + 'vit_b_32': [0.777, 0.793], 'shufflenet_v2_x0_5': [0.8141, 0.8230], 'mobilenet_v2': [0.6571, 0.6571], 'resnet18': [0.9756, 0.9756], @@ -15,6 +20,9 @@ 'alexnet': [0.875, 0.875],} +IN_SIZE_CONV = (1, 3, 224, 224) +IN_SIZE_LINEAR = (1, 224, 3) + @pytest_cases.fixture def bnconv_model(): @@ -39,6 +47,54 @@ def forward(self, x): return BNConvModel +@pytest_cases.fixture +def linearmha_model(): + if torch_version < version.parse('1.9.1'): + pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") + class LinearMhaModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(3,24) + self.mha = nn.MultiheadAttention(24,3,0.1, bias=True, add_bias_kv=True, batch_first=True) + def forward(self, x): + x = self.linear(x) + x, _ = self.mha(x, x, x) + return x + return LinearMhaModel + + +@pytest_cases.fixture +def layernormmha_model(): + if torch_version < version.parse('1.9.1'): + pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") + class LayerNormMhaModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.layernorm = nn.LayerNorm(3) + self.mha = nn.MultiheadAttention(3,3,0.1, bias=True, add_bias_kv=True, batch_first=True) + def forward(self, x): + x = self.layernorm(x) + x, _ = self.mha(x, x, x) + return x + return LayerNormMhaModel + + +@pytest_cases.fixture +def mhalinear_model(): + if torch_version < version.parse('1.9.1'): + pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") + class MhaLinearModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.mha = nn.MultiheadAttention(3,1,0.1, bias=True, add_bias_kv=True, batch_first=True) + self.linear = nn.Linear(3,6) + def forward(self, x): + x, _ = self.mha(x, x, x) + x = self.linear(x) + return x + return MhaLinearModel + + @pytest_cases.fixture def convdepthconv_model(): @@ -146,12 +202,8 @@ def forward(self, x): return ResidualSrcsAndSinkModel -toy_model = fixture_union( - 'toy_model', - [ - 'residual_model', - 'srcsinkconflict_model', - 'mul_model', - 'convbn_model', - 'bnconv_model', - 'convdepthconv_model']) +list_of_fixtures = ['residual_model', 'srcsinkconflict_model', 'mul_model', + 'convbn_model', 'bnconv_model', 'convdepthconv_model', + 'linearmha_model', 'mhalinear_model', 'layernormmha_model'] + +toy_model = fixture_union('toy_model', list_of_fixtures, ids=list_of_fixtures) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 976039a3b..084ae4925 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -14,7 +14,6 @@ from .equalization_fixtures import * SEED = 123456 -IN_SIZE = (1, 3, 224, 224) ATOL = 1e-3 @@ -26,9 +25,9 @@ def test_equalization_torchvision_models(model_dict: dict, merge_bias: bool): model, coverage = model_dict if model == 'googlenet' and torch_version == version.parse('1.8.1'): - pytest.skip( - 'Skip because of PyTorch error = AttributeError: \'function\' object has no attribute \'GoogLeNetOutputs\' ' - ) + pytest.skip('Skip because of PyTorch error = AttributeError: \'function\' object has no attribute \'GoogLeNetOutputs\' ') + if 'vit' in model and torch_version < version.parse('1.13'): + pytest.skip(f'ViT supported from torch version 1.13, current torch version is {torch_version}') try: model = getattr(models, model)(pretrained=True, transform_input=False) @@ -36,7 +35,7 @@ def test_equalization_torchvision_models(model_dict: dict, merge_bias: bool): model = getattr(models, model)(pretrained=True) torch.manual_seed(SEED) - inp = torch.randn(IN_SIZE) + inp = torch.randn(IN_SIZE_CONV) model.eval() expected_out = model(inp) @@ -62,9 +61,17 @@ def test_equalization_torchvision_models(model_dict: dict, merge_bias: bool): @pytest.mark.parametrize("merge_bias", [True, False]) -def test_models(toy_model, merge_bias): - model = toy_model() - inp = torch.randn(IN_SIZE) +def test_models(toy_model, merge_bias, request): + test_id = request.node.callspec.id + + if 'mha' in test_id: + in_shape = IN_SIZE_LINEAR + else: + in_shape = IN_SIZE_CONV + + model_class = toy_model + model = model_class() + inp = torch.randn(in_shape) model.eval() expected_out = model(inp) From 5deab5a83affc2a9c2b6cf75ac739b02c5a28b80 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 9 Mar 2023 15:28:51 +0000 Subject: [PATCH 3/8] Fix (graph_eq): Skip equalization if no regions are found --- src/brevitas/graph/equalize.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index b9a94dd37..d5d8805e0 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -395,33 +395,22 @@ def _extract_regions(graph_model: GraphModule) -> Set[Tuple[str]]: class EqualizeGraph(GraphTransform): - def __init__( - self, - iterations: int = 10, - threshold: float = 0.05, - return_regions: bool = False, - merge_bias: bool = True, - bias_shrinkage: Union[float, str] = 'vaiq', - scale_computation: str = 'maxabs') -> None: + def __init__(self, iterations: int = 10, threshold: float = 0.05, return_regions: bool = False, + merge_bias: bool = True, bias_shrinkage: Union[float, str] = 'vaiq', scale_computation_type: str = 'maxabs') -> None: super(EqualizeGraph, self).__init__() self.iterations = iterations self.return_regions = return_regions self.merge_bias = merge_bias self.bias_shrinkage = bias_shrinkage self.threshold = threshold - self.scale_computation = scale_computation + self.scale_computation_type = scale_computation_type def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]: regions = _extract_regions(graph_model) - graph_model = _equalize( - graph_model, - regions, - self.iterations, - self.threshold, - self.merge_bias, - self.bias_shrinkage, - self.scale_computation) + if len(regions) > 0: + graph_model = _equalize(graph_model, regions, self.iterations, self.threshold, + self.merge_bias, self.bias_shrinkage, self.scale_computation_type) if self.return_regions: return graph_model, regions else: From a6dd94bf25cb3e90d57b2d6a69cd2b94f97045f5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Mar 2023 01:44:51 +0000 Subject: [PATCH 4/8] Fix (graph_eq): Do not equalize through sink LayernNorm --- src/brevitas/graph/equalize.py | 6 +++++- tests/brevitas/graph/equalization_fixtures.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index d5d8805e0..b73a408ef 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -325,6 +325,7 @@ def _is_supported_module(graph_model: GraphModule, node: Node) -> bool: return True return False + def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool: return node.op == 'call_module' and isinstance( get_module(graph_model, node.target), _scale_invariant_layers) @@ -358,7 +359,10 @@ def walk_region( continue if _is_supported_module(graph_model, node): if walk_forward: - sinks.add(node.target) + module = get_module(graph_model, node.target) + # It is not possible to equalize through LayerNorm as sink + if not isinstance(module, nn.LayerNorm): + sinks.add(node.target) else: srcs.add(node.target) walk_region(graph_model, node, history, srcs, sinks, walk_forward=True) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 88f8505ef..4bdc85be7 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -11,7 +11,7 @@ from brevitas import torch_version MODELS = { - 'vit_b_32': [0.777, 0.793], + 'vit_b_32': [0.396, 0.396], 'shufflenet_v2_x0_5': [0.8141, 0.8230], 'mobilenet_v2': [0.6571, 0.6571], 'resnet18': [0.9756, 0.9756], From fc8cd64bfd6859bdc336352c9e81ceb74451b763 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Mar 2023 01:58:42 +0000 Subject: [PATCH 5/8] Fix (graph_eq): fix for early exit threshold --- src/brevitas/graph/equalize.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index b73a408ef..493ceec7e 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -299,17 +299,12 @@ def _equalize( for i in range(iterations): scale_factor_max = None for region in regions: - scale_factors_region = _cross_layer_equalization([name_to_module[n] for n in region[0]], - [name_to_module[n] for n in region[1]], - merge_bias, - bias_shrinkage, - scale_computation_type) - - scale_factor_region_max = torch.max(torch.abs(1 - scale_factors_region)) - if scale_factor_max is not None: - scale_factor_max = torch.max(scale_factor_max, scale_factor_region_max) - else: - scale_factor_max = scale_factor_region_max + scale_factors_region = _cross_layer_equalization([name_to_module[n] for n in region[0]], [name_to_module[n] for n in region[1]], merge_bias, bias_shrinkage, scale_computation_type) + scale_factor_region_max = torch.max(torch.abs(1 - scale_factors_region)) + if scale_factor_max is not None: + scale_factor_max = torch.max(scale_factor_max, scale_factor_region_max) + else: + scale_factor_max = scale_factor_region_max if threshold is not None and scale_factor_max < threshold: break return model From d0da62aec6f9bb8ee962bc1784d1bb887a1ad481 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Mar 2023 02:15:23 +0000 Subject: [PATCH 6/8] Test (graph_eq): improve check on intermediate results --- tests/brevitas/graph/equalization_fixtures.py | 24 ++++++--- tests/brevitas/graph/test_equalization.py | 50 +++++++++++++++---- 2 files changed, 57 insertions(+), 17 deletions(-) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 4bdc85be7..25262caed 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -48,14 +48,17 @@ def forward(self, x): @pytest_cases.fixture -def linearmha_model(): +@pytest_cases.parametrize('bias', [True, False]) +@pytest_cases.parametrize('add_bias_kv', [True, False]) +@pytest_cases.parametrize('batch_first', [True, False]) +def linearmha_model(bias, add_bias_kv, batch_first): if torch_version < version.parse('1.9.1'): pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") class LinearMhaModel(nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(3,24) - self.mha = nn.MultiheadAttention(24,3,0.1, bias=True, add_bias_kv=True, batch_first=True) + self.mha = nn.MultiheadAttention(24,3,0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first) def forward(self, x): x = self.linear(x) x, _ = self.mha(x, x, x) @@ -64,14 +67,20 @@ def forward(self, x): @pytest_cases.fixture -def layernormmha_model(): +@pytest_cases.parametrize('bias', [True, False]) +@pytest_cases.parametrize('add_bias_kv', [True, False]) +@pytest_cases.parametrize('batch_first', [True, False]) +def layernormmha_model(bias, add_bias_kv, batch_first): if torch_version < version.parse('1.9.1'): pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") class LayerNormMhaModel(nn.Module): def __init__(self) -> None: super().__init__() self.layernorm = nn.LayerNorm(3) - self.mha = nn.MultiheadAttention(3,3,0.1, bias=True, add_bias_kv=True, batch_first=True) + # Simulate learned parameters + self.layernorm.weight.data = torch.randn_like(self.layernorm.weight) + self.layernorm.bias.data = torch.randn_like(self.layernorm.bias) + self.mha = nn.MultiheadAttention(3,3,0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first) def forward(self, x): x = self.layernorm(x) x, _ = self.mha(x, x, x) @@ -80,13 +89,16 @@ def forward(self, x): @pytest_cases.fixture -def mhalinear_model(): +@pytest_cases.parametrize('bias', [True, False]) +@pytest_cases.parametrize('add_bias_kv', [True, False]) +@pytest_cases.parametrize('batch_first', [True, False]) +def mhalinear_model(bias, add_bias_kv, batch_first): if torch_version < version.parse('1.9.1'): pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") class MhaLinearModel(nn.Module): def __init__(self) -> None: super().__init__() - self.mha = nn.MultiheadAttention(3,1,0.1, bias=True, add_bias_kv=True, batch_first=True) + self.mha = nn.MultiheadAttention(3,1,0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first) self.linear = nn.Linear(3,6) def forward(self, x): x, _ = self.mha(x, x, x) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 084ae4925..38cfe1e8a 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -8,7 +8,8 @@ from brevitas import torch_version from brevitas.fx import symbolic_trace -from brevitas.graph import EqualizeGraph +from brevitas.graph.equalize import _cross_layer_equalization +from brevitas.graph.equalize import _extract_regions from brevitas.graph.equalize import _is_supported_module from .equalization_fixtures import * @@ -17,22 +18,33 @@ ATOL = 1e-3 -@pytest_cases.parametrize( - "model_dict", [(model_name, coverage) for model_name, coverage in MODELS.items()], - ids=[model_name for model_name, _ in MODELS.items()]) +def equalize(model, regions, merge_bias, bias_shrinkage, scale_computation_type): + name_to_module = {} + name_set = {name for region in regions for module_set in region for name in module_set} + scale_factors_regions = [] + for name, module in model.named_modules(): + if name in name_set: + name_to_module[name] = module + for region in regions: + scale_factors_region = _cross_layer_equalization([name_to_module[n] for n in region[0]], [name_to_module[n] for n in region[1]], merge_bias, bias_shrinkage, scale_computation_type) + scale_factors_regions.append(scale_factors_region) + return scale_factors_regions + + +@pytest_cases.parametrize("model_dict", [(model_name, coverage) for model_name, coverage in MODELS.items()], ids=[ model_name for model_name, _ in MODELS.items()]) @pytest.mark.parametrize("merge_bias", [True, False]) def test_equalization_torchvision_models(model_dict: dict, merge_bias: bool): - model, coverage = model_dict + model_name, coverage = model_dict - if model == 'googlenet' and torch_version == version.parse('1.8.1'): + if model_name == 'googlenet' and torch_version == version.parse('1.8.1'): pytest.skip('Skip because of PyTorch error = AttributeError: \'function\' object has no attribute \'GoogLeNetOutputs\' ') - if 'vit' in model and torch_version < version.parse('1.13'): + if 'vit' in model_name and torch_version < version.parse('1.13'): pytest.skip(f'ViT supported from torch version 1.13, current torch version is {torch_version}') try: - model = getattr(models, model)(pretrained=True, transform_input=False) + model = getattr(models, model_name)(pretrained=True, transform_input=False) except TypeError: - model = getattr(models, model)(pretrained=True) + model = getattr(models, model_name)(pretrained=True) torch.manual_seed(SEED) inp = torch.randn(IN_SIZE_CONV) @@ -40,7 +52,9 @@ def test_equalization_torchvision_models(model_dict: dict, merge_bias: bool): expected_out = model(inp) model = symbolic_trace(model) - model, regions = EqualizeGraph(3, return_regions=True, merge_bias=merge_bias).apply(model) + regions = _extract_regions(model) + scale_factor_regions = equalize(model, regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') + shape_scale_regions = [scale.shape for scale in scale_factor_regions] out = model(inp) srcs = set() @@ -58,6 +72,15 @@ def test_equalization_torchvision_models(model_dict: dict, merge_bias: bool): assert src_coverage >= coverage[0] assert sink_coverage >= coverage[1] assert torch.allclose(expected_out, out, atol=ATOL) + # Graph equalization can exit in case of shape mismatches or other error without performing any + # equalization and returning a scalar value. We check that the equalized regions are as many as + # expected + print(sum([shape != () for shape in shape_scale_regions])) + if 'alexnet' in model_name: + # In AlexNet, we cannot equalize only through one region + assert sum([shape == () for shape in shape_scale_regions]) == 1 + else: + assert all([shape != () for shape in shape_scale_regions]) @pytest.mark.parametrize("merge_bias", [True, False]) @@ -76,8 +99,13 @@ def test_models(toy_model, merge_bias, request): model.eval() expected_out = model(inp) model = symbolic_trace(model) - model, regions = EqualizeGraph(3, return_regions=True, merge_bias=merge_bias).apply(model) + regions = _extract_regions(model) + scale_factor_regions = equalize(model, regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') + shape_scale_regions = [scale.shape for scale in scale_factor_regions] out = model(inp) assert len(regions) > 0 assert torch.allclose(expected_out, out, atol=ATOL) + # Check that at least one region performs "true" equalization + # If all shapes are scalar, no equalization has been performed + assert all([shape != () for shape in shape_scale_regions]) From f5c591fa375ea4ebb1a41655b2850e99ccc5e24d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 20 Mar 2023 16:16:30 +0000 Subject: [PATCH 7/8] Refactor (graph_eq): annotations, comments, and formatting --- src/brevitas/graph/equalize.py | 82 ++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 29 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 493ceec7e..e13de87a2 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1,7 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause - from collections import namedtuple from copy import deepcopy from functools import partial @@ -21,6 +20,8 @@ EPSILON = 1e-9 +Region = namedtuple('Region', ['srcs', 'sinks']) + _supported_layers = ( nn.ConvTranspose1d, nn.ConvTranspose2d, @@ -50,31 +51,19 @@ nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d) -_scale_invariant_op = ( - torch.mul, - operator.mul, - operator.imul, - operator.__mul__, - operator.__imul__) +_scale_invariant_op = (torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__) -_select_op = ( - operator.getitem, - operator.__getitem__) +_select_op = (operator.getitem, operator.__getitem__) _residual_methods = ('add', 'add_') _residual_fns = (torch.add, operator.add, operator.iadd, operator.__add__, operator.__iadd__) -_batch_norm = ( - nn.BatchNorm1d, - nn.BatchNorm2d, - nn.BatchNorm3d) - +_batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) WeightBiasTuple = namedtuple('WeightBiasTuple', ['weight', 'bias'], defaults=[None]) - def _select_scale_computation_fn( scale_computation_type: str) -> Callable[[torch.Tensor], torch.Tensor]: if scale_computation_type == 'maxabs': @@ -114,9 +103,13 @@ def _get_size(axes: Dict[nn.Module, int]) -> int: def _get_input_axis(module: nn.Module) -> Optional[int]: + """ + Given a sink module, determine the axis associated to the input channels. + Return None if not supported. + """ if isinstance(module, (nn.Linear, nn.MultiheadAttention)): return 1 - elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + elif isinstance(module, _batch_norm): return 0 elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): if module.groups == 1: @@ -139,8 +132,19 @@ def _get_input_axis(module: nn.Module) -> Optional[int]: def _get_output_axis(module: nn.Module) -> Optional[int]: - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.MultiheadAttention, - nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + """ + Given a source module, determine the axis associated to the output channels. + Return None if not supported. + """ + if isinstance(module, + (nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.MultiheadAttention, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d)): return 0 elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): return 1 @@ -153,7 +157,10 @@ def _get_output_axis(module: nn.Module) -> Optional[int]: else: return None -def _combine_weights_bias(weight: torch.Tensor, bias_shrinkage: Union[float, str], bias: Optional[torch.Tensor]) -> torch.Tensor: + +def _combine_weights_bias( + weight: torch.Tensor, bias_shrinkage: Union[float, str], + bias: Optional[torch.Tensor]) -> torch.Tensor: """Combine weights and bias before graph equalizattion This method merges the weight and bias of the sources, so that the resulting equalizer scale factor is influenced also by the magnitude of the bias, mitigated by a shrink factor. @@ -271,11 +278,14 @@ def _cross_layer_equalization( for module, axis in sink_axes.items(): src_broadcast_size = [1] * module.weight.ndim src_broadcast_size[axis] = module.weight.size(axis) - if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)): + if isinstance(module, _batch_norm): + # We re-compute the bias as function of running_mean and running_var to adjust the + # additive factor for equalization. additive_factor = module.running_mean.data * module.weight.data / torch.sqrt( module.running_var.data + module.eps) module.bias.data = module.bias.data + additive_factor * (scaling_factors - 1) module.weight.data = module.weight.data * torch.reshape(scaling_factors, src_broadcast_size) + return scaling_factors @@ -299,7 +309,11 @@ def _equalize( for i in range(iterations): scale_factor_max = None for region in regions: - scale_factors_region = _cross_layer_equalization([name_to_module[n] for n in region[0]], [name_to_module[n] for n in region[1]], merge_bias, bias_shrinkage, scale_computation_type) + scale_factors_region = _cross_layer_equalization( + [name_to_module[n] for n in region.srcs], [name_to_module[n] for n in region.sinks], + merge_bias, + bias_shrinkage, + scale_computation_type) scale_factor_region_max = torch.max(torch.abs(1 - scale_factors_region)) if scale_factor_max is not None: scale_factor_max = torch.max(scale_factor_max, scale_factor_region_max) @@ -386,16 +400,20 @@ def _extract_regions(graph_model: GraphModule) -> Set[Tuple[str]]: if sinks: # each region should appear only once, so to make it hashable # we convert srcs and sinks to ordered lists first, and then to tuples - regions.add((tuple(sorted(srcs)), tuple(sorted(sinks)))) - # for clarity, sort by the of the first source - regions = sorted(regions, key=lambda region: region[0][0]) + regions.add(Region(tuple(sorted(srcs)), tuple(sorted(sinks)))) return regions class EqualizeGraph(GraphTransform): - def __init__(self, iterations: int = 10, threshold: float = 0.05, return_regions: bool = False, - merge_bias: bool = True, bias_shrinkage: Union[float, str] = 'vaiq', scale_computation_type: str = 'maxabs') -> None: + def __init__( + self, + iterations: int = 10, + threshold: float = 0.05, + return_regions: bool = False, + merge_bias: bool = True, + bias_shrinkage: Union[float, str] = 'vaiq', + scale_computation_type: str = 'maxabs') -> None: super(EqualizeGraph, self).__init__() self.iterations = iterations self.return_regions = return_regions @@ -408,8 +426,14 @@ def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]: regions = _extract_regions(graph_model) if len(regions) > 0: - graph_model = _equalize(graph_model, regions, self.iterations, self.threshold, - self.merge_bias, self.bias_shrinkage, self.scale_computation_type) + graph_model = _equalize( + graph_model, + regions, + self.iterations, + self.threshold, + self.merge_bias, + self.bias_shrinkage, + self.scale_computation_type) if self.return_regions: return graph_model, regions else: From c4117a525552c555ab1f5f4a2a132927a2eeedfc Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 23 Mar 2023 14:15:24 +0000 Subject: [PATCH 8/8] Test (graph_eq): extended equalization tests on resnet18 --- tests/brevitas/graph/equalization_fixtures.py | 126 ++++++++++++++++-- tests/brevitas/graph/test_equalization.py | 93 ++++++++----- 2 files changed, 174 insertions(+), 45 deletions(-) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 25262caed..94a77e299 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -7,8 +7,13 @@ from pytest_cases import fixture_union import torch import torch.nn as nn +from torchvision import models from brevitas import torch_version +from brevitas.graph.equalize import _cross_layer_equalization + +SEED = 123456 +ATOL = 1e-3 MODELS = { 'vit_b_32': [0.396, 0.396], @@ -16,13 +21,55 @@ 'mobilenet_v2': [0.6571, 0.6571], 'resnet18': [0.9756, 0.9756], 'googlenet': [0.4956, 0.4956], - 'inception_v3': [0.4973, 0.4973], + 'inception_v3': [0.4948, 0.4948], 'alexnet': [0.875, 0.875],} - IN_SIZE_CONV = (1, 3, 224, 224) IN_SIZE_LINEAR = (1, 224, 3) + +def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_type): + name_to_module = {} + name_set = {name for region in regions for module_set in region for name in module_set} + scale_factors_regions = [] + for name, module in model.named_modules(): + if name in name_set: + name_to_module[name] = module + for i in range(3): + for region in regions: + scale_factors_region = _cross_layer_equalization([name_to_module[n] for n in region[0]], + [name_to_module[n] for n in region[1]], + merge_bias, + bias_shrinkage, + scale_computation_type) + if i == 0: + scale_factors_regions.append(scale_factors_region) + return scale_factors_regions + + +@pytest_cases.fixture +@pytest_cases.parametrize( + "model_dict", [(model_name, coverage) for model_name, coverage in MODELS.items()], + ids=[model_name for model_name, _ in MODELS.items()]) +def model_coverage(model_dict: dict): + model_name, coverage = model_dict + + if model_name == 'googlenet' and torch_version == version.parse('1.8.1'): + pytest.skip( + 'Skip because of PyTorch error = AttributeError: \'function\' object has no attribute \'GoogLeNetOutputs\' ' + ) + if 'vit' in model_name and torch_version < version.parse('1.13'): + pytest.skip( + f'ViT supported from torch version 1.13, current torch version is {torch_version}') + + kwargs = dict() + if model_name in ('inception_v3', 'googlenet'): + kwargs['transform_input'] = False + model = getattr(models, model_name)(pretrained=True, **kwargs) + + return model, coverage + + @pytest_cases.fixture def bnconv_model(): @@ -54,15 +101,20 @@ def forward(self, x): def linearmha_model(bias, add_bias_kv, batch_first): if torch_version < version.parse('1.9.1'): pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") + class LinearMhaModel(nn.Module): + def __init__(self) -> None: super().__init__() - self.linear = nn.Linear(3,24) - self.mha = nn.MultiheadAttention(24,3,0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first) + self.linear = nn.Linear(3, 24) + self.mha = nn.MultiheadAttention( + 24, 3, 0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first) + def forward(self, x): x = self.linear(x) x, _ = self.mha(x, x, x) return x + return LinearMhaModel @@ -73,18 +125,23 @@ def forward(self, x): def layernormmha_model(bias, add_bias_kv, batch_first): if torch_version < version.parse('1.9.1'): pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") + class LayerNormMhaModel(nn.Module): + def __init__(self) -> None: super().__init__() self.layernorm = nn.LayerNorm(3) # Simulate learned parameters self.layernorm.weight.data = torch.randn_like(self.layernorm.weight) self.layernorm.bias.data = torch.randn_like(self.layernorm.bias) - self.mha = nn.MultiheadAttention(3,3,0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first) + self.mha = nn.MultiheadAttention( + 3, 3, 0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first) + def forward(self, x): x = self.layernorm(x) x, _ = self.mha(x, x, x) return x + return LayerNormMhaModel @@ -95,15 +152,20 @@ def forward(self, x): def mhalinear_model(bias, add_bias_kv, batch_first): if torch_version < version.parse('1.9.1'): pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") + class MhaLinearModel(nn.Module): + def __init__(self) -> None: super().__init__() - self.mha = nn.MultiheadAttention(3,1,0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first) - self.linear = nn.Linear(3,6) + self.mha = nn.MultiheadAttention( + 3, 1, 0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first) + self.linear = nn.Linear(3, 6) + def forward(self, x): x, _ = self.mha(x, x, x) x = self.linear(x) return x + return MhaLinearModel @@ -214,8 +276,52 @@ def forward(self, x): return ResidualSrcsAndSinkModel -list_of_fixtures = ['residual_model', 'srcsinkconflict_model', 'mul_model', - 'convbn_model', 'bnconv_model', 'convdepthconv_model', - 'linearmha_model', 'mhalinear_model', 'layernormmha_model'] +list_of_fixtures = [ + 'residual_model', + 'srcsinkconflict_model', + 'mul_model', + 'convbn_model', + 'bnconv_model', + 'convdepthconv_model', + 'linearmha_model', + 'mhalinear_model', + 'layernormmha_model'] toy_model = fixture_union('toy_model', list_of_fixtures, ids=list_of_fixtures) + +RESNET_18_REGIONS = [ + [('conv1',), ('bn1',)], + [('layer4.0.conv2',), ('layer4.0.bn2',)], + [('layer2.0.conv2',), ('layer2.0.bn2',)], + [('layer3.0.bn1',), ('layer3.0.conv2',)], + [('layer4.1.bn1',), ('layer4.1.conv2',)], + [('layer1.1.conv1',), ('layer1.1.bn1',)], + [('layer3.0.conv2',), ('layer3.0.bn2',)], + [('layer2.1.bn1',), ('layer2.1.conv2',)], + [('layer2.1.conv2',), ('layer2.1.bn2',)], + [('layer3.1.conv1',), ('layer3.1.bn1',)], + [('layer3.1.bn1',), ('layer3.1.conv2',)], + [('layer4.0.conv1',), ('layer4.0.bn1',)], + [('layer3.0.downsample.0',), ('layer3.0.downsample.1',)], + [('layer1.0.bn1',), ('layer1.0.conv2',)], + [('layer3.0.bn2', 'layer3.0.downsample.1', 'layer3.1.bn2'), + ('layer3.1.conv1', 'layer4.0.conv1', 'layer4.0.downsample.0')], + [('layer4.1.conv2',), ('layer4.1.bn2',)], + [('layer3.0.conv1',), ('layer3.0.bn1',)], + [('layer4.0.bn1',), ('layer4.0.conv2',)], + [('layer3.1.conv2',), ('layer3.1.bn2',)], + [('layer1.0.conv1',), ('layer1.0.bn1',)], + [('layer4.0.downsample.0',), ('layer4.0.downsample.1',)], + [('layer2.0.bn2', 'layer2.0.downsample.1', 'layer2.1.bn2'), + ('layer2.1.conv1', 'layer3.0.conv1', 'layer3.0.downsample.0')], + [('layer1.1.bn1',), ('layer1.1.conv2',)], + [('layer1.1.conv2',), ('layer1.1.bn2',)], + [('layer4.1.conv1',), ('layer4.1.bn1',)], + [('layer2.1.conv1',), ('layer2.1.bn1',)], + [('bn1', 'layer1.0.bn2', 'layer1.1.bn2'), + ('layer1.0.conv1', 'layer1.1.conv1', 'layer2.0.conv1', 'layer2.0.downsample.0')], + [('layer2.0.bn1',), ('layer2.0.conv2',)], + [('layer2.0.conv1',), ('layer2.0.bn1',)], + [('layer4.0.bn2', 'layer4.0.downsample.1', 'layer4.1.bn2'), ('fc', 'layer4.1.conv1')], + [('layer2.0.downsample.0',), ('layer2.0.downsample.1',)], + [('layer1.0.conv2',), ('layer1.0.bn2',)],] diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 38cfe1e8a..f0bfb20ee 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -1,59 +1,78 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from packaging import version -import pytest +import copy + import torch from torchvision import models -from brevitas import torch_version from brevitas.fx import symbolic_trace -from brevitas.graph.equalize import _cross_layer_equalization from brevitas.graph.equalize import _extract_regions from brevitas.graph.equalize import _is_supported_module +from brevitas.graph.utils import get_module from .equalization_fixtures import * -SEED = 123456 -ATOL = 1e-3 +def test_resnet18_equalization(): + model = models.resnet18(pretrained=True) -def equalize(model, regions, merge_bias, bias_shrinkage, scale_computation_type): - name_to_module = {} - name_set = {name for region in regions for module_set in region for name in module_set} - scale_factors_regions = [] - for name, module in model.named_modules(): - if name in name_set: - name_to_module[name] = module - for region in regions: - scale_factors_region = _cross_layer_equalization([name_to_module[n] for n in region[0]], [name_to_module[n] for n in region[1]], merge_bias, bias_shrinkage, scale_computation_type) - scale_factors_regions.append(scale_factors_region) - return scale_factors_regions + torch.manual_seed(SEED) + inp = torch.randn(IN_SIZE_CONV) + model.eval() + model = symbolic_trace(model) + expected_out = model(inp) + model_orig = copy.deepcopy(model) + regions = _extract_regions(model) + _ = equalize_test( + model, regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs') + out = model(inp) -@pytest_cases.parametrize("model_dict", [(model_name, coverage) for model_name, coverage in MODELS.items()], ids=[ model_name for model_name, _ in MODELS.items()]) -@pytest.mark.parametrize("merge_bias", [True, False]) -def test_equalization_torchvision_models(model_dict: dict, merge_bias: bool): - model_name, coverage = model_dict + # Check that equalization is not introducing FP variations + assert torch.allclose(expected_out, out, atol=ATOL) + + regions = sorted(regions, key=lambda region: region[0][0]) + resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0]) + equalized_layers = set() + for region in resnet_18_regions: + equalized_layers.update(region[0]) + equalized_layers.update(region[1]) + + # Check that we found all the expected regions + for region, expected_region in zip(regions, resnet_18_regions): + sources_check = set(region[0]) == set(expected_region[0]) + sinks_check = set(region[1]) == set(expected_region[1]) + assert sources_check + assert sinks_check + + # Check that all layers were equalized and weights changed + for layer in equalized_layers: + eq_module = get_module(model, layer) + orig_module = get_module(model_orig, layer) + assert not torch.allclose(eq_module.weight, orig_module.weight) - if model_name == 'googlenet' and torch_version == version.parse('1.8.1'): - pytest.skip('Skip because of PyTorch error = AttributeError: \'function\' object has no attribute \'GoogLeNetOutputs\' ') - if 'vit' in model_name and torch_version < version.parse('1.13'): - pytest.skip(f'ViT supported from torch version 1.13, current torch version is {torch_version}') - try: - model = getattr(models, model_name)(pretrained=True, transform_input=False) - except TypeError: - model = getattr(models, model_name)(pretrained=True) +@pytest_cases.parametrize("merge_bias", [True, False]) +def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool): + model, coverage = model_coverage torch.manual_seed(SEED) inp = torch.randn(IN_SIZE_CONV) model.eval() + # The isistance does not work after symbolic trace + is_alexnet = isinstance(model, models.AlexNet) + model = symbolic_trace(model) + expected_out = model(inp) - model = symbolic_trace(model) regions = _extract_regions(model) - scale_factor_regions = equalize(model, regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') + scale_factor_regions = equalize_test( + model, + regions, + merge_bias=merge_bias, + bias_shrinkage='vaiq', + scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] out = model(inp) @@ -75,15 +94,14 @@ def test_equalization_torchvision_models(model_dict: dict, merge_bias: bool): # Graph equalization can exit in case of shape mismatches or other error without performing any # equalization and returning a scalar value. We check that the equalized regions are as many as # expected - print(sum([shape != () for shape in shape_scale_regions])) - if 'alexnet' in model_name: + if is_alexnet: # In AlexNet, we cannot equalize only through one region assert sum([shape == () for shape in shape_scale_regions]) == 1 else: assert all([shape != () for shape in shape_scale_regions]) -@pytest.mark.parametrize("merge_bias", [True, False]) +@pytest_cases.parametrize("merge_bias", [True, False]) def test_models(toy_model, merge_bias, request): test_id = request.node.callspec.id @@ -100,7 +118,12 @@ def test_models(toy_model, merge_bias, request): expected_out = model(inp) model = symbolic_trace(model) regions = _extract_regions(model) - scale_factor_regions = equalize(model, regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') + scale_factor_regions = equalize_test( + model, + regions, + merge_bias=merge_bias, + bias_shrinkage='vaiq', + scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] out = model(inp)