Skip to content

Commit

Permalink
[Fixed] RuntimeError: probability tensor contains either inf, nan or …
Browse files Browse the repository at this point in the history
…element < 0 (#704)
  • Loading branch information
init27 authored Oct 3, 2024
2 parents d9aab46 + 625860d commit ad5ce80
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions recipes/quickstart/inference/local_inference/multi_modal_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from PIL import Image as PIL_Image
import torch
from transformers import MllamaForConditionalGeneration, MllamaProcessor
from accelerate import Accelerator

accelerator = Accelerator()

device = accelerator.device

# Constants
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
Expand All @@ -14,8 +18,11 @@ def load_model_and_processor(model_name: str, hf_token: str):
"""
Load the model and processor based on the 11B or 90B model.
"""
model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_token)
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token)
model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16,use_safetensors=True, device_map=device,
token=hf_token)
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token,use_safetensors=True)

model, processor=accelerator.prepare(model, processor)
return model, processor


Expand All @@ -38,7 +45,7 @@ def generate_text_from_image(model, processor, image, prompt_text: str, temperat
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = processor(image, prompt, return_tensors="pt").to(model.device)
inputs = processor(image, prompt, return_tensors="pt").to(device)
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
return processor.decode(output[0])[len(prompt):]

Expand All @@ -63,4 +70,4 @@ def main(image_path: str, prompt_text: str, temperature: float, top_p: float, mo
parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")

args = parser.parse_args()
main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)
main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)

0 comments on commit ad5ce80

Please sign in to comment.