Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (graph/equalize): improvements for llm equalization #784

Merged
merged 12 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ class FnToModule(CallableToModule):
def match_node(self, node: Node) -> bool:
return node.op == 'call_function' and node.target is self.old_callable

def move_node_args_to_kwargs(self, node: Node):
super().move_node_args_to_kwargs(node)
# Moving to stateful modules, we remove the 'training' argument if it is passed to the
# functional version of the layer since it is not needed anymore
kwargs = dict(node.kwargs)
if 'training' in kwargs:
del kwargs['training']
node.kwargs = immutable_dict(kwargs)


class MethodToModule(CallableToModule):

Expand Down
50 changes: 30 additions & 20 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@

__all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph']

EPSILON = 1e-9
# TODO: if we are able to run activation equalization in GPU + float16, we could have two separate
# epsilon factors for float16 (2e-5) vs float32/bfloat16 (1e-9). At the moment we are tied to one
# single epsilon for both cases.
EPSILON = 2e-5

_supported_layers = (
nn.ConvTranspose1d,
Expand Down Expand Up @@ -73,6 +76,8 @@

_batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

_ignore_ops = (getattr, 'size', 'contiguous')


# Required for being hashable
@dataclass(eq=True, frozen=True)
Expand Down Expand Up @@ -398,7 +403,6 @@ def _no_equalize():
scale_fn = _select_scale_computation_fn(scale_computation_type)
sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()]
sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1))
sinks_range = torch.clamp(sinks_range, EPSILON)

# Determine the srcs_range based on where we are performing activation equalization or
# weight equalization
Expand Down Expand Up @@ -431,10 +435,18 @@ def _no_equalize():
"Detected source and sink with non compatible shapes, equalization is skipped")
return _no_equalize()

# Instead of clipping very low values, which would cause their reciprocal to be very large
# thus hindering quantization, we set them to one, which is the no-op equivalent for equalization
sinks_range = torch.where((sinks_range < EPSILON) | (srcs_range < EPSILON),
torch.tensor(1., dtype=dtype, device=device),
sinks_range)
srcs_range = torch.where((sinks_range < EPSILON) | (srcs_range < EPSILON),
torch.tensor(1., dtype=dtype, device=device),
srcs_range)

srcs_range = torch.pow(srcs_range, alpha)
sinks_range = torch.pow(sinks_range, 1 - alpha)
scaling_factors = srcs_range / sinks_range
scaling_factors = torch.clamp(scaling_factors, EPSILON)
inverse_scaling_factors = torch.reciprocal(scaling_factors)

if list_of_act_val is not None and list_of_insert_mul_node_fn is not None:
Expand All @@ -455,8 +467,8 @@ def _no_equalize():
torch.reshape(inverse_scaling_factors, src_broadcast_size)),
attr='weight')
for module, axis in sink_axes.items():
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)
sink_broadcast_size = [1] * module.weight.ndim
sink_broadcast_size[axis] = module.weight.size(axis)
if isinstance(module, _batch_norm):
# We re-compute the bias as function of running_mean and running_var to adjust the
# additive factor for equalization.
Expand All @@ -466,7 +478,7 @@ def _no_equalize():
module, module.bias.clone() + additive_factor * (scaling_factors - 1), attr='bias')
_update_weights(
module,
module.weight.clone() * torch.reshape(scaling_factors, src_broadcast_size),
module.weight.clone() * torch.reshape(scaling_factors, sink_broadcast_size),
attr='weight')

