From 25db06016643b02d430c76e01c750f38f300cb67 Mon Sep 17 00:00:00 2001 From: weimingc Date: Fri, 17 May 2024 00:57:22 +0000 Subject: [PATCH] fix --- demo_trt_llm/build_visual_engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/demo_trt_llm/build_visual_engine.py b/demo_trt_llm/build_visual_engine.py index a6f2543..cd5bc04 100644 --- a/demo_trt_llm/build_visual_engine.py +++ b/demo_trt_llm/build_visual_engine.py @@ -240,10 +240,10 @@ def build_vila_engine(args): vision_tower = model.get_vision_tower() image_processor = vision_tower.image_processor raw_image = Image.new('RGB', [10, 10]) # dummy image - image = image_processor(images=raw_image, - return_tensors="pt")['pixel_values'].to( - args.device, torch.float16) - + image = image_processor(images=raw_image,return_tensors="pt")['pixel_values'] + if isinstance(image, list): + image = image[0].unsqueeze(0) + image = image.to(args.device, torch.float16) class VilaVisionWrapper(torch.nn.Module): def __init__(self, tower, projector):