-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add local agent #23438
Add local agent #23438
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,8 @@ | |
import requests | ||
from huggingface_hub import HfFolder, hf_hub_download, list_spaces | ||
|
||
from ..generation import StoppingCriteria, StoppingCriteriaList | ||
from ..models.auto import AutoModelForCausalLM, AutoTokenizer | ||
from ..utils import is_openai_available, logging | ||
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote | ||
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE | ||
|
@@ -492,3 +494,114 @@ def generate_one(self, prompt, stop): | |
if result.endswith(stop_seq): | ||
return result[: -len(stop_seq)] | ||
return result | ||
|
||
|
||
class LocalAgent(Agent): | ||
""" | ||
Agent that uses a local model and tokenizer to generate code. | ||
|
||
Args: | ||
model ([`PreTrainedModel`]): | ||
The model to use for the agent. | ||
tokenizer ([`PreTrainedTokenizer`]): | ||
The tokenizer to use for the agent. | ||
chat_prompt_template (`str`, *optional*): | ||
Pass along your own prompt if you want to override the default template for the `chat` method. | ||
run_prompt_template (`str`, *optional*): | ||
Pass along your own prompt if you want to override the default template for the `run` method. | ||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): | ||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as | ||
one of the default tools, that default tool will be overridden. | ||
|
||
Example: | ||
|
||
```py | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent | ||
|
||
checkpoint = "bigcode/starcoder" | ||
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16) | ||
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | ||
|
||
agent = LocalAgent(model, tokenizer) | ||
agent.run("Draw me a picture of rivers and lakes.") | ||
``` | ||
""" | ||
|
||
def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None): | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
super().__init__( | ||
chat_prompt_template=chat_prompt_template, | ||
run_prompt_template=run_prompt_template, | ||
additional_tools=additional_tools, | ||
) | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | ||
""" | ||
Convenience method to build a `LocalAgent` from a pretrained checkpoint. | ||
|
||
Args: | ||
pretrained_model_name_or_path (`str` or `os.PathLike`): | ||
The name of a repo on the Hub or a local path to a folder containing both model and tokenizer. | ||
kwargs: | ||
Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`]. | ||
|
||
Example: | ||
|
||
```py | ||
import torch | ||
from transformers import LocalAgent | ||
|
||
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16) | ||
agent.run("Draw me a picture of rivers and lakes.") | ||
``` | ||
""" | ||
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
return cls(model, tokenizer) | ||
|
||
@property | ||
def _model_device(self): | ||
if hasattr(self.model, "hf_device_map"): | ||
return list(self.model.hf_device_map.values())[0] | ||
for param in self.mode.parameters(): | ||
return param.device | ||
|
||
def generate_one(self, prompt, stop): | ||
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device) | ||
src_len = encoded_inputs["input_ids"].shape[1] | ||
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)]) | ||
outputs = self.model.generate( | ||
encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria | ||
) | ||
|
||
result = self.tokenizer.decode(outputs[0].tolist()[src_len:]) | ||
# Inference API returns the stop sequence | ||
for stop_seq in stop: | ||
if result.endswith(stop_seq): | ||
result = result[: -len(stop_seq)] | ||
return result | ||
|
||
|
||
class StopSequenceCriteria(StoppingCriteria): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you can move it as you add support for the batched case? I don't need it for the agents and it's not an obvious thing to do (that's the reason I didn't put this in the stopping criteria file by the way). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I can take care of it afterwards 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! |
||
""" | ||
This class can be used to stop generation whenever a sequence of tokens is encountered. | ||
|
||
Args: | ||
stop_sequences (`str` or `List[str]`): | ||
The sequence (or list of sequences) on which to stop execution. | ||
tokenizer: | ||
The tokenizer used to decode the model outputs. | ||
""" | ||
|
||
def __init__(self, stop_sequences, tokenizer): | ||
if isinstance(stop_sequences, str): | ||
stop_sequences = [stop_sequences] | ||
self.stop_sequences = stop_sequences | ||
self.tokenizer = tokenizer | ||
|
||
def __call__(self, input_ids, scores, **kwargs) -> bool: | ||
decoded_output = self.tokenizer.decode(input_ids.tolist()[0]) | ||
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice shortcut!