Skip to content

Commit

Permalink
Different handling of state dict
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jul 6, 2023
1 parent 2855743 commit ada8404
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
14 changes: 13 additions & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing_extensions import Protocol
from typing_extensions import runtime_checkable

from brevitas import config
from brevitas.function import max_int
from brevitas.quant_tensor import QuantTensor

Expand Down Expand Up @@ -49,12 +50,23 @@ class ParameterQuantProxyFromInjector(QuantProxyFromInjector):
def tracked_parameter_list(self):
pass

def init_tensor_quant(self):
def init_tensor_quant(self, preserve_state_dict=False):
param_list = self.tracked_parameter_list

# params might not be there yet, e.g. bias before merging
if param_list:
if preserve_state_dict:
reinit_on_state_dict = config.REINIT_ON_STATE_DICT_LOAD
ignore_missing_key = config.IGNORE_MISSING_KEYS
config.REINIT_ON_STATE_DICT_LOAD = False
config.IGNORE_MISSING_KEYS = True
state_dict = self.state_dict()
self.quant_injector = self.quant_injector.let(tracked_parameter_list=param_list)
super(ParameterQuantProxyFromInjector, self).init_tensor_quant()
if preserve_state_dict:
self.load_state_dict(state_dict)
config.IGNORE_MISSING_KEYS = ignore_missing_key
config.REINIT_ON_STATE_DICT_LOAD = reinit_on_state_dict

def max_uint_value(self, bit_width):
return max_int(False, self.is_narrow_range, bit_width)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,13 @@ def insert_learned_round_quantizer(layer, learned_round_zeta=1.1, learned_round_
delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight
value = -torch.log((learned_round_zeta - learned_round_gamma) /
(delta - learned_round_gamma) - 1)
state_dict = layer.weight_quant.state_dict()
layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let(
float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND,
learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID,
learned_round_gamma=learned_round_gamma,
learned_round_zeta=learned_round_zeta,
learned_round_init=value)
layer.weight_quant.init_tensor_quant()
layer.weight_quant.load_state_dict(state_dict)
layer.weight_quant.init_tensor_quant(preserve_state_dict=True)


def split_layers(model, blocks):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def apply_learned_round_learning(
print(len(blocks))

for block, get_inp_out, layer_loss, learned_round_module in block_wise_learned_round_iterator(model, blocks, iters=iters):
optimizer = optimizer_class(list(learned_round_module.parameters()), **optimizer_kwargs)
optimizer = optimizer_class(learned_round_module.parameters(), **optimizer_kwargs)
pbar = tqdm(range(epochs), desc='')
for e in pbar:
for i, (img, t) in enumerate(dataloader):
Expand Down

0 comments on commit ada8404

Please sign in to comment.