Skip to content

Commit

Permalink
Fix (core): po2 for float quantization (#1033)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Oct 8, 2024
1 parent 4d8b153 commit 33de963
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 85 deletions.
11 changes: 6 additions & 5 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ def __init__(

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scale / float_scaling_impl_value
x = self.input_view_impl(x)
scaled_x = x / scale
internal_scale = float_internal_scale(
Expand All @@ -85,7 +81,12 @@ def dequantize(self, y, scale):

@brevitas.jit.script_method
def forward(self, x):
scale = self.scaling_impl(x)
if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
else:
float_scaling_impl_value = None
scale = self.scaling_impl(x, float_scaling_impl_value)
if self.observer_only:
y = x
saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values
Expand Down
9 changes: 3 additions & 6 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,8 @@ def __init__(
@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
bit_width = self.msb_clamp_bit_width_impl()
threshold = self.scaling_impl(x)
int_threshold = self.int_scaling_impl(bit_width)
scale = threshold / int_threshold
scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
if self.observer_only:
y = x
Expand Down Expand Up @@ -189,8 +188,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te
pre_threshold = self.pre_scaling_impl(x)
pre_scale = pre_threshold / int_threshold
pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
if self.observer_only:
y = x
Expand Down Expand Up @@ -258,8 +256,7 @@ def forward(self, x: Tensor, input_bit_width: Tensor,
pre_threshold = self.pre_scaling_impl(x, input_bit_width, input_is_signed)
pre_scale = pre_threshold / int_threshold
pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
if self.observer_only:
y = x
Expand Down
32 changes: 22 additions & 10 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Option
self.restrict_value_impl = Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.restrict_value_impl(x)
x = self.clamp_min_ste(x)
return x
Expand All @@ -52,7 +52,7 @@ def __init__(self, restrict_value_impl: Optional[Module]):
self.restrict_value_impl = Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.restrict_value_impl(x)
return x

Expand All @@ -68,7 +68,7 @@ def __init__(self, scaling_min_val: Optional[float]):
self.min_val = scaling_min_val

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.clamp_min_ste(x)
return x

Expand All @@ -90,8 +90,11 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> Tensor:
def forward(self, x: Tensor) -> Tensor:
return x


Expand All @@ -104,7 +107,7 @@ def __init__(self):
def restrict_init_float(self, x: float):
return math.log2(x)

def restrict_init_tensor(self, x: torch.Tensor):
def restrict_init_tensor(self, x: Tensor):
return torch.log2(x)

def restrict_init_module(self):
Expand All @@ -113,8 +116,11 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.power_of_two(x)
return x

Expand All @@ -128,7 +134,7 @@ def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()):
def restrict_init_float(self, x: float):
return x

def restrict_init_tensor(self, x: torch.Tensor):
def restrict_init_tensor(self, x: Tensor):
return x

def restrict_init_module(self):
Expand All @@ -137,8 +143,11 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
return x

Expand All @@ -153,7 +162,7 @@ def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()):
def restrict_init_float(self, x: float):
return math.log2(x)

def restrict_init_tensor(self, x: torch.Tensor):
def restrict_init_tensor(self, x: Tensor):
return torch.log2(x)

def restrict_init_module(self):
Expand All @@ -162,8 +171,11 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
x = self.power_of_two(x)
return x
47 changes: 31 additions & 16 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import brevitas.config as config
from brevitas.core.function_wrapper import Identity
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.stats import _ParameterListStats
from brevitas.core.stats import _RuntimeStats
from brevitas.core.stats import DEFAULT_MOMENTUM
Expand All @@ -27,8 +28,8 @@ def __init__(
scaling_stats_input_view_shape_impl: Module,
scaling_stats_input_concat_dim: int,
tracked_parameter_list: List[torch.nn.Parameter],
restrict_scaling_impl: Module,
scaling_shape: Tuple[int, ...],
restrict_scaling_impl: Module = FloatRestrictValue(),
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
scaling_min_val: Optional[float] = None,
Expand All @@ -51,9 +52,12 @@ def __init__(
device)

@brevitas.jit.script_method
def forward(self, ignored: torch.Tensor) -> torch.Tensor:
def forward(
self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.parameter_list_stats()
return self.stats_scaling_impl(stats)
if threshold is None:
threshold = torch.ones(1).type_as(stats)
return self.stats_scaling_impl(stats, threshold)


class _StatsScaling(brevitas.jit.ScriptModule):
Expand All @@ -78,10 +82,16 @@ def __init__(
self.affine_rescaling = Identity()
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_scaling_impl = restrict_scaling_impl

@brevitas.jit.script_method
def forward(self, stats: torch.Tensor) -> torch.Tensor:
def forward(
self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats)
threshold = self.restrict_scaling_pre(threshold)
stats = self.restrict_scaling_pre(stats)
stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
return stats
Expand All @@ -93,10 +103,10 @@ def __init__(
self,
scaling_stats_impl: Module,
scaling_stats_input_view_shape_impl: Module,
restrict_scaling_impl: Module,
scaling_shape: Tuple[int, ...],
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_stats_momentum: float = DEFAULT_MOMENTUM,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
Expand All @@ -120,9 +130,9 @@ def __init__(
device)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.runtime_stats(x)
return self.stats_scaling_impl(stats)
return self.stats_scaling_impl(stats, threshold)


class _AffineRescaling(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -163,13 +173,13 @@ def _load_from_state_dict(
class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule):

def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: torch.nn.Module,
scaling_stats_impl: torch.nn.Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Optional[torch.nn.Module]) -> None:
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
scaling_stats_impl: Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Module = FloatRestrictValue()) -> None:
super(RuntimeDynamicGroupStatsScaling, self).__init__()
self.group_size = group_size
self.group_dim = group_dim
Expand All @@ -179,9 +189,14 @@ def __init__(
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def forward(self, stats_input) -> torch.Tensor:
def forward(
self,
stats_input: torch.Tensor,
threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped)
out = self.scaling_stats_impl(stats_input_reshaped) / threshold
# Scaling min val
out = self.restrict_clamp_scaling(out)
return out
Loading

0 comments on commit 33de963

Please sign in to comment.