diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 6c8ff875ba4..d3c38eb682a 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -171,7 +171,7 @@ def __init__(self, model, config_list, optimizer=None): self.bound_model.to(self.device) def validate_config(self, model, config_list): - schema = CompressorSchema([{ + schema = QuantizerSchema([{ Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]), Optional('quant_bits'): Or(And(int, lambda n: n == 8), Schema({ Optional('weight'): And(int, lambda n: n == 8),