Skip to content

Commit

Permalink
feat: add brevitas channel-wise support
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Jul 29, 2024
1 parent 123d619 commit a9ac260
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 44 deletions.
5 changes: 4 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def pytest_addoption(parser):
)

parser.addoption(
"--no-flaky", action="store_true", default=False, help="Don't run known flaky tests."
"--no-flaky",
action="store_true",
default=False,
help="Don't run known flaky tests.",
)


Expand Down
13 changes: 11 additions & 2 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto):
# Create a GEMM node which combines the MatMul and Add operations
gemm_node = helper.make_node(
"Gemm", # op_type
[matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs
[
matmul_node.input[0],
matmul_node.input[1],
bias_other_input_node_name,
], # inputs
[add_node.output[0]], # outputs
name="Gemm_Node",
alpha=1.0,
Expand Down Expand Up @@ -149,9 +153,14 @@ def get_equivalent_numpy_forward_from_torch(

arguments = list(inspect.signature(torch_module.forward).parameters)

if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.to("cpu")
else:
dummy_input = tuple(elt.to("cpu") for elt in dummy_input)

# Export to ONNX
torch.onnx.export(
torch_module,
torch_module.to("cpu"),
dummy_input,
str(output_onnx_file_path),
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
Expand Down
212 changes: 201 additions & 11 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
import brevitas.nn as qnn
import numpy
import torch
from brevitas.quant import Int8ActPerTensorFloat, Int8WeightPerTensorFloat, IntBias
from brevitas.core.restrict_val import FloatRestrictValue, RestrictValueType
from brevitas.quant import (
Int8AccumulatorAwareWeightQuant,
Int8AccumulatorAwareZeroCenterWeightQuant,
Int8ActPerTensorFloat,
Int8WeightPerTensorFloat,
IntBias,
Uint8ActPerTensorFloat,
)
from torch import nn
from torch.nn.utils import prune

Expand Down Expand Up @@ -38,7 +46,7 @@ def forward(self, x, y):
return x + y + self.value, (x - y) ** 2


class SimpleNet(torch.nn.Module):
class SimpleNet(nn.Module):
"""Fake torch model used to generate some onnx."""

def __init__(self) -> None:
Expand Down Expand Up @@ -292,7 +300,7 @@ def forward(self, x):
return x


class NetWithLoops(torch.nn.Module):
class NetWithLoops(nn.Module):
"""Torch model, where we reuse some elements in a loop.
Torch model, where we reuse some elements in a loop in the forward and don't expect the
Expand Down Expand Up @@ -538,7 +546,7 @@ def step(x, bias):
return x


class NetWithConcatUnsqueeze(torch.nn.Module):
class NetWithConcatUnsqueeze(nn.Module):
"""Torch model to test the concat and unsqueeze operators."""

def __init__(self, activation_function, input_output, n_fc_layers):
Expand Down Expand Up @@ -1004,6 +1012,7 @@ def __init__(self, use_conv, use_qat, inp_size, n_bits):
layer_obj = self.mixing_layer

layer_obj.weight.data = torch.from_numpy(np_weights).float()
assert layer_obj.bias is not None
layer_obj.bias.data = torch.rand(size=(1,))

def forward(self, x):
Expand Down Expand Up @@ -1216,12 +1225,12 @@ def forward(self, x):
# for example a 4d tensor NCHW, padded with [1, 2, 2, 3] is padded
# along the last 2 dimensions, with 1 cell to the left and 2 to the right (dimension 4: W)
# and 2 cells at the top and 3 at the bottom (dimension 3: H)
x = torch.nn.functional.pad(x, (3, 2))
x = torch.nn.functional.pad(x, (1, 2, 3, 4))
x = nn.functional.pad(x, (3, 2))
x = nn.functional.pad(x, (1, 2, 3, 4))

# Concrete ML only supports padding on the last two dimensions as this is the
# most common setting
x = torch.nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0))
x = nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0))
return x


Expand Down Expand Up @@ -1340,7 +1349,12 @@ class ConcatFancyIndexing(nn.Module):
"""Concat with fancy indexing."""

def __init__(
self, input_shape, hidden_shape, output_shape, n_bits: int = 4, n_blocks: int = 3
self,
input_shape,
hidden_shape,
output_shape,
n_bits: int = 4,
n_blocks: int = 3,
) -> None:
"""Torch Model.
Expand All @@ -1361,7 +1375,10 @@ def __init__(

self.quant_2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc2 = qnn.QuantLinear(
hidden_shape * self.n_blocks, hidden_shape, bias=True, weight_bit_width=n_bits
hidden_shape * self.n_blocks,
hidden_shape,
bias=True,
weight_bit_width=n_bits,
)

self.quant_3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
Expand Down Expand Up @@ -1393,7 +1410,7 @@ def forward(self, x):
return x


class PartialQATModel(torch.nn.Module):
class PartialQATModel(nn.Module):
"""A model with a QAT Module."""

def __init__(self, input_shape: int, output_shape: int, n_bits: int):
Expand Down Expand Up @@ -1442,7 +1459,7 @@ def forward(self, input1):
return output


class ManualLogisticRegressionTraining(torch.nn.Module):
class ManualLogisticRegressionTraining(nn.Module):
"""PyTorch module for performing SGD training."""

def __init__(self, learning_rate=0.1):
Expand Down Expand Up @@ -1665,3 +1682,176 @@ def forward(self, x):
x = self.relu(x)
x = self.linear(x)
return x


# pylint: disable-next=too-many-ancestors
class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):
"""CommonIntWeightPerChannelQuant."""

scaling_per_output_channel = True


# pylint: disable-next=too-many-ancestors
class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
"""CommonIntAccumulatorAwareWeightQuant."""

restrict_scaling_impl = FloatRestrictValue # backwards compatibility
bit_width = None


# pylint: disable-next=too-many-ancestors
class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant):
"""CommonIntAccumulatorAwareZeroCenterWeightQuant."""

bit_width = None


# pylint: disable-next=too-many-ancestors
class CommonUintActQuant(Uint8ActPerTensorFloat):
"""CommonUintActQuant."""

bit_width = None
restrict_scaling_type = RestrictValueType.LOG_FP


def weight_init(layer: nn.Module):
"""Initialize layer weights.
Arguments:
layer (nn.Module): a conv2d layer
"""

if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, nn.init.calculate_gain("relu"))
if layer.bias is not None:
layer.bias.data.zero_()


# pylint: disable-next=too-many-instance-attributes
class FloatLeNet(nn.Module):
"""Floating point LeNet."""

def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu1 = nn.ReLU(inplace=True)

self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu2 = nn.ReLU(inplace=True)

self.fc1 = nn.Linear(400, 120, bias=True)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84, bias=True)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 10, bias=True)

