Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def local_config(self):
def local_config(self, config):
self._local_config = config

def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig:
def set_local(self, operator_name: Union[str, Callable], config: BaseConfig) -> BaseConfig:
if operator_name in self.local_config:
logger.warning("The configuration for %s has already been set, update it.", operator_name)
self.local_config[operator_name] = config
Expand Down Expand Up @@ -392,14 +392,16 @@ def _get_op_name_op_type_config(self):
op_name_config_dict = dict()
for name, config in self.local_config.items():
if self._is_op_type(name):
op_type_config_dict[name] = config
# Convert the Callable to String.
new_name = self._op_type_to_str(name)
op_type_config_dict[new_name] = config
else:
op_name_config_dict[name] = config
return op_type_config_dict, op_name_config_dict

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDict[Union[str, Callable], OrderedDict[str, BaseConfig]]:
) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]:
config_mapping = OrderedDict()
if config_list is None:
config_list = [self]
Expand All @@ -416,6 +418,14 @@ def to_config_mapping(
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
return config_mapping

@staticmethod
def _op_type_to_str(op_type: Callable) -> str:
# * Ort and TF may override this method.
op_type_name = getattr(op_type, "__name__", "")
if op_type_name == "":
logger.warning("The op_type %s has no attribute __name__.", op_type)
return op_type_name

@staticmethod
def _is_op_type(name: str) -> bool:
# * Ort and TF may override this method.
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ def set_module(model, op_name, new_module):
setattr(second_last_module, name_list[-1], new_module)


def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, Callable]]:
def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, str]]:
module_dict = dict(model.named_modules())
filter_result = []
filter_result_set = set()
for op_name, module in module_dict.items():
if isinstance(module, tuple(white_module_list)):
pair = (op_name, type(module))
pair = (op_name, type(module).__name__)
if pair not in filter_result_set:
filter_result_set.add(pair)
filter_result.append(pair)
Expand Down
49 changes: 49 additions & 0 deletions test/3x/common/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,29 @@ def __repr__(self) -> str:
return "FakeModel"


class FakeOpType:
def __init__(self) -> None:
self.name = "fake_module"

def __call__(self, x) -> Any:
return x

def __repr__(self) -> str:
return "FakeModule"


class OP_TYPE1(FakeOpType):
pass


class OP_TYPE2(FakeOpType):
pass


def build_simple_fake_model():
return FakeModel()


@register_config(framework_name=FAKE_FRAMEWORK_NAME, algo_name=FAKE_CONFIG_NAME, priority=PRIORITY_FAKE_ALGO)
class FakeAlgoConfig(BaseConfig):
"""Config class for fake algo."""
Expand Down Expand Up @@ -257,6 +280,32 @@ def test_mixed_two_algos(self):
self.assertIn(OP1_NAME, [op_info[0] for op_info in config_mapping])
self.assertIn(OP2_NAME, [op_info[0] for op_info in config_mapping])

def test_set_local_op_name(self):
quant_config = FakeAlgoConfig(weight_bits=4)
# set `OP1_NAME`
fc1_config = FakeAlgoConfig(weight_bits=6)
quant_config.set_local("OP1_NAME", fc1_config)
model_info = FAKE_MODEL_INFO
logger.info(quant_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("OP1_NAME", "OP_TYPE1")].weight_bits == 6)
self.assertTrue(configs_mapping[("OP2_NAME", "OP_TYPE1")].weight_bits == 4)
self.assertTrue(configs_mapping[("OP3_NAME", "OP_TYPE2")].weight_bits == 4)

def test_set_local_op_type(self):
quant_config = FakeAlgoConfig(weight_bits=4)
# set all `OP_TYPE1`
fc1_config = FakeAlgoConfig(weight_bits=6)
quant_config.set_local(OP_TYPE1, fc1_config)
model_info = FAKE_MODEL_INFO
logger.info(quant_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("OP1_NAME", "OP_TYPE1")].weight_bits == 6)
self.assertTrue(configs_mapping[("OP2_NAME", "OP_TYPE1")].weight_bits == 6)
self.assertTrue(configs_mapping[("OP3_NAME", "OP_TYPE2")].weight_bits == 4)


class TestConfigSet(unittest.TestCase):
def setUp(self):
Expand Down
29 changes: 22 additions & 7 deletions test/3x/torch/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def test_config_white_lst2(self):
logger.info(quant_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 6)
self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 4)
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6)
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 4)

def test_config_from_dict(self):
quant_config = {
Expand Down Expand Up @@ -253,16 +253,31 @@ def test_config_mapping(self):
logger.info(quant_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 6)
self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 4)
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6)
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 4)
# test regular matching
fc_config = RTNConfig(bits=5, dtype="int8")
quant_config.set_local("fc", fc_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 5)
self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 5)
self.assertTrue(configs_mapping[("fc3", torch.nn.Linear)].bits == 5)
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 5)
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 5)
self.assertTrue(configs_mapping[("fc3", "Linear")].bits == 5)

def test_set_local_op_type(self):
quant_config = RTNConfig(bits=4, dtype="nf4")
# set all `Linear`
fc1_config = RTNConfig(bits=6, dtype="int8")
quant_config.set_local(torch.nn.Linear, fc1_config)
# get model and quantize
fp32_model = build_simple_torch_model()
model_info = get_model_info(fp32_model, white_module_list=[torch.nn.Linear])
logger.info(quant_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6)
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 6)
self.assertTrue(configs_mapping[("fc3", "Linear")].bits == 6)

def test_gptq_config(self):
gptq_config1 = GPTQConfig(bits=8, act_order=True)
Expand Down