Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Eq: annotations refactor #532

Merged
merged 1 commit into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 54 additions & 51 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from functools import partial
import operator
from typing import Dict
from typing import Dict, List, Set, Tuple, Union

import torch
import torch.nn as nn

from brevitas.fx import GraphModule
from brevitas.fx import Node
Expand All @@ -21,31 +22,31 @@
EPSILON = 1e-9

_supported_layers = (
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.Linear,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d)
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.Linear,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d)

_scale_invariant_layers = (
torch.nn.Dropout,
torch.nn.Dropout2d,
torch.nn.Dropout3d,
torch.nn.ReLU,
torch.nn.MaxPool1d,
torch.nn.MaxPool2d,
torch.nn.MaxPool3d,
torch.nn.AvgPool1d,
torch.nn.AvgPool2d,
torch.nn.AvgPool3d,
torch.nn.AdaptiveAvgPool1d,
torch.nn.AdaptiveAvgPool2d,
torch.nn.AdaptiveAvgPool3d)
nn.Dropout,
nn.Dropout2d,
nn.Dropout3d,
nn.ReLU,
nn.MaxPool1d,
nn.MaxPool2d,
nn.MaxPool3d,
nn.AvgPool1d,
nn.AvgPool2d,
nn.AvgPool3d,
nn.AdaptiveAvgPool1d,
nn.AdaptiveAvgPool2d,
nn.AdaptiveAvgPool3d)

_scale_invariant_op = (
torch.mul,
Expand All @@ -69,18 +70,18 @@
)

_batch_norm = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
)


def _channel_range(inp):
def _channel_range(inp: torch.Tensor) -> torch.Tensor:
mins, _ = inp.min(dim=1)
maxs, _ = inp.max(dim=1)
out = maxs - mins
# correct corner case where where all weights along a channel have the same value
# e.g. when a mean/torch.nn.AvgPool/torch.nn.AdaptiveAvgPool is converted to a depth-wise conv
# e.g. when a mean/nn.AvgPool/nn.AdaptiveAvgPool is converted to a depth-wise conv
out = torch.where(out == 0., torch.mean(inp, dim=1), out)

# convert to positive range, in case any of the values are negative,
Expand All @@ -89,7 +90,7 @@ def _channel_range(inp):
return out


def _get_size(axes):
def _get_size(axes: Dict[nn.Module, int]) -> int:
m0, axis0 = list(axes.items())[0]
size = m0.weight.size(axis0)
for m, axis in axes.items():
Expand All @@ -98,19 +99,19 @@ def _get_size(axes):
return size


def _get_input_axis(module):
if isinstance(module, torch.nn.Linear):
def _get_input_axis(module: nn.Module) -> int:
if isinstance(module, nn.Linear):
return 1
elif isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
return 0
elif isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
if module.groups == 1:
return 1
elif module.groups == module.out_channels:
return 0
else:
raise RuntimeError("Group convolution not supported")
elif isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
if module.groups == 1:
return 0
elif module.groups == module.out_channels:
Expand All @@ -121,17 +122,17 @@ def _get_input_axis(module):
raise RuntimeError(f"Module {module} not supported.")


def _get_output_axis(module):
if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
def _get_output_axis(module: nn.Module) -> int:
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
return 0
elif isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranpose3d)):
elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
return 1
else:
raise RuntimeError(f"Module {module} not supported.")


def _cross_layer_equalization(srcs, sinks):
def _cross_layer_equalization(srcs: List[nn.Module], sinks: List[nn.Module]):
"""
Given two adjacent tensors', the weights are scaled such that
the ranges of the first tensors' output channel are equal to the
Expand Down Expand Up @@ -174,11 +175,11 @@ def _cross_layer_equalization(srcs, sinks):
module.weight.data = module.weight.data * torch.reshape(scaling_factors, src_broadcast_size)


def _equalize(model, regions, iterations):
def _equalize(model: GraphModule, regions: Set[Tuple[str]], iterations: int) -> GraphModule:
"""
Generalized version of section 4.1 of https://arxiv.org/pdf/1906.04721.pdf
"""
name_to_module : Dict[str, torch.nn.Module] = {}
name_to_module : Dict[str, nn.Module] = {}
name_set = {name for region in regions for module_set in region for name in module_set}

for name, module in model.named_modules():
Expand All @@ -190,22 +191,24 @@ def _equalize(model, regions, iterations):
return model


def _is_supported_module(graph_model, node):
def _is_supported_module(graph_model: GraphModule, node: Node) -> bool:
return node.op == 'call_module' and isinstance(get_module(graph_model, node.target), _supported_layers)


def _is_scale_invariant_module(graph_model, node):
def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool:
return node.op == 'call_module' and isinstance(get_module(graph_model, node.target), _scale_invariant_layers)

def _is_scale_invariant_function(node):

def _is_scale_invariant_function(node: Node) -> bool:
return node.op == 'call_function' and node.target in _scale_invariant_op

def _is_reshaping_op(node):

def _is_reshaping_op(node: Node) -> bool:
return (node.op == 'call_function' and node.target in [torch.flatten, torch.reshape]
or node.op == 'call_method' and node.target in ['view', 'reshape', 'flatten'])


def walk_region(graph_model: GraphModule, starting_node: Node, history, srcs, sinks, walk_forward):
def walk_region(graph_model: GraphModule, starting_node: Node, history: Set[Node], srcs: Set[str], sinks: Set[str], walk_forward: bool):
node_list = starting_node.users if walk_forward else starting_node.all_input_nodes
for node in node_list:
# we keep a history of how the graph has been walked already, invariant to the direction,
Expand Down Expand Up @@ -237,7 +240,7 @@ def walk_region(graph_model: GraphModule, starting_node: Node, history, srcs, si
continue


def _extract_regions(graph_model: GraphModule):
def _extract_regions(graph_model: GraphModule) -> Set[Tuple[str]]:
regions = set()
for node in graph_model.graph.nodes:
if node.op == 'call_module':
Expand All @@ -256,12 +259,12 @@ def _extract_regions(graph_model: GraphModule):

class EqualizeGraph(GraphTransform):

def __init__(self, iterations, return_regions=False) -> None:
def __init__(self, iterations: int, return_regions: bool = False) -> None:
super(EqualizeGraph, self).__init__()
self.iterations = iterations
self.return_regions = return_regions

def apply(self, graph_model: GraphModule):
def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]:
regions = _extract_regions(graph_model)
graph_model = _equalize(graph_model, regions, self.iterations)
if self.return_regions:
Expand All @@ -281,7 +284,7 @@ def add_to_bias(self, module, tensor):
if module.bias is not None:
module.bias.data += tensor.view_as(module.bias)
else:
module.bias = torch.nn.Parameter(tensor)
module.bias = nn.Parameter(tensor)

def absorb_biases(self, groups):
for layer, bn, (next_layer_name, next_layer) in groups:
Expand Down
3 changes: 0 additions & 3 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ def forward(self, x):

@pytest_cases.fixture
def mul_model():
"""
In this example, conv_0 is both a src and sink.
"""
class ResidualSrcsAndSinkModel(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down