self.apply(weight_init)

def forward(self, x: torch.Tensor):
"""Forward function.
Arguments:
x (torch.Tensor): input image
Returns:
Neural network prediction
"""
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = torch.flatten(x, 1)
x = self.relu3(self.fc1(x))
x = self.relu4(self.fc2(x))
x = self.fc3(x)
return x


# pylint: disable-next=too-many-instance-attributes
class QuantLeNet(FloatLeNet):
"""Quantized LeNet with per-channel quantization."""

def __init__(
self,
weight_bit_width=4,
act_bit_width=4,
acc_bit_width=32,
weight_quant=CommonIntAccumulatorAwareWeightQuant,
):
super().__init__()

self.conv1 = qnn.QuantConv2d(
bias=False,
in_channels=1,
out_channels=6,
kernel_size=5,
stride=1,
padding=0,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu1 = qnn.QuantReLU(
inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width
)

self.conv2 = qnn.QuantConv2d(
bias=False,
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu2 = qnn.QuantReLU(
inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width
)

self.fc1 = qnn.QuantLinear(
400,
120,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.relu3 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width)
self.fc2 = qnn.QuantLinear(
120,
84,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.relu4 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width)
self.fc3 = qnn.QuantLinear(
84,
10,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)

self.apply(weight_init)
9 changes: 7 additions & 2 deletions src/concrete/ml/quantization/base_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Tuple, Type, Union, cast

import numpy
import numpy.typing as npt

from concrete import fhe

Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
input_quant_opts: Optional[QuantizationOptions] = None,
**attrs,
) -> None:

self.n_bits = n_bits_output

if input_quant_opts is not None:
Expand Down Expand Up @@ -916,7 +918,7 @@ def can_fuse(self) -> bool:
def make_output_quant_parameters(
self,
q_values: Union[numpy.ndarray, Any],
scale: numpy.float64,
scale: npt.NDArray[numpy.float64],
zero_point: Union[int, float, numpy.ndarray],
) -> QuantizedArray:
"""Build a quantized array from quantized integer results of the op and quantization params.
Expand Down Expand Up @@ -1016,6 +1018,9 @@ def cnp_round(
# Rounding to low bit-width with approximate can cause issues with overflow protection
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4345
x = fhe.round_bit_pattern(
x, lsbs_to_remove=lsbs_value, exactness=exactness, overflow_protection=False
x,
lsbs_to_remove=lsbs_value,
exactness=exactness,
overflow_protection=False,
)
return x
Loading

0 comments on commit a9ac260

Please sign in to comment.