From 37faf0f02eb2404e9756b8065ac80382eea7cb9e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 27 Jun 2023 01:47:22 +0100 Subject: [PATCH 1/8] Feat (act_eq): extended functionalities --- src/brevitas/graph/equalize.py | 66 +++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 20 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 91c7c9c65..875a54d14 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -74,6 +74,21 @@ _batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) +# To be moved in Pytorch utils +class SmartHook(nn.Module): + + def __init__(self, module, hook_fn): + super().__init__() + self.module = module + self.hook_fn = hook_fn + + def forward(self, *args, **kwargs): + out = self.module(*args, **kwargs) + args = args + (out,) + self.hook_fn(*args, **kwargs) + return out + + # Required for being hashable @dataclass(eq=True, frozen=True) class WeightBiasTuple: @@ -576,8 +591,6 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, find_sinks(graph_model, node, state) elif _is_scale_invariant_module( graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): - if _is_scale_invariant_activation(graph_model, node): - state.acts.add(node.target) find_srcs(graph_model, node, state) find_sinks(graph_model, node, state) elif (node.op == 'call_method' and node.target in _residual_methods or @@ -608,8 +621,6 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, state.sinks.add(node.target) elif _is_scale_invariant_module( graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): - if _is_scale_invariant_activation(graph_model, node): - state.acts.add(node.target) find_sinks(graph_model, node, state) elif (node.op == 'call_method' and node.target in _residual_methods or node.op == 'call_function' and node.target in _residual_fns): @@ -826,35 +837,45 @@ def setup(self): # because we can't propagate scaling factors regions_to_drop = [] for region in self.regions: - if len(region.acts) == 0: - regions_to_drop.append(region) # This condition is for redudancy, since # a region with two scale-varying activations cannot be detected in the first place - elif len(region.acts) > 1 and any([isinstance(name_to_module[act_name], - _scale_varying_activations) - for act_name in region.acts]): + if len(region.acts) > 1 and any([isinstance(name_to_module[act_name], + _scale_varying_activations) + for act_name in region.acts]): regions_to_drop.append(region) else: # We assume that the entire region has a unique batch_dim batch_dim = 0 + region_to_search = region.sinks if len(region.acts) == 0 else region.acts for name in region.srcs + region.sinks: module = name_to_module[name] if hasattr(module, 'batch_first'): batch_dim = 0 if module.batch_first == True else 1 - - for act_name in region.acts: - act_module = name_to_module[act_name] - hook_fn = partial(self.forward_stats_hook, name=act_name, batch_dim=batch_dim) - self.hooks.append(act_module.register_forward_hook(hook_fn)) + for name in region_to_search: + act_module = name_to_module[name] + kwarg_name = 'query' if isinstance( + act_module, torch.nn.MultiheadAttention) else None + use_inp = True if region_to_search == region.sinks else False + hook_fn = partial( + self.forward_stats_hook, + name=name, + batch_dim=batch_dim, + kwarg_name=kwarg_name, + use_inp=use_inp) + new_instance = SmartHook(act_module, hook_fn) + ModuleInstanceToModuleInstance(act_module, new_instance).apply(self.graph_model) + self.hooks.append(new_instance) self.regions = [x for x in self.regions if x not in regions_to_drop] def apply(self, alpha): scale_factors = [] + self.remove_hooks() name_to_module = dict_name_to_module(self.graph_model, self.regions) for region in self.regions: + region_to_search = region.sinks if len(region.acts) == 0 else region.acts act_module = [name_to_module[act_name] for act_name in region.acts] - list_of_act_val = [self.float_act_map[act_name] for act_name in region.acts] + list_of_act_val = [self.float_act_map[name] for name in region_to_search] sinks = [name_to_module[sink] for sink in region.sinks] # Filter out scale_varying activations from the srcs srcs = [ @@ -888,11 +909,16 @@ def apply(self, alpha): def remove_hooks(self): for hook in self.hooks: - hook.remove() + ModuleInstanceToModuleInstance(hook, hook.module).apply(self.graph_model) - def forward_stats_hook(self, module, inp, out, name, batch_dim=0): - inp = inp[0] + def forward_stats_hook(self, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs): + if use_inp and len(args) > 1: + inp = args[0] + elif not use_inp: + inp = args[-1] + elif len(kwargs) > 0: + inp = kwargs[kwarg_name] # Extra check for batch_dim if hasattr(inp, 'names') and 'N' in inp.names: batch_dim = inp.names.index('N') @@ -901,9 +927,9 @@ def forward_stats_hook(self, module, inp, out, name, batch_dim=0): self.batch_dim_act_map[name] = batch_dim if name not in self.float_act_map: - self.float_act_map[name] = self.scale_fn(out, dim=batch_dim) + self.float_act_map[name] = self.scale_fn(inp, dim=batch_dim) else: - batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), out], + batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), inp], dim=batch_dim) self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) From db48ece8ff489ab5297141202a9dc2cefde8f0bd Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 27 Jun 2023 10:14:37 +0100 Subject: [PATCH 2/8] Update for layerwise --- src/brevitas/graph/equalize.py | 40 ++++++++++++++++-------------- src/brevitas/nn/equalized_layer.py | 23 +++++++++++++---- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 875a54d14..319bb4dff 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -734,11 +734,22 @@ def setup(self): if hasattr(region, 'batch_first'): batch_dim = 0 if region.batch_first == True else 1 - hook_fn = partial(self.forward_stats_hook, name=region, batch_dim=batch_dim) - self.hooks.append(region.register_forward_pre_hook(hook_fn)) + # hook_fn = partial(self.forward_stats_hook, name=region, batch_dim=batch_dim) + # self.hooks.append(region.register_forward_pre_hook(hook_fn)) + kwarg_name = 'query' if isinstance(region, torch.nn.MultiheadAttention) else None + hook_fn = partial( + self.forward_stats_hook, + name=region, + batch_dim=batch_dim, + kwarg_name=kwarg_name, + use_inp=True) + new_instance = SmartHook(region, hook_fn) + ModuleInstanceToModuleInstance(region, new_instance).apply(self.model) + self.hooks.append(new_instance) def apply(self, alpha): scale_factors = [] + self.remove_hooks() for region in self.regions: if self.float_act_map[region] == None: continue @@ -756,30 +767,21 @@ def apply(self, alpha): def remove_hooks(self): for hook in self.hooks: - hook.remove() + ModuleInstanceToModuleInstance(hook, hook.module).apply(self.model) - def forward_stats_hook(self, module, inp, name, batch_dim=0): - if len(inp) == 0: - warnings.warn( - "Cannot perform layerwise activation equalization with only kwargs as input") - self.float_act_map[name] = None - return + def forward_stats_hook(self, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs): + if use_inp and len(args) > 1: + inp = args[0] + elif not use_inp: + inp = args[-1] + elif len(kwargs) > 0: + inp = kwargs[kwarg_name] # Extra check for batch_dim if hasattr(inp, 'names') and 'N' in inp.names: batch_dim = inp.names.index('N') inp = inp.transpose(0, batch_dim) - if len(inp) > 1: - # check that they are all equal for self-attention MHA - self_attention = all([inp[0].data_ptr() == i.data_ptr() for i in inp]) - if not self_attention: - self.float_act_map[name] = None - return - inp = inp[0] - else: - inp = inp[0] - self.batch_dim_act_map[name] = batch_dim if name not in self.float_act_map: diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index b7a732d0e..bdc5e9246 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -10,8 +10,17 @@ def __init__(self, scale_module, layer) -> None: self.scale = scale_module self.layer = layer - def forward(self, x, *args, **kwargs): + def forward(self, *args, **kwargs): args = list(args) + + if len(args) > 0: + x = args[0] + args.pop(0) + elif len(kwargs) > 0 and 'query' in kwargs: + x = kwargs['query'] + else: + raise ValueError("Unsupported input type") + out = x if 'key' in kwargs: if kwargs['key'].data_ptr() != out.data_ptr(): @@ -25,16 +34,20 @@ def forward(self, x, *args, **kwargs): # We need to preserve the correctness of the forward even after # quantization has been applied if isinstance(self.layer, (torch.nn.MultiheadAttention, QuantMultiheadAttention)): - if 'key' not in kwargs.items(): + if 'query' not in kwargs.keys(): + pos_inputs.append(out) + args.pop(0) + else: + kwargs['query'] = out + if 'key' not in kwargs.keys(): pos_inputs.append(out) args.pop(0) else: kwargs['key'] = out - if 'value' not in kwargs.items(): + if 'value' not in kwargs.keys(): pos_inputs.append(out) args.pop(0) else: kwargs['value'] = out - - out = self.layer(*pos_inputs, *args, **kwargs) + out = self.layer(*(pos_inputs + args), **kwargs) return out From 22b81a9af083100c01be2f7bab613c155d7d5603 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 27 Jun 2023 13:54:03 +0100 Subject: [PATCH 3/8] Fix for MHA --- src/brevitas/graph/equalize.py | 2 -- src/brevitas/nn/equalized_layer.py | 8 +++----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 319bb4dff..c8543a9ea 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1,8 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from collections import namedtuple -from copy import deepcopy from dataclasses import dataclass from dataclasses import field from functools import partial diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index bdc5e9246..a0708c842 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -15,9 +15,12 @@ def forward(self, *args, **kwargs): if len(args) > 0: x = args[0] + # We delete it since it will updated and passed as first arg args.pop(0) elif len(kwargs) > 0 and 'query' in kwargs: x = kwargs['query'] + # We delete it since it will updated and passed as first arg + del kwargs['query'] else: raise ValueError("Unsupported input type") @@ -34,11 +37,6 @@ def forward(self, *args, **kwargs): # We need to preserve the correctness of the forward even after # quantization has been applied if isinstance(self.layer, (torch.nn.MultiheadAttention, QuantMultiheadAttention)): - if 'query' not in kwargs.keys(): - pos_inputs.append(out) - args.pop(0) - else: - kwargs['query'] = out if 'key' not in kwargs.keys(): pos_inputs.append(out) args.pop(0) From c2ace61ef6250b8ece4231c8c2f578366da5a8ad Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 28 Jun 2023 19:49:17 +0100 Subject: [PATCH 4/8] Clean up --- src/brevitas/graph/equalize.py | 49 +++++++++++++++++-------------- src/brevitas/utils/torch_utils.py | 15 ++++++++++ 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index c8543a9ea..d8f738adb 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from copy import deepcopy from dataclasses import dataclass from dataclasses import field from functools import partial @@ -20,6 +21,7 @@ from brevitas.graph.utils import get_node from brevitas.nn.equalized_layer import EqualizedModule from brevitas.nn.quant_scale_bias import ScaleBias +from brevitas.utils.torch_utils import KwargsForwardHook from .base import GraphTransform from .base import InsertModuleCallAfter @@ -72,21 +74,6 @@ _batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) -# To be moved in Pytorch utils -class SmartHook(nn.Module): - - def __init__(self, module, hook_fn): - super().__init__() - self.module = module - self.hook_fn = hook_fn - - def forward(self, *args, **kwargs): - out = self.module(*args, **kwargs) - args = args + (out,) - self.hook_fn(*args, **kwargs) - return out - - # Required for being hashable @dataclass(eq=True, frozen=True) class WeightBiasTuple: @@ -732,8 +719,6 @@ def setup(self): if hasattr(region, 'batch_first'): batch_dim = 0 if region.batch_first == True else 1 - # hook_fn = partial(self.forward_stats_hook, name=region, batch_dim=batch_dim) - # self.hooks.append(region.register_forward_pre_hook(hook_fn)) kwarg_name = 'query' if isinstance(region, torch.nn.MultiheadAttention) else None hook_fn = partial( self.forward_stats_hook, @@ -741,7 +726,7 @@ def setup(self): batch_dim=batch_dim, kwarg_name=kwarg_name, use_inp=True) - new_instance = SmartHook(region, hook_fn) + new_instance = KwargsForwardHook(region, hook_fn) ModuleInstanceToModuleInstance(region, new_instance).apply(self.model) self.hooks.append(new_instance) @@ -767,7 +752,16 @@ def remove_hooks(self): for hook in self.hooks: ModuleInstanceToModuleInstance(hook, hook.module).apply(self.model) - def forward_stats_hook(self, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs): + def forward_stats_hook( + self, module, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs): + # Check for MHA Cross attention, and if found, skip it + kwargs_to_check = deepcopy(kwargs) + kwargs_to_check.update(zip(module.forward.__code__.co_varnames[1:], args[:-1])) + if 'query' in kwargs_to_check and 'key' in kwargs_to_check and 'value' in kwargs_to_check: + if kwargs_to_check['query'].data_ptr() != kwargs_to_check['key'].data_ptr( + ) != kwargs_to_check['value'].data_ptr(): + self.float_act_map[name] = None + return if use_inp and len(args) > 1: inp = args[0] @@ -775,6 +769,7 @@ def forward_stats_hook(self, *args, name, batch_dim=0, kwarg_name=None, use_inp= inp = args[-1] elif len(kwargs) > 0: inp = kwargs[kwarg_name] + # Extra check for batch_dim if hasattr(inp, 'names') and 'N' in inp.names: batch_dim = inp.names.index('N') @@ -832,7 +827,6 @@ def __init__( def setup(self): name_to_module = dict_name_to_module(self.graph_model, self.regions) # Select only regions with activation to equalize through. - # If a region has no activations, it is dropped. # If a region has multiple scale varying activation, must also be dropped # because we can't propagate scaling factors regions_to_drop = [] @@ -862,7 +856,7 @@ def setup(self): batch_dim=batch_dim, kwarg_name=kwarg_name, use_inp=use_inp) - new_instance = SmartHook(act_module, hook_fn) + new_instance = KwargsForwardHook(act_module, hook_fn) ModuleInstanceToModuleInstance(act_module, new_instance).apply(self.graph_model) self.hooks.append(new_instance) @@ -874,6 +868,8 @@ def apply(self, alpha): name_to_module = dict_name_to_module(self.graph_model, self.regions) for region in self.regions: region_to_search = region.sinks if len(region.acts) == 0 else region.acts + if any([self.float_act_map[name] is None for name in region_to_search]): + continue act_module = [name_to_module[act_name] for act_name in region.acts] list_of_act_val = [self.float_act_map[name] for name in region_to_search] sinks = [name_to_module[sink] for sink in region.sinks] @@ -911,7 +907,16 @@ def remove_hooks(self): for hook in self.hooks: ModuleInstanceToModuleInstance(hook, hook.module).apply(self.graph_model) - def forward_stats_hook(self, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs): + def forward_stats_hook( + self, module, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs): + # Check for MHA Cross attention, and if found, skip it + kwargs_to_check = deepcopy(kwargs) + kwargs_to_check.update(zip(module.forward.__code__.co_varnames[1:], args)) + if 'query' in kwargs_to_check and 'key' in kwargs_to_check and 'value' in kwargs_to_check: + if kwargs_to_check['query'].data_ptr() != kwargs_to_check['key'].data_ptr( + ) != kwargs_to_check['value'].data_ptr(): + self.float_act_map[name] = None + return if use_inp and len(args) > 1: inp = args[0] diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 2927b92f2..7105ea874 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -3,6 +3,7 @@ import copy +import torch from torch.nn import Sequential @@ -22,6 +23,20 @@ def forward(self, *input): return out +class KwargsForwardHook(torch.nn.Module): + + def __init__(self, module, hook_fn): + super().__init__() + self.module = module + self.hook_fn = hook_fn + + def forward(self, *args, **kwargs): + out = self.module(*args, **kwargs) + args = args + (out,) + self.hook_fn(self.module, *args, **kwargs) + return out + + def torch_partial_deepcopy(model): """ Performs a deepcopy of a torch.nn.Module, except for all the parameters that are instead passed by reference From ef30c088300b2d1ef1c1c44a2f4f38518c99e70a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 28 Jun 2023 22:56:55 +0100 Subject: [PATCH 5/8] Fix --- src/brevitas/graph/equalize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index d8f738adb..c089d60ac 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -911,7 +911,7 @@ def forward_stats_hook( self, module, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs): # Check for MHA Cross attention, and if found, skip it kwargs_to_check = deepcopy(kwargs) - kwargs_to_check.update(zip(module.forward.__code__.co_varnames[1:], args)) + kwargs_to_check.update(zip(module.forward.__code__.co_varnames[1:], args[:-1])) if 'query' in kwargs_to_check and 'key' in kwargs_to_check and 'value' in kwargs_to_check: if kwargs_to_check['query'].data_ptr() != kwargs_to_check['key'].data_ptr( ) != kwargs_to_check['value'].data_ptr(): From 46bd62d34262b8ac376f43028b536b801c887980 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 29 Jun 2023 11:12:24 +0100 Subject: [PATCH 6/8] Cleanup --- src/brevitas/graph/equalize.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index c089d60ac..865e295c3 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -138,8 +138,7 @@ def __enter__(self): def __exit__(self, type, value, traceback): if self.enabled: self.scale_factors = self.graph_act_eq.apply(self.alpha) - self.graph_act_eq.remove_hooks() - return True # To propagate exceptions + return True # To propagate exceptions def dict_name_to_module(model, regions): @@ -919,22 +918,22 @@ def forward_stats_hook( return if use_inp and len(args) > 1: - inp = args[0] + x = args[0] elif not use_inp: - inp = args[-1] + x = args[-1] elif len(kwargs) > 0: - inp = kwargs[kwarg_name] + x = kwargs[kwarg_name] # Extra check for batch_dim - if hasattr(inp, 'names') and 'N' in inp.names: - batch_dim = inp.names.index('N') - inp = inp.transpose(0, batch_dim) + if hasattr(x, 'names') and 'N' in x.names: + batch_dim = x.names.index('N') + x = x.transpose(0, batch_dim) self.batch_dim_act_map[name] = batch_dim if name not in self.float_act_map: - self.float_act_map[name] = self.scale_fn(inp, dim=batch_dim) + self.float_act_map[name] = self.scale_fn(x, dim=batch_dim) else: - batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), inp], + batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x], dim=batch_dim) self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) From 358ee4c03f823ad7ecce589007163b19d0af3445 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 29 Jun 2023 12:33:10 +0100 Subject: [PATCH 7/8] Fix for args/kwargs in equalized module --- src/brevitas/nn/equalized_layer.py | 32 ++++++++---------------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index a0708c842..d20e78537 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -11,19 +11,11 @@ def __init__(self, scale_module, layer) -> None: self.layer = layer def forward(self, *args, **kwargs): - args = list(args) - - if len(args) > 0: - x = args[0] - # We delete it since it will updated and passed as first arg - args.pop(0) - elif len(kwargs) > 0 and 'query' in kwargs: - x = kwargs['query'] - # We delete it since it will updated and passed as first arg - del kwargs['query'] - else: - raise ValueError("Unsupported input type") + kwargs.update(zip(self.layer.forward.__code__.co_varnames[1:], args)) + possible_input_kwargs = ['input', 'inp', 'query'] + input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0] + x = kwargs[input_kwarg] out = x if 'key' in kwargs: if kwargs['key'].data_ptr() != out.data_ptr(): @@ -32,20 +24,12 @@ def forward(self, *args, **kwargs): "Replace kwargs with positional args to avoid this exception.") out = self.scale(out) - pos_inputs = [out] + kwargs[input_kwarg] = out # QuantMultiheadAttention is not a subclass of MultiheadAttention # We need to preserve the correctness of the forward even after # quantization has been applied if isinstance(self.layer, (torch.nn.MultiheadAttention, QuantMultiheadAttention)): - if 'key' not in kwargs.keys(): - pos_inputs.append(out) - args.pop(0) - else: - kwargs['key'] = out - if 'value' not in kwargs.keys(): - pos_inputs.append(out) - args.pop(0) - else: - kwargs['value'] = out - out = self.layer(*(pos_inputs + args), **kwargs) + kwargs['key'] = out + kwargs['value'] = out + out = self.layer(**kwargs) return out From 91efbd079379a973772dd5f1bd0ddd5d81933e5b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 29 Jun 2023 12:56:26 +0100 Subject: [PATCH 8/8] args/kwargs update in hook --- src/brevitas/graph/equalize.py | 66 +++++++++++++--------------------- 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 865e295c3..a9ceb4601 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -718,13 +718,8 @@ def setup(self): if hasattr(region, 'batch_first'): batch_dim = 0 if region.batch_first == True else 1 - kwarg_name = 'query' if isinstance(region, torch.nn.MultiheadAttention) else None hook_fn = partial( - self.forward_stats_hook, - name=region, - batch_dim=batch_dim, - kwarg_name=kwarg_name, - use_inp=True) + self.forward_stats_hook, name=region, batch_dim=batch_dim, use_inp=True) new_instance = KwargsForwardHook(region, hook_fn) ModuleInstanceToModuleInstance(region, new_instance).apply(self.model) self.hooks.append(new_instance) @@ -751,35 +746,32 @@ def remove_hooks(self): for hook in self.hooks: ModuleInstanceToModuleInstance(hook, hook.module).apply(self.model) - def forward_stats_hook( - self, module, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs): + def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs): # Check for MHA Cross attention, and if found, skip it - kwargs_to_check = deepcopy(kwargs) - kwargs_to_check.update(zip(module.forward.__code__.co_varnames[1:], args[:-1])) - if 'query' in kwargs_to_check and 'key' in kwargs_to_check and 'value' in kwargs_to_check: - if kwargs_to_check['query'].data_ptr() != kwargs_to_check['key'].data_ptr( - ) != kwargs_to_check['value'].data_ptr(): + kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1])) + if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs: + if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr(): self.float_act_map[name] = None return - if use_inp and len(args) > 1: - inp = args[0] + possible_input_kwargs = ['input', 'inp', 'query'] + input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0] + if use_inp: + x = kwargs[input_kwarg] elif not use_inp: - inp = args[-1] - elif len(kwargs) > 0: - inp = kwargs[kwarg_name] + x = args[-1] # Extra check for batch_dim - if hasattr(inp, 'names') and 'N' in inp.names: - batch_dim = inp.names.index('N') - inp = inp.transpose(0, batch_dim) + if hasattr(x, 'names') and 'N' in x.names: + batch_dim = x.names.index('N') + x = x.transpose(0, batch_dim) self.batch_dim_act_map[name] = batch_dim if name not in self.float_act_map: - self.float_act_map[name] = self.scale_fn(inp, dim=batch_dim) + self.float_act_map[name] = self.scale_fn(x, dim=batch_dim) else: - batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), inp], + batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x], dim=batch_dim) self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) @@ -846,15 +838,9 @@ def setup(self): batch_dim = 0 if module.batch_first == True else 1 for name in region_to_search: act_module = name_to_module[name] - kwarg_name = 'query' if isinstance( - act_module, torch.nn.MultiheadAttention) else None use_inp = True if region_to_search == region.sinks else False hook_fn = partial( - self.forward_stats_hook, - name=name, - batch_dim=batch_dim, - kwarg_name=kwarg_name, - use_inp=use_inp) + self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp) new_instance = KwargsForwardHook(act_module, hook_fn) ModuleInstanceToModuleInstance(act_module, new_instance).apply(self.graph_model) self.hooks.append(new_instance) @@ -906,23 +892,21 @@ def remove_hooks(self): for hook in self.hooks: ModuleInstanceToModuleInstance(hook, hook.module).apply(self.graph_model) - def forward_stats_hook( - self, module, *args, name, batch_dim=0, kwarg_name=None, use_inp=True, **kwargs): + def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs): # Check for MHA Cross attention, and if found, skip it - kwargs_to_check = deepcopy(kwargs) - kwargs_to_check.update(zip(module.forward.__code__.co_varnames[1:], args[:-1])) - if 'query' in kwargs_to_check and 'key' in kwargs_to_check and 'value' in kwargs_to_check: - if kwargs_to_check['query'].data_ptr() != kwargs_to_check['key'].data_ptr( - ) != kwargs_to_check['value'].data_ptr(): + kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1])) + if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs: + if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr(): self.float_act_map[name] = None return - if use_inp and len(args) > 1: - x = args[0] + possible_input_kwargs = ['input', 'inp', 'query'] + input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0] + if use_inp: + x = kwargs[input_kwarg] elif not use_inp: x = args[-1] - elif len(kwargs) > 0: - x = kwargs[kwarg_name] + # Extra check for batch_dim if hasattr(x, 'names') and 'N' in x.names: batch_dim = x.names.index('N')