diff --git a/basaran/model.py b/basaran/model.py index bfc7be0a..b681cf30 100644 --- a/basaran/model.py +++ b/basaran/model.py @@ -323,8 +323,8 @@ def load_model( kwargs["device_map"] = "auto" kwargs["load_in_8bit"] = load_in_8bit - # Override the dtype to float16 as required by bitsandbytes. - if load_in_8bit: + # Cast all parameters to float16 if quantization is enabled. + if half_precision or load_in_8bit: kwargs["torch_dtype"] = torch.float16 # Support both decoder-only and encoder-decoder models. @@ -333,10 +333,6 @@ def load_model( except ValueError: model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path, **kwargs) - # Cast all parameters to half-precision if required. - if half_precision: - model = model.half() - # Check if the model has text generation capabilities. if not model.can_generate(): raise TypeError(f"{name_or_path} is not a text generation model")