Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hierarchical scales #1038

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/base.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/base_reduced.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/develop_install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/end_to_end.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/examples_llm_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/examples_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/finn_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/notebook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ort_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_develop_install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_end_to_end.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_examples_llm_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_examples_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_finn_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_notebook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_ort_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
1 change: 1 addition & 0 deletions src/brevitas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ def env_to_bool(name, default):
_FULL_STATE_DICT = False
_IS_INSIDE_QUANT_LAYER = None
_ONGOING_EXPORT = None
_RETROCOMPATIBLE_SCALING = False
5 changes: 3 additions & 2 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ def __init__(

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scale = self.scaling_impl(x)

if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scale / float_scaling_impl_value
else:
float_scaling_impl_value = None
scale = self.scaling_impl(x, float_scaling_impl_value)
x = self.input_view_impl(x)
scaled_x = x / scale
internal_scale = float_internal_scale(
Expand Down
17 changes: 10 additions & 7 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.nn import Module

import brevitas
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import round_ste
Expand Down Expand Up @@ -138,20 +139,24 @@ def __init__(
scaling_impl: Module,
int_scaling_impl: Module,
zero_point_impl: Module,
bit_width_impl: Module):
bit_width_impl: Module,
scaling_int_quant: Optional[Module] = None):
super(RescalingIntQuant, self).__init__()
self.int_quant = int_quant
self.scaling_impl = scaling_impl
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
if scaling_int_quant is None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused

self.scaling_int_quant = Identity()
else:
self.scaling_int_quant = scaling_int_quant

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
bit_width = self.msb_clamp_bit_width_impl()
threshold = self.scaling_impl(x)
int_threshold = self.int_scaling_impl(bit_width)
scale = threshold / int_threshold
scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.int_quant(scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width
Expand Down Expand Up @@ -184,8 +189,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te
pre_threshold = self.pre_scaling_impl(x)
pre_scale = pre_threshold / int_threshold
pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
Expand Down Expand Up @@ -250,8 +254,7 @@ def forward(self, x: Tensor, input_bit_width: Tensor,
pre_threshold = self.pre_scaling_impl(x, input_bit_width, input_is_signed)
pre_scale = pre_threshold / int_threshold
pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
39 changes: 39 additions & 0 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return x

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> Tensor:
return x
Expand All @@ -113,6 +116,9 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def retrocompatibility_op(self, x):
return self.power_of_two(x)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
x = self.power_of_two(x)
Expand All @@ -137,6 +143,9 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return x

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
x = self.float_to_int_impl(x)
Expand All @@ -162,8 +171,38 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def retrocompatibility_op(self, x):
return self.power_of_two(x)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
x = self.float_to_int_impl(x)
x = self.power_of_two(x)
return x


class QuantRestrictValue(brevitas.jit.ScriptModule):

def __init__(self, restrict_value_float_to_int_impl: Module):
super(QuantRestrictValue, self).__init__()
self.float_to_int_impl = restrict_value_float_to_int_impl

def restrict_init_float(self, x: float):
return Identity()

def restrict_init_tensor(self, x: torch.Tensor):
return Identity()

def restrict_init_module(self):
return Identity()

def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
o, *_ = self.float_to_int_impl(x)
return o
29 changes: 19 additions & 10 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ def __init__(
dtype,
device)

@brevitas.jit.script_method
def forward(self, ignored: torch.Tensor) -> torch.Tensor:
stats = self.parameter_list_stats()
return self.stats_scaling_impl(stats)
def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.parameter_list_stats(x)
if threshold is None:
threshold = torch.ones(1).type_as(stats)
return self.stats_scaling_impl(stats, threshold)


class _StatsScaling(brevitas.jit.ScriptModule):
Expand All @@ -80,8 +81,11 @@ def __init__(
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()

@brevitas.jit.script_method
def forward(self, stats: torch.Tensor) -> torch.Tensor:
stats = self.restrict_scaling_pre(stats)
def forward(
self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats)
stats = self.restrict_scaling_pre(stats / threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
return stats
Expand Down Expand Up @@ -120,9 +124,9 @@ def __init__(
device)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.runtime_stats(x)
return self.stats_scaling_impl(stats)
return self.stats_scaling_impl(stats, threshold)


class _AffineRescaling(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -179,9 +183,14 @@ def __init__(
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def forward(self, stats_input) -> torch.Tensor:
def forward(
self,
stats_input: torch.Tensor,
threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped)
out = self.scaling_stats_impl(stats_input_reshaped) / threshold
# Scaling min val
out = self.restrict_clamp_scaling(out)
return out
Loading
Loading