Skip to content

Commit

Permalink
Refactor: added annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 20, 2023
1 parent 6d37776 commit 6c80e6d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 54 deletions.
105 changes: 54 additions & 51 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from functools import partial
import operator
from typing import Dict
from typing import Dict, List, Set, Tuple, Union

import torch
import torch.nn as nn

from brevitas.fx import GraphModule
from brevitas.fx import Node
Expand All @@ -21,31 +22,31 @@
EPSILON = 1e-9

_supported_layers = (
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.Linear,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d)
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.Linear,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d)

_scale_invariant_layers = (
torch.nn.Dropout,
torch.nn.Dropout2d,
torch.nn.Dropout3d,
torch.nn.ReLU,
torch.nn.MaxPool1d,
torch.nn.MaxPool2d,
torch.nn.MaxPool3d,
torch.nn.AvgPool1d,
torch.nn.AvgPool2d,
torch.nn.AvgPool3d,
torch.nn.AdaptiveAvgPool1d,
torch.nn.AdaptiveAvgPool2d,
torch.nn.AdaptiveAvgPool3d)
nn.Dropout,
nn.Dropout2d,
nn.Dropout3d,
nn.ReLU,
nn.MaxPool1d,
nn.MaxPool2d,
nn.MaxPool3d,
nn.AvgPool1d,
nn.AvgPool2d,
nn.AvgPool3d,
nn.AdaptiveAvgPool1d,
nn.AdaptiveAvgPool2d,
nn.AdaptiveAvgPool3d)

_scale_invariant_op = (
torch.mul,
Expand All @@ -69,18 +70,18 @@
)

_batch_norm = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
)


def _channel_range(inp):
def _channel_range(inp: torch.Tensor) -> torch.Tensor:
mins, _ = inp.min(dim=1)
maxs, _ = inp.max(dim=1)
out = maxs - mins
# correct corner case where where all weights along a channel have the same value
# e.g. when a mean/torch.nn.AvgPool/torch.nn.AdaptiveAvgPool is converted to a depth-wise conv
# e.g. when a mean/nn.AvgPool/nn.AdaptiveAvgPool is converted to a depth-wise conv
out = torch.where(out == 0., torch.mean(inp, dim=1), out)

# convert to positive range, in case any of the values are negative,
Expand All @@ -89,7 +90,7 @@ def _channel_range(inp):
return out


def _get_size(axes):
def _get_size(axes: Dict[nn.Module, int]) -> int:
m0, axis0 = list(axes.items())[0]
size = m0.weight.size(axis0)
for m, axis in axes.items():
Expand All @@ -98,19 +99,19 @@ def _get_size(axes):
return size


def _get_input_axis(module):
if isinstance(module, torch.nn.Linear):
def _get_input_axis(module: nn.Module) -> int:
if isinstance(module, nn.Linear):
return 1
elif isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
return 0
elif isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
if module.groups == 1:
return 1
elif module.groups == module.out_channels:
return 0
else:
raise RuntimeError("Group convolution not supported")
elif isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
if module.groups == 1:
return 0
elif module.groups == module.out_channels:
Expand All @@ -121,17 +122,17 @@ def _get_input_axis(module):
raise RuntimeError(f"Module {module} not supported.")


def _get_output_axis(module):
if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
def _get_output_axis(module: nn.Module) -> int:
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
return 0
elif isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranpose3d)):
elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
return 1
else:
raise RuntimeError(f"Module {module} not supported.")


def _cross_layer_equalization(srcs, sinks):
def _cross_layer_equalization(srcs: List[nn.Module], sinks: List[nn.Module]):
"""
Given two adjacent tensors', the weights are scaled such that
the ranges of the first tensors' output channel are equal to the
Expand Down Expand Up @@ -174,11 +175,11 @@ def _cross_layer_equalization(srcs, sinks):
module.weight.data = module.weight.data * torch.reshape(scaling_factors, src_broadcast_size)


def _equalize(model, regions, iterations):
def _equalize(model: GraphModule, regions: Set[Tuple[str]], iterations: int) -> GraphModule:
"""
Generalized version of section 4.1 of https://arxiv.org/pdf/1906.04721.pdf
"""
name_to_module : Dict[str, torch.nn.Module] = {}
name_to_module : Dict[str, nn.Module] = {}
name_set = {name for region in regions for module_set in region for name in module_set}

for name, module in model.named_modules():
Expand All @@ -190,22 +191,24 @@ def _equalize(model, regions, iterations):
return model


def _is_supported_module(graph_model, node):
def _is_supported_module(graph_model: GraphModule, node: Node) -> bool:
return node.op == 'call_module' and isinstance(get_module(graph_model, node.target), _supported_layers)


def _is_scale_invariant_module(graph_model, node):
def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool:
return node.op == 'call_module' and isinstance(get_module(graph_model, node.target), _scale_invariant_layers)

def _is_scale_invariant_function(node):

def _is_scale_invariant_function(node: Node) -> bool:
return node.op == 'call_function' and node.target in _scale_invariant_op

def _is_reshaping_op(node):

def _is_reshaping_op(node: Node) -> bool:
return (node.op == 'call_function' and node.target in [torch.flatten, torch.reshape]
or node.op == 'call_method' and node.target in ['view', 'reshape', 'flatten'])


def walk_region(graph_model: GraphModule, starting_node: Node, history, srcs, sinks, walk_forward):
def walk_region(graph_model: GraphModule, starting_node: Node, history: Set[Node], srcs: Set[str], sinks: Set[str], walk_forward: bool):
node_list = starting_node.users if walk_forward else starting_node.all_input_nodes
for node in node_list:
# we keep a history of how the graph has been walked already, invariant to the direction,
Expand Down Expand Up @@ -237,7 +240,7 @@ def walk_region(graph_model: GraphModule, starting_node: Node, history, srcs, si
continue


def _extract_regions(graph_model: GraphModule):
def _extract_regions(graph_model: GraphModule) -> Set[Tuple[str]]:
regions = set()
for node in graph_model.graph.nodes:
if node.op == 'call_module':
Expand All @@ -256,12 +259,12 @@ def _extract_regions(graph_model: GraphModule):

class EqualizeGraph(GraphTransform):

def __init__(self, iterations, return_regions=False) -> None:
def __init__(self, iterations: int, return_regions: bool = False) -> None:
super(EqualizeGraph, self).__init__()
self.iterations = iterations
self.return_regions = return_regions

def apply(self, graph_model: GraphModule):
def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]:
regions = _extract_regions(graph_model)
graph_model = _equalize(graph_model, regions, self.iterations)
if self.return_regions:
Expand All @@ -281,7 +284,7 @@ def add_to_bias(self, module, tensor):
if module.bias is not None:
module.bias.data += tensor.view_as(module.bias)
else:
module.bias = torch.nn.Parameter(tensor)
module.bias = nn.Parameter(tensor)

def absorb_biases(self, groups):
for layer, bn, (next_layer_name, next_layer) in groups:
Expand Down
3 changes: 0 additions & 3 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ def forward(self, x):

@pytest_cases.fixture
def mul_model():
"""
In this example, conv_0 is both a src and sink.
"""
class ResidualSrcsAndSinkModel(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down

0 comments on commit 6c80e6d

Please sign in to comment.