Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation eq extension #642

Merged
merged 8 commits into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 67 additions & 39 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,21 @@
_batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)


# To be moved in Pytorch utils
class SmartHook(nn.Module):

def __init__(self, module, hook_fn):
super().__init__()
self.module = module
self.hook_fn = hook_fn

def forward(self, *args, **kwargs):
out = self.module(*args, **kwargs)
args = args + (out,)
self.hook_fn(*args, **kwargs)
return out


# Required for being hashable
@dataclass(eq=True, frozen=True)
class WeightBiasTuple:
Expand Down Expand Up @@ -576,8 +591,6 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
find_sinks(graph_model, node, state)
elif _is_scale_invariant_module(
graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node):
if _is_scale_invariant_activation(graph_model, node):
state.acts.add(node.target)
find_srcs(graph_model, node, state)
find_sinks(graph_model, node, state)
elif (node.op == 'call_method' and node.target in _residual_methods or
Expand Down Expand Up @@ -608,8 +621,6 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
state.sinks.add(node.target)
elif _is_scale_invariant_module(
graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node):
if _is_scale_invariant_activation(graph_model, node):
state.acts.add(node.target)
find_sinks(graph_model, node, state)
elif (node.op == 'call_method' and node.target in _residual_methods or
node.op == 'call_function' and node.target in _residual_fns):
Expand Down Expand Up @@ -723,11 +734,22 @@ def setup(self):
if hasattr(region, 'batch_first'):
batch_dim = 0 if region.batch_first == True else 1

hook_fn = partial(self.forward_stats_hook, name=region, batch_dim=batch_dim)
self.hooks.append(region.register_forward_pre_hook(hook_fn))
# hook_fn = partial(self.forward_stats_hook, name=region, batch_dim=batch_dim)
# self.hooks.append(region.register_forward_pre_hook(hook_fn))
kwarg_name = 'query' if isinstance(region, torch.nn.MultiheadAttention) else None
hook_fn = partial(
self.forward_stats_hook,
name=region,
batch_dim=batch_dim,
kwarg_name=kwarg_name,
use_inp=True)
new_instance = SmartHook(region, hook_fn)
ModuleInstanceToModuleInstance(region, new_instance).apply(self.model)
self.hooks.append(new_instance)

def apply(self, alpha):
scale_factors = []
self.remove_hooks()
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
for region in self.regions:
if self.float_act_map[region] == None:
continue
Expand All @@ -745,30 +767,21 @@ def apply(self, alpha):

def remove_hooks(self):
for hook in self.hooks:
hook.remove()
ModuleInstanceToModuleInstance(hook, hook.module).apply(self.model)

def forward_stats_hook(self, module, inp, name, batch_dim=0):
if len(inp) == 0:
warnings.warn(
"Cannot perform layerwise activation equalization with only kwargs as input")
self.float_act_map[name] = None
return
def forward_stats_hook(self, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs):

if use_inp and len(args) > 1:
inp = args[0]
elif not use_inp:
inp = args[-1]
elif len(kwargs) > 0:
inp = kwargs[kwarg_name]
# Extra check for batch_dim
if hasattr(inp, 'names') and 'N' in inp.names:
batch_dim = inp.names.index('N')
inp = inp.transpose(0, batch_dim)

if len(inp) > 1:
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
# check that they are all equal for self-attention MHA
self_attention = all([inp[0].data_ptr() == i.data_ptr() for i in inp])
if not self_attention:
self.float_act_map[name] = None
return
inp = inp[0]
else:
inp = inp[0]

self.batch_dim_act_map[name] = batch_dim

if name not in self.float_act_map:
Expand Down Expand Up @@ -826,35 +839,45 @@ def setup(self):
# because we can't propagate scaling factors
regions_to_drop = []
for region in self.regions:
if len(region.acts) == 0:
regions_to_drop.append(region)
# This condition is for redudancy, since
# a region with two scale-varying activations cannot be detected in the first place
elif len(region.acts) > 1 and any([isinstance(name_to_module[act_name],
_scale_varying_activations)
for act_name in region.acts]):
if len(region.acts) > 1 and any([isinstance(name_to_module[act_name],
_scale_varying_activations)
for act_name in region.acts]):
regions_to_drop.append(region)
else:
# We assume that the entire region has a unique batch_dim
batch_dim = 0
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
for name in region.srcs + region.sinks:
module = name_to_module[name]
if hasattr(module, 'batch_first'):
batch_dim = 0 if module.batch_first == True else 1

for act_name in region.acts:
act_module = name_to_module[act_name]
hook_fn = partial(self.forward_stats_hook, name=act_name, batch_dim=batch_dim)
self.hooks.append(act_module.register_forward_hook(hook_fn))
for name in region_to_search:
act_module = name_to_module[name]
kwarg_name = 'query' if isinstance(
act_module, torch.nn.MultiheadAttention) else None
use_inp = True if region_to_search == region.sinks else False
hook_fn = partial(
self.forward_stats_hook,
name=name,
batch_dim=batch_dim,
kwarg_name=kwarg_name,
use_inp=use_inp)
new_instance = SmartHook(act_module, hook_fn)
ModuleInstanceToModuleInstance(act_module, new_instance).apply(self.graph_model)
self.hooks.append(new_instance)

self.regions = [x for x in self.regions if x not in regions_to_drop]

def apply(self, alpha):
scale_factors = []
self.remove_hooks()
name_to_module = dict_name_to_module(self.graph_model, self.regions)
for region in self.regions:
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
act_module = [name_to_module[act_name] for act_name in region.acts]
list_of_act_val = [self.float_act_map[act_name] for act_name in region.acts]
list_of_act_val = [self.float_act_map[name] for name in region_to_search]
sinks = [name_to_module[sink] for sink in region.sinks]
# Filter out scale_varying activations from the srcs
srcs = [
Expand Down Expand Up @@ -888,11 +911,16 @@ def apply(self, alpha):

def remove_hooks(self):
for hook in self.hooks:
hook.remove()
ModuleInstanceToModuleInstance(hook, hook.module).apply(self.graph_model)

def forward_stats_hook(self, module, inp, out, name, batch_dim=0):
inp = inp[0]
def forward_stats_hook(self, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs):

if use_inp and len(args) > 1:
inp = args[0]
elif not use_inp:
inp = args[-1]
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
elif len(kwargs) > 0:
inp = kwargs[kwarg_name]
# Extra check for batch_dim
if hasattr(inp, 'names') and 'N' in inp.names:
batch_dim = inp.names.index('N')
Expand All @@ -901,9 +929,9 @@ def forward_stats_hook(self, module, inp, out, name, batch_dim=0):
self.batch_dim_act_map[name] = batch_dim

if name not in self.float_act_map:
self.float_act_map[name] = self.scale_fn(out, dim=batch_dim)
self.float_act_map[name] = self.scale_fn(inp, dim=batch_dim)
else:
batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), out],
batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), inp],
dim=batch_dim)
self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim)

