From 49fd7f2572a0a00c8f5cfbdd2ce6184269843f01 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 4 Dec 2023 23:09:28 +0900 Subject: [PATCH 1/3] feat: add check for quantized model --- src/axolotl/utils/models.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index acc6f41fa..564c4c033 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -38,6 +38,25 @@ def load_model_config(cfg): for key, val in cfg.model_config.items(): setattr(model_config, key, val) + if ( + hasattr(model_config, "quantization_config") + and model_config.quantization_config + ): + if not cfg.gptq: + raise ValueError( + "model_config.quantization_config is set but gptq is not. " + "Please use the gptq flag to train quantized model or point to a non-quantized model." + ) + + if ( + hasattr(model_config.quantization_config, "quant_method") + and model_config.quantization_config.quant_method != "gptq" + ): + raise ValueError( + "model_config.quantization_config.quant_method is not set to gptq." + "Please make sure to point to a GPTQ model." + ) + return model_config From 0009f5099f6c47bf30386a54ee2cbda9714fedca Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 4 Dec 2023 23:20:07 +0900 Subject: [PATCH 2/3] chore: refactor and add another check --- src/axolotl/utils/models.py | 40 ++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 564c4c033..7bb0a079a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -28,6 +28,27 @@ LOG = logging.getLogger("axolotl") +def check_model_config(cfg: DictDefault, model_config: AutoConfig): + quant_config_exists = hasattr(model_config, "quantization_config") + quant_config_method_is_gptq = ( + quant_config_exists + and hasattr(model_config.quantization_config, "quant_method") + and model_config.quantization_config.quant_method == "gptq" + ) + + if cfg.gptq and not quant_config_method_is_gptq: + raise ValueError( + "model_config.quantization_config is not set or quant_method is not set to gptq. " + "Please make sure to point to a GPTQ model." + ) + + if not cfg.gptq and quant_config_exists: + raise ValueError( + "model_config.quantization_config is set but `gptq` flag is not. " + "Please use the `gptq` flag to train quantized model or point to a non-quantized model." + ) + + def load_model_config(cfg): model_config_name = cfg.base_model_config or cfg.base_model trust_remote_code = cfg.trust_remote_code is True @@ -38,24 +59,7 @@ def load_model_config(cfg): for key, val in cfg.model_config.items(): setattr(model_config, key, val) - if ( - hasattr(model_config, "quantization_config") - and model_config.quantization_config - ): - if not cfg.gptq: - raise ValueError( - "model_config.quantization_config is set but gptq is not. " - "Please use the gptq flag to train quantized model or point to a non-quantized model." - ) - - if ( - hasattr(model_config.quantization_config, "quant_method") - and model_config.quantization_config.quant_method != "gptq" - ): - raise ValueError( - "model_config.quantization_config.quant_method is not set to gptq." - "Please make sure to point to a GPTQ model." - ) + check_model_config(cfg, model_config) return model_config From 85a43d5ef1dd72d12db61d75f318d66aba387778 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Dec 2023 10:06:08 -0500 Subject: [PATCH 3/3] Update src/axolotl/utils/models.py --- src/axolotl/utils/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7bb0a079a..fd474561f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -32,8 +32,8 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig): quant_config_exists = hasattr(model_config, "quantization_config") quant_config_method_is_gptq = ( quant_config_exists - and hasattr(model_config.quantization_config, "quant_method") - and model_config.quantization_config.quant_method == "gptq" + and "quant_method" in model_config.quantization_config + and model_config.quantization_config["quant_method"] == "gptq" ) if cfg.gptq and not quant_config_method_is_gptq: