Skip to content

Commit

Permalink
Add early transformers generator
Browse files Browse the repository at this point in the history
  • Loading branch information
monoxgas committed May 20, 2024
1 parent 50314f6 commit f440c2d
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ litellm = "1.35.21"
pandas = "^2.2.2"

vllm = { version = "0.4.2", optional = true }
transformers = { version = "^4.41.0", optional = true }
accelerate = { version = "^0.30.1", optional = true }

[tool.poetry.extras]
all = ["vllm"]
all = ["vllm", "transformers", "accelerate"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.27.1"
Expand Down
6 changes: 6 additions & 0 deletions rigging/generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
except ImportError:
pass

try:
from rigging.generator.transformers_ import TransformersGenerator # noqa: F401
except ImportError:
pass

__all__ = [
"get_generator",
"Generator",
Expand All @@ -29,4 +34,5 @@
"get_identifier",
"LiteLLMGenerator",
"VLLMGenerator",
"TransformersGenerator",
]
3 changes: 3 additions & 0 deletions rigging/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class GenerateParams(BaseModel):
max_tokens: int | None = None
"""The maximum number of tokens to generate."""

top_k: int | None = None
"""The top-k sampling parameter."""

top_p: float | None = None
"""The nucleus sampling probability."""

Expand Down
130 changes: 130 additions & 0 deletions rigging/generator/transformers_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import typing as t

from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline, pipeline

from rigging.generator.base import GenerateParams, Generator, register_generator, trace_messages, trace_str
from rigging.message import Message


class TransformersGenerator(Generator):
"""
Generator backed by the Transformers library for local model loading.
Warning:
The use of Transformers requires the `transformers` package to be installed directly or by
installing rigging as `rigging[all]`.
Note:
The async methods currently just call synchronous variants for compatibility.
Note:
The model load into memory will occur lazily when the first generation is requested.
If you'd want to force this to happen earlier, you can access the
[`.llm`][rigging.generator.vllm_.TransformersGenerator.llm] property on this class before first use.
"""

torch_dtype: str = "auto"
"""Torch dtype passed to [`AutoModelForCausalLM.from_pretrained`]()"""
device_map: str = "auto"
"""Device map passed to [`AutoModelForCausalLM.from_pretrained`]()"""
trust_remote_code: bool = False
"""Trust remote code passed to [`AutoModelForCausalLM.from_pretrained`]()"""

_llm: AutoModelForCausalLM | None = None
_tokenizer: AutoTokenizer | None = None
_pipeline: TextGenerationPipeline | None = None

@property
def llm(self) -> AutoModelForCausalLM:
"""The underlying [`AutoModelForCausalLM`]() instance."""
# Lazy initialization
if self._llm is None:
self._llm = AutoModelForCausalLM.from_pretrained(
self.model,
device_map=self.device_map,
torch_dtype=self.torch_dtype,
trust_remote_code=self.trust_remote_code,
)
return self._llm

@property
def tokenizer(self) -> AutoTokenizer:
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
return self._tokenizer

@property
def pipeline(self) -> TextGenerationPipeline:
if self._pipeline is None:
self._pipeline = pipeline(
"text-generation",
return_full_text=False,
model=self.llm,
tokenizer=self.tokenizer,
)
return self._pipeline

def _generate(
self,
inputs: t.Sequence[str] | t.Sequence[Message],
params: t.Sequence[GenerateParams],
) -> list[str]:
param_set = {p.model_dump_json() for p in params}
if len(param_set) != 1:
raise ValueError("All GenerateParams must be identical for this generator")

# Generation Args + Fixups

kwargs = self.params.merge_with(params[0]).to_dict()
if "max_tokens" in kwargs:
kwargs["max_new_tokens"] = kwargs.pop("max_tokens")
if any(k in kwargs for k in ["temperature", "top_k", "top_p"]):
kwargs["do_sample"] = True

outputs = self.pipeline(inputs, **kwargs)
return [output[-1]["generated_text"].strip() for output in outputs]

def generate_messages(
self,
messages: t.Sequence[t.Sequence[Message]],
params: t.Sequence[GenerateParams],
) -> t.Sequence[Message]:
message_dicts = [[m.model_dump(include={"role", "content"}) for m in _messages] for _messages in messages]
outputs = self._generate(message_dicts, params)
generated = [Message(role="assistant", content=output) for output in outputs]

for i, (in_messages, out_message) in enumerate(zip(messages, generated, strict=True)):
trace_messages(in_messages, f"Messages {i+1}/{len(in_messages)}")
trace_messages([out_message], f"Response {i+1}/{len(in_messages)}")

return generated

async def agenerate_messages(
self,
messages: t.Sequence[t.Sequence[Message]],
params: t.Sequence[GenerateParams],
) -> t.Sequence[Message]:
return self.generate_messages(messages, params)

def generate_texts(
self,
texts: t.Sequence[str],
params: t.Sequence[GenerateParams],
) -> t.Sequence[str]:
generated = self._generate(texts=list(texts), params=params)

for i, (text, response) in enumerate(zip(texts, generated, strict=True)):
trace_str(text, f"Text {i+1}/{len(texts)}")
trace_str(response, f"Generated {i+1}/{len(texts)}")

return generated

async def agenerate_texts(
self,
texts: t.Sequence[str],
params: t.Sequence[GenerateParams],
) -> t.Sequence[str]:
return self.generate_texts(texts, params)


register_generator("transformers", TransformersGenerator)

0 comments on commit f440c2d

Please sign in to comment.