diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index acc6f41fa..fd474561f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -28,6 +28,27 @@ LOG = logging.getLogger("axolotl") +def check_model_config(cfg: DictDefault, model_config: AutoConfig): + quant_config_exists = hasattr(model_config, "quantization_config") + quant_config_method_is_gptq = ( + quant_config_exists + and "quant_method" in model_config.quantization_config + and model_config.quantization_config["quant_method"] == "gptq" + ) + + if cfg.gptq and not quant_config_method_is_gptq: + raise ValueError( + "model_config.quantization_config is not set or quant_method is not set to gptq. " + "Please make sure to point to a GPTQ model." + ) + + if not cfg.gptq and quant_config_exists: + raise ValueError( + "model_config.quantization_config is set but `gptq` flag is not. " + "Please use the `gptq` flag to train quantized model or point to a non-quantized model." + ) + + def load_model_config(cfg): model_config_name = cfg.base_model_config or cfg.base_model trust_remote_code = cfg.trust_remote_code is True @@ -38,6 +59,8 @@ def load_model_config(cfg): for key, val in cfg.model_config.items(): setattr(model_config, key, val) + check_model_config(cfg, model_config) + return model_config