Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Eq: Transformers, LayerNorm, Better BN #555

Merged
merged 8 commits into from
Mar 31, 2023
183 changes: 113 additions & 70 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from collections import namedtuple
from copy import deepcopy
from functools import partial
import operator
Expand All @@ -19,14 +20,18 @@

EPSILON = 1e-9

Region = namedtuple('Region', ['srcs', 'sinks'])

_supported_layers = (
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.MultiheadAttention,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.Linear,
nn.LayerNorm,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d)
Expand All @@ -46,23 +51,17 @@
nn.AdaptiveAvgPool2d,
nn.AdaptiveAvgPool3d)

_scale_invariant_op = (
torch.mul,
operator.mul,
operator.imul,
operator.__mul__,
operator.__imul__,
)
_scale_invariant_op = (torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__)

_select_op = (operator.getitem, operator.__getitem__)

_residual_methods = ('add', 'add_')

_residual_fns = (torch.add, operator.add, operator.iadd, operator.__add__, operator.__iadd__)

_batch_norm = (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
)
_batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

WeightBiasTuple = namedtuple('WeightBiasTuple', ['weight', 'bias'], defaults=[None])


def _select_scale_computation_fn(
Expand Down Expand Up @@ -103,10 +102,14 @@ def _get_size(axes: Dict[nn.Module, int]) -> int:
return size


def _get_input_axis(module: nn.Module) -> int:
if isinstance(module, nn.Linear):
def _get_input_axis(module: nn.Module) -> Optional[int]:
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
"""
Given a sink module, determine the axis associated to the input channels.
Return None if not supported.
"""
if isinstance(module, (nn.Linear, nn.MultiheadAttention)):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
return 1
elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
elif isinstance(module, _batch_norm):
return 0
elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
if module.groups == 1:
Expand All @@ -118,26 +121,46 @@ def _get_input_axis(module: nn.Module) -> int:
return 0
elif module.groups == module.out_channels:
return 1
elif isinstance(module, nn.LayerNorm):
# We assume normalization happens only along the channel dimension
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
if len(module.weight.shape) == 1:
return 0
else:
return None
else:
return None


def _get_output_axis(module: nn.Module) -> int:
if isinstance(
module,
(nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d,
nn.BatchNorm3d)):
def _get_output_axis(module: nn.Module) -> Optional[int]:
"""
Given a source module, determine the axis associated to the output channels.
Return None if not supported.
"""
if isinstance(module,
(nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.MultiheadAttention,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d)):
return 0
elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
return 1
elif isinstance(module, nn.LayerNorm):
# We assume normalization happens only along the channel dimension
if len(module.weight.shape) == 1:
return 0
else:
return None
else:
return None


def _combine_weights_bias(
weight: nn.parameter.Parameter,
bias_shrinkage: Union[float, str],
bias: Optional[nn.parameter.Parameter]):
weight: torch.Tensor, bias_shrinkage: Union[float, str],
bias: Optional[torch.Tensor]) -> torch.Tensor:
"""Combine weights and bias before graph equalizattion
This method merges the weight and bias of the sources, so that the resulting equalizer scale factor
is influenced also by the magnitude of the bias, mitigated by a shrink factor.
Expand Down Expand Up @@ -193,15 +216,28 @@ def _cross_layer_equalization(
device = next(srcs[0].parameters()).device
dtype = next(srcs[0].parameters()).dtype

for module_set in [srcs, sinks]:
for module in module_set:
if not isinstance(module, _supported_layers):
return torch.tensor(
1., dtype=dtype,
device=device) # If module is not supported, do not perform graph equalization

src_axes = {m: _get_output_axis(m) for m in srcs}
sink_axes = {m: _get_input_axis(m) for m in sinks}
src_axes = {}
sink_axes = {}

for i, module in enumerate(srcs):
# If module is not supported, do not perform graph equalization
if not isinstance(module, _supported_layers):
return torch.tensor(1., dtype=dtype, device=device)
if isinstance(module, nn.MultiheadAttention):
srcs[i] = module.out_proj
src_axes[srcs[i]] = _get_output_axis(module)

for i, module in enumerate(sinks):
# If module is not supported, do not perform graph equalization
if not isinstance(module, _supported_layers):
return torch.tensor(1., dtype=dtype, device=device)
# For MultiheadAttention, we support only self-attetion
if isinstance(module, nn.MultiheadAttention) and hasattr(sinks[i], 'in_proj_weight'):
# For sinks, we only need to modify the weight but not the bias
sinks[i] = WeightBiasTuple(module.in_proj_weight)
elif isinstance(module, nn.MultiheadAttention) and not hasattr(sinks[i], 'in_proj_weight'):
return torch.tensor(1., dtype=dtype, device=device)
sink_axes[sinks[i]] = _get_input_axis(module)

# Check if any of the axis is None, which means that the module is not supported.
# In that case, do not perform graph equalization
Expand All @@ -216,7 +252,6 @@ def _cross_layer_equalization(
# Similarly, exit if source and sink have different different sizes
if None in [src_size, sink_size] or src_size != sink_size:
return torch.tensor(1., dtype=dtype, device=device)

transpose = lambda module, axis: module.weight if axis == 0 else module.weight.transpose(0, 1)
scale_fn = _select_scale_computation_fn(scale_computation_type)
if merge_bias:
Expand All @@ -243,11 +278,14 @@ def _cross_layer_equalization(
for module, axis in sink_axes.items():
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)
if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
if isinstance(module, _batch_norm):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
# We re-compute the bias as function of running_mean and running_var to adjust the
# additive factor for equalization.
additive_factor = module.running_mean.data * module.weight.data / torch.sqrt(
module.running_var.data + module.eps)
module.bias.data = module.bias.data + additive_factor * (scaling_factors - 1)
module.weight.data = module.weight.data * torch.reshape(scaling_factors, src_broadcast_size)

return scaling_factors


Expand All @@ -271,25 +309,30 @@ def _equalize(
for i in range(iterations):
scale_factor_max = None
for region in regions:
scale_factors_region = _cross_layer_equalization([name_to_module[n] for n in region[0]],
[name_to_module[n] for n in region[1]],
merge_bias,
bias_shrinkage,
scale_computation_type)

scale_factor_region_max = torch.max(torch.abs(1 - scale_factors_region))
if scale_factor_max is not None:
scale_factor_max = torch.max(scale_factor_max, scale_factor_region_max)
else:
scale_factor_max = scale_factor_region_max
scale_factors_region = _cross_layer_equalization(
[name_to_module[n] for n in region.srcs], [name_to_module[n] for n in region.sinks],
merge_bias,
bias_shrinkage,
scale_computation_type)
scale_factor_region_max = torch.max(torch.abs(1 - scale_factors_region))
if scale_factor_max is not None:
scale_factor_max = torch.max(scale_factor_max, scale_factor_region_max)
else:
scale_factor_max = scale_factor_region_max
if threshold is not None and scale_factor_max < threshold:
break
return model


def _is_supported_module(graph_model: GraphModule, node: Node) -> bool:
return node.op == 'call_module' and isinstance(
get_module(graph_model, node.target), _supported_layers)
if node.op == 'call_module':
module = get_module(graph_model, node.target)
if isinstance(module, _supported_layers):
# We support only self-attention
if isinstance(module, nn.MultiheadAttention):
return all([node.all_input_nodes[0].name == n.name for n in node.all_input_nodes])
return True
return False


def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool:
Expand All @@ -298,7 +341,7 @@ def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool:


def _is_scale_invariant_function(node: Node) -> bool:
return node.op == 'call_function' and node.target in _scale_invariant_op
return node.op == 'call_function' and node.target in _scale_invariant_op + _select_op


def _is_reshaping_op(node: Node) -> bool:
Expand All @@ -325,7 +368,10 @@ def walk_region(
continue
if _is_supported_module(graph_model, node):
if walk_forward:
sinks.add(node.target)
module = get_module(graph_model, node.target)
# It is not possible to equalize through LayerNorm as sink
if not isinstance(module, nn.LayerNorm):
sinks.add(node.target)
else:
srcs.add(node.target)
walk_region(graph_model, node, history, srcs, sinks, walk_forward=True)
Expand All @@ -348,17 +394,13 @@ def walk_region(
def _extract_regions(graph_model: GraphModule) -> Set[Tuple[str]]:
regions = set()
for node in graph_model.graph.nodes:
if node.op == 'call_module':
module = get_module(graph_model, node.target)
if isinstance(module, _supported_layers):
srcs, sinks = {node.target}, set()
walk_region(graph_model, node, set(), srcs, sinks, walk_forward=True)
if sinks:
# each region should appear only once, so to make it hashable
# we convert srcs and sinks to ordered lists first, and then to tuples
regions.add((tuple(sorted(srcs)), tuple(sorted(sinks))))
# for clarity, sort by the of the first source
regions = sorted(regions, key=lambda region: region[0][0])
if _is_supported_module(graph_model, node):
srcs, sinks = {node.target}, set()
walk_region(graph_model, node, set(), srcs, sinks, walk_forward=True)
if sinks:
# each region should appear only once, so to make it hashable
# we convert srcs and sinks to ordered lists first, and then to tuples
regions.add(Region(tuple(sorted(srcs)), tuple(sorted(sinks))))
return regions


Expand All @@ -371,26 +413,27 @@ def __init__(
return_regions: bool = False,
merge_bias: bool = True,
bias_shrinkage: Union[float, str] = 'vaiq',
scale_computation: str = 'maxabs') -> None:
scale_computation_type: str = 'maxabs') -> None:
super(EqualizeGraph, self).__init__()
self.iterations = iterations
self.return_regions = return_regions
self.merge_bias = merge_bias
self.bias_shrinkage = bias_shrinkage
self.threshold = threshold
self.scale_computation = scale_computation
self.scale_computation_type = scale_computation_type

def apply(self,
graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]:
regions = _extract_regions(graph_model)
graph_model = _equalize(
graph_model,
regions,
self.iterations,
self.threshold,
self.merge_bias,
self.bias_shrinkage,
self.scale_computation)
if len(regions) > 0:
graph_model = _equalize(
graph_model,
regions,
self.iterations,
self.threshold,
self.merge_bias,
self.bias_shrinkage,
self.scale_computation_type)
if self.return_regions:
return graph_model, regions
else:
Expand Down
Loading