diff --git a/demo_trt_llm/build_visual_engine.py b/demo_trt_llm/build_visual_engine.py index a6f2543..cd5bc04 100644 --- a/demo_trt_llm/build_visual_engine.py +++ b/demo_trt_llm/build_visual_engine.py @@ -240,10 +240,10 @@ def build_vila_engine(args): vision_tower = model.get_vision_tower() image_processor = vision_tower.image_processor raw_image = Image.new('RGB', [10, 10]) # dummy image - image = image_processor(images=raw_image, - return_tensors="pt")['pixel_values'].to( - args.device, torch.float16) - + image = image_processor(images=raw_image,return_tensors="pt")['pixel_values'] + if isinstance(image, list): + image = image[0].unsqueeze(0) + image = image.to(args.device, torch.float16) class VilaVisionWrapper(torch.nn.Module): def __init__(self, tower, projector):