Skip to content

Commit

Permalink
Add: gen model demo
Browse files Browse the repository at this point in the history
  • Loading branch information
kimihailv committed Nov 30, 2023
1 parent 8e73b01 commit 4b0fe4e
Show file tree
Hide file tree
Showing 2 changed files with 586 additions and 0 deletions.
56 changes: 56 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch

from transformers import TextStreamer
from .src.gen_model import VLMForCausalLM, VLMProcessor
from PIL import Image

EOS_TOKEN = 32001


if __name__ == "__main__":
print(
"1) For setting an image: [img] path/to/the/image",
"2) For captioning: [cap] describe the image / give a detailed description etc",
"3) For VQA: [vqa] question",
"4) For only-text prompts: [txt] prompt",
sep="\n",
)
image = None

print("\nLoading model")
device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = VLMForCausalLM.from_pretrained("unum-cloud/uform-gen").eval().to(device)
processor = VLMProcessor.from_pretrained("unum-cloud/uform-gen")
streamer = TextStreamer(
processor.tokenizer, skip_prompt=True, skip_special_tokens=True
)

while True:
print("> ", end="")
prompt = input()

if prompt.startswith("[img]"):
image_path = prompt.split("[img]")[-1].strip()
image = Image.open(image_path)
print("Image is set!")
continue

is_text_only = prompt.startswith("[txt]")

input_data = processor(text=prompt, images=image, return_tensors="pt").to(
device
)

with torch.inference_mode():
response = model.generate(
input_ids=input_data["input_ids"],
attention_mask=None if is_text_only else input_data["attention_mask"],
images=None if is_text_only else input_data["images"],
use_cache=True,
do_sample=False,
max_new_tokens=1024,
eos_token_id=EOS_TOKEN,
pad_token_id=processor.tokenizer.pad_token_id,
streamer=streamer,
)
Loading

0 comments on commit 4b0fe4e

Please sign in to comment.