Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
monoxgas committed May 21, 2024
2 parents eed942f + 6257b82 commit aae3055
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions rigging/generator/transformers_.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class TransformersGenerator(Generator):
"""Device map passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)"""
trust_remote_code: bool = False
"""Trust remote code passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)"""
load_in_8bit: bool = False
"""Load in 8 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)"""
load_in_4bit: bool = False
"""Load in 4 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)"""

_llm: AutoModelForCausalLM | None = None
_tokenizer: AutoTokenizer | None = None
Expand All @@ -54,12 +58,11 @@ def llm(self) -> AutoModelForCausalLM:
"""The underlying `AutoModelForCausalLM` instance."""
# Lazy initialization
if self._llm is None:
self._llm = AutoModelForCausalLM.from_pretrained(
self.model,
device_map=self.device_map,
torch_dtype=self.torch_dtype,
trust_remote_code=self.trust_remote_code,
llm_kwargs = self.model_dump(
exclude_unset=True,
include={"torch_dtype", "device_map", "trust_remote_code", "load_in_8bit", "load_in_4bit"},
)
self._llm = AutoModelForCausalLM.from_pretrained(self.model, **llm_kwargs)
return self._llm

@property
Expand Down

0 comments on commit aae3055

Please sign in to comment.