-
Notifications
You must be signed in to change notification settings - Fork 656
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor Sequences and Generator pipeline into single LLM pipeline, c…
…loses #494
- Loading branch information
1 parent
a833eef
commit 847b7e4
Showing
7 changed files
with
179 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
""" | ||
LLM Module | ||
""" | ||
|
||
from ...models import Models | ||
|
||
from ..hfpipeline import HFPipeline | ||
|
||
|
||
class LLM(HFPipeline): | ||
""" | ||
Runs prompt text through a large language model (LLM). This pipeline autodetects if the input path is a text generation or | ||
sequence to sequence model. | ||
""" | ||
|
||
def __init__(self, path=None, quantize=False, gpu=True, model=None, task=None): | ||
super().__init__(self.task(path, task), path if path else "google/flan-t5-base", quantize, gpu, model) | ||
|
||
def __call__(self, text, prefix=None, maxlength=512, workers=0, **kwargs): | ||
""" | ||
Generates text using input text | ||
Args: | ||
text: text|list | ||
prefix: optional prefix to prepend to text elements | ||
maxlength: maximum sequence length | ||
workers: number of concurrent workers to use for processing data, defaults to None | ||
kwargs: additional generation keyword arguments | ||
Returns: | ||
generated text | ||
""" | ||
|
||
# List of texts | ||
texts = text if isinstance(text, list) else [text] | ||
|
||
# Add prefix, if necessary | ||
if prefix: | ||
texts = [f"{prefix}{x}" for x in texts] | ||
|
||
# Run pipeline | ||
results = self.pipeline(texts, max_length=maxlength, num_workers=workers, **kwargs) | ||
|
||
# Get generated text | ||
results = [self.clean(texts[x], result) for x, result in enumerate(results)] | ||
|
||
return results[0] if isinstance(text, str) else results | ||
|
||
def clean(self, prompt, result): | ||
""" | ||
Applies a series of rules to clean generated text. | ||
Args: | ||
prompt: original input prompt | ||
result: input result | ||
Returns: | ||
clean text | ||
""" | ||
|
||
# Extract output from list, if necessary | ||
result = result[0] if isinstance(result, list) else result | ||
|
||
# Get generated text field | ||
text = result["generated_text"] | ||
|
||
# Replace input prompt | ||
text = text.replace(prompt, "") | ||
|
||
# Apply text cleaning rules | ||
return text.replace("$=", "<=").strip() | ||
|
||
def task(self, path, task): | ||
""" | ||
Get the pipeline task name. | ||
Args: | ||
path: model path input | ||
task: task name | ||
Returns: | ||
pipeline task name | ||
""" | ||
|
||
# Mapping from txtai to Hugging Face pipeline tasks | ||
mapping = {"language-generation": "text-generation", "sequence-sequence": "text2text-generation"} | ||
|
||
# Attempt to resolve task | ||
if path and not task: | ||
task = Models.task(path) | ||
|
||
# Map to Hugging Face task. Default to text2text-generation pipeline when task not resolved. | ||
return mapping.get(task, "text2text-generation") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
LLM module tests | ||
""" | ||
|
||
import unittest | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from txtai.pipeline import LLM | ||
|
||
|
||
class TestLLM(unittest.TestCase): | ||
""" | ||
LLM tests. | ||
""" | ||
|
||
def testExternal(self): | ||
""" | ||
Test externally loaded model | ||
""" | ||
|
||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") | ||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") | ||
|
||
model = LLM((model, tokenizer)) | ||
start = "Hello, how are" | ||
|
||
# Test that text is generator | ||
self.assertGreater(len(model(start)), len(start)) |