return scaling_factors
Expand Down Expand Up @@ -575,6 +587,8 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
node.op == 'call_function' and node.target in _residual_fns):
find_srcs(graph_model, node, state)
find_sinks(graph_model, node, state)
elif node.target in _ignore_ops:
continue
else:
# If we meet an unrecognized op, we add None to invalidate the region
state.srcs.add(_UNSUPPORTED_OP)
Expand Down Expand Up @@ -606,6 +620,8 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
node.op == 'call_function' and node.target in _residual_fns):
find_sinks(graph_model, node, state)
find_srcs(graph_model, node, state)
elif node.target in _ignore_ops:
continue
else:
# If we meet an unrecognized op, we add None to invalidate the region
state.sinks.add(_UNSUPPORTED_OP)
Expand Down Expand Up @@ -713,7 +729,7 @@ def setup(self):
for region in self.regions:
batch_dim = 0
if hasattr(region, 'batch_first'):
batch_dim = 0 if region.batch_first == True else 1
batch_dim = 0 if region.batch_first else 1

hook_fn = partial(
self.forward_stats_hook, name=region, batch_dim=batch_dim, use_inp=True)
Expand Down Expand Up @@ -761,16 +777,14 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k
# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')
x = x.transpose(0, batch_dim)

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = self.scale_fn(x, dim=batch_dim)
self.float_act_map[name] = input_scales
else:
batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x],
dim=batch_dim)
self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim)
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def insert_mul_node(self, scale, shape, axis, region, batch_dim=0):
broadcastable_shape = [1] * len(shape)
Expand Down Expand Up @@ -832,7 +846,7 @@ def setup(self):
for name in region.srcs + region.sinks:
module = name_to_module[name]
if hasattr(module, 'batch_first'):
batch_dim = 0 if module.batch_first == True else 1
batch_dim = 0 if module.batch_first else 1
for name in region_to_search:
act_module = name_to_module[name]
use_inp = True if region_to_search == region.sinks else False
Expand Down Expand Up @@ -907,16 +921,12 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k
# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')
x = x.transpose(0, batch_dim)

self.batch_dim_act_map[name] = batch_dim
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = self.scale_fn(x, dim=batch_dim)
self.float_act_map[name] = input_scales
else:
batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x],
dim=batch_dim)
self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim)
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0):
broadcastable_shape = [1] * len(shape)
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/graph/standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class TorchFunctionalToModule(GraphTransform):
nn.AvgPool1d), (F.avg_pool2d, nn.AvgPool2d),
(F.avg_pool3d, nn.AvgPool3d), (F.adaptive_avg_pool1d, nn.AdaptiveAvgPool1d),
(F.adaptive_avg_pool2d,
nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, nn.AdaptiveAvgPool3d))
nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d,
nn.AdaptiveAvgPool3d), (F.dropout, nn.Dropout))

def __init__(self, fn_to_module_map=FN_TO_MODULE_MAP):
super().__init__()
Expand Down
13 changes: 11 additions & 2 deletions src/brevitas_examples/llm/llm_quant/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from brevitas.fx.brevitas_tracer import value_trace
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import EqualizeGraph
from brevitas.graph.standardize import DuplicateSharedStatelessModule
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn
from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32

Expand All @@ -26,6 +28,13 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha):
return outs


def trace_and_standardize(model, ref_kwargs):
graph_model = value_trace(model, value_args=ref_kwargs)
graph_model = TorchFunctionalToModule().apply(graph_model)
graph_model = DuplicateSharedStatelessModule().apply(graph_model)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
return graph_model


@torch.no_grad()
def apply_act_equalization(
model,
Expand All @@ -49,7 +58,7 @@ def apply_act_equalization(
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs)
# TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode
# or an FX interpreter to run it on GPU
warnings.warn(
Expand All @@ -70,5 +79,5 @@ def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type='
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs)
EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model)
7 changes: 6 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,14 @@ def main():
quantize_embedding=args.quantize_embedding,
seqlen=args.seqlen)
# Tie back first/last layer weights in case they got untied
model.tie_weights()
print("Model quantization applied.")

# If any equalization has taken places, the embedding layer and the fully connected one are
# not tied anymore, and they need to be treated as standalone, separate layers.
# In all other cases we can tie them back so to preserve memory.
if args.act_equalization is None and not args.weight_equalization:
model.tie_weights()

if args.act_calibration:
print("Apply act calibration...")
apply_calibration(model, calibration_loader, args.nsamples)
Expand Down
Loading