Skip to content

Commit

Permalink
Feat (mx): automatic group_dim in layerwise quant (#1012)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Sep 4, 2024
1 parent dfd9e9f commit 0dfed16
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
18 changes: 8 additions & 10 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@
" pass\n",
"\n",
"class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" # In layerwise quantization, groupdim is automatically determined\n",
" pass\n",
"\n",
"\n",
"class MXModel(nn.Module):\n",
Expand Down Expand Up @@ -221,9 +221,8 @@
" group_size = 8\n",
"\n",
"class MXFloat8ActNoPadding(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" # In layerwise quantization, groupdim is automatically determined\n",
" group_size = 8\n",
" group_dim = 1\n",
"\n",
"\n",
"class MXModelNoPadding(nn.Module):\n",
Expand Down Expand Up @@ -277,8 +276,8 @@
" pass\n",
"\n",
"class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" # In layerwise quantization, groupdim is automatically determined\n",
" pass\n",
"\n",
"\n",
"class MXModel(nn.Module):\n",
Expand Down Expand Up @@ -314,12 +313,11 @@
"class MXInt8Weight(MXInt8Weight):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" bit_width = 8\n",
" pass\n",
"\n",
"class MXInt8Act(MXInt8Act):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" bit_width = 8\n",
" # In layerwise quantization, groupdim is automatically determined\n",
" pass\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
Expand Down
29 changes: 29 additions & 0 deletions src/brevitas/quant/solver/act.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from warnings import warn

import torch
from torch import nn
from torch import Tensor
Expand Down Expand Up @@ -111,6 +113,33 @@ def scaling_shape(scaling_per_output):
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return SCALAR_SHAPE

@value
def group_dim(module=None, group_size=None):
# Avoid circular import
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer

if group_size is not None and module is not None:
if isinstance(module, QuantWeightBiasInputOutputLayer):
if isinstance(module, nn.Linear):
return -1
elif isinstance(module,
(nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d)):
warn(
"Group dim is being selected assuming batched input. Using unbatched input will fail and requires manually specification of group_dim"
)
# We are assuming batched input
return 1
else:
raise RuntimeError("Cannot determine automatically group_dim. Please specify")
else:
raise RuntimeError(
f"Cannot determine automatically group_dim for {type(module)}. Please specify")


class SolveActScalingPerOutputChannelShape(ExtendedInjector):

Expand Down

0 comments on commit 0dfed16

Please sign in to comment.