-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
130 lines (105 loc) · 3.61 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright © 2024 Apple Inc.
import argparse
import codecs
from pathlib import Path
import mlx.core as mx
import requests
from PIL import Image
from transformers import AutoProcessor
from llava import LlavaModel
def parse_arguments():
parser = argparse.ArgumentParser(
description="Generate text from an image using a model."
)
parser.add_argument(
"--model",
type=str,
default="llava-hf/llava-1.5-7b-hf",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--image",
type=str,
default="http://images.cocodataset.org/val2017/000000039769.jpg",
help="URL or path of the image to process.",
)
parser.add_argument(
"--prompt",
type=str,
default="USER: <image>\nWhat are these?\nASSISTANT:",
help="Message to be processed by the model.",
)
parser.add_argument(
"--max-tokens",
type=int,
default=100,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--temp", type=float, default=0.3, help="Temperature for sampling."
)
return parser.parse_args()
def load_image(image_source):
"""
Helper function to load an image from either a URL or file.
"""
if image_source.startswith(("http://", "https://")):
try:
response = requests.get(image_source, stream=True)
response.raise_for_status()
return Image.open(response.raw)
except Exception as e:
raise ValueError(
f"Failed to load image from URL: {image_source} with error {e}"
)
elif Path(image_source).is_file():
try:
return Image.open(image_source)
except IOError as e:
raise ValueError(f"Failed to load image {image_source} with error: {e}")
else:
raise ValueError(
f"The image {image_source} must be a valid URL or existing file."
)
def prepare_inputs(processor, image, prompt):
if isinstance(image, str):
image = load_image(image)
inputs = processor(prompt, image, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
return input_ids, pixel_values
def load_model(model_path):
processor = AutoProcessor.from_pretrained(model_path)
model = LlavaModel.from_pretrained(model_path)
return processor, model
def sample(logits, temperature=0.0):
if temperature == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temperature))
def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
logits, cache = model(input_ids, pixel_values)
logits = logits[:, -1, :]
y = sample(logits, temperature=temperature)
tokens = [y.item()]
for n in range(max_tokens - 1):
logits, cache = model.language_model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits, temperature)
token = y.item()
if token == processor.tokenizer.eos_token_id:
break
tokens.append(token)
return processor.tokenizer.decode(tokens)
def main():
args = parse_arguments()
processor, model = load_model(args.model)
prompt = codecs.decode(args.prompt, "unicode_escape")
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
print(prompt)
generated_text = generate_text(
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
)
print(generated_text)
if __name__ == "__main__":
main()