forked from meta-llama/llama-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
75 lines (56 loc) · 2.82 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import fire
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
from typing import List, Tuple
from enum import Enum
class AgentType(Enum):
AGENT = "Agent"
USER = "User"
def main(
model_id: str = "meta-llama/LlamaGuard-7b",
llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_1
):
"""
Entry point for Llama Guard inference sample script.
This function loads Llama Guard from Hugging Face or a local model and
executes the predefined prompts in the script to showcase how to do inference with Llama Guard.
Args:
model_id (str): The ID of the pretrained model to use for generation. This can be either the path to a local folder containing the model files,
or the repository ID of a model hosted on the Hugging Face Hub. Defaults to 'meta-llama/LlamaGuard-7b'.
llama_guard_version (LlamaGuardVersion): The version of the Llama Guard model to use for formatting prompts. Defaults to LLAMA_GUARD_1.
"""
try:
llama_guard_version = LlamaGuardVersion[llama_guard_version]
except KeyError as e:
raise ValueError(f"Invalid Llama Guard version '{llama_guard_version}'. Valid values are: {', '.join([lgv.name for lgv in LlamaGuardVersion])}") from e
prompts: List[Tuple[List[str], AgentType]] = [
(["<Sample user prompt>"], AgentType.USER),
(["<Sample user prompt>",
"<Sample agent response>"], AgentType.AGENT),
(["<Sample user prompt>",
"<Sample agent response>",
"<Sample user reply>",
"<Sample agent response>",], AgentType.AGENT),
]
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
for prompt in prompts:
formatted_prompt = build_default_prompt(
prompt[1],
create_conversation(prompt[0]),
llama_guard_version)
input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
prompt_len = input["input_ids"].shape[-1]
output = model.generate(**input, max_new_tokens=100, pad_token_id=0)
results = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
print(prompt[0])
print(f"> {results}")
print("\n==================================\n")
if __name__ == "__main__":
try:
fire.Fire(main)
except Exception as e:
print(e)