Expand Down
23 changes: 18 additions & 5 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,17 @@ def __init__(self, scale_module, layer) -> None:
self.scale = scale_module
self.layer = layer

def forward(self, x, *args, **kwargs):
def forward(self, *args, **kwargs):
args = list(args)

if len(args) > 0:
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
x = args[0]
args.pop(0)
elif len(kwargs) > 0 and 'query' in kwargs:
x = kwargs['query']
else:
raise ValueError("Unsupported input type")

out = x
if 'key' in kwargs:
if kwargs['key'].data_ptr() != out.data_ptr():
Expand All @@ -25,16 +34,20 @@ def forward(self, x, *args, **kwargs):
# We need to preserve the correctness of the forward even after
# quantization has been applied
if isinstance(self.layer, (torch.nn.MultiheadAttention, QuantMultiheadAttention)):
if 'key' not in kwargs.items():
if 'query' not in kwargs.keys():
pos_inputs.append(out)
args.pop(0)
else:
kwargs['query'] = out
if 'key' not in kwargs.keys():
pos_inputs.append(out)
args.pop(0)
else:
kwargs['key'] = out
if 'value' not in kwargs.items():
if 'value' not in kwargs.keys():
pos_inputs.append(out)
args.pop(0)
else:
kwargs['value'] = out

out = self.layer(*pos_inputs, *args, **kwargs)
out = self.layer(*(pos_inputs + args), **kwargs)
return out