diff --git a/rigging/generator/transformers_.py b/rigging/generator/transformers_.py index af3ef44..480a82f 100644 --- a/rigging/generator/transformers_.py +++ b/rigging/generator/transformers_.py @@ -44,6 +44,10 @@ class TransformersGenerator(Generator): """Device map passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)""" trust_remote_code: bool = False """Trust remote code passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)""" + load_in_8bit: bool = False + """Load in 8 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)""" + load_in_4bit: bool = False + """Load in 4 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)""" _llm: AutoModelForCausalLM | None = None _tokenizer: AutoTokenizer | None = None @@ -54,12 +58,11 @@ def llm(self) -> AutoModelForCausalLM: """The underlying `AutoModelForCausalLM` instance.""" # Lazy initialization if self._llm is None: - self._llm = AutoModelForCausalLM.from_pretrained( - self.model, - device_map=self.device_map, - torch_dtype=self.torch_dtype, - trust_remote_code=self.trust_remote_code, + llm_kwargs = self.model_dump( + exclude_unset=True, + include={"torch_dtype", "device_map", "trust_remote_code", "load_in_8bit", "load_in_4bit"}, ) + self._llm = AutoModelForCausalLM.from_pretrained(self.model, **llm_kwargs) return self._llm @property