-
Notifications
You must be signed in to change notification settings - Fork 175
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #778 from frankchieng/patch-1
Create caption_with_internvl.py
- Loading branch information
Showing
1 changed file
with
315 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,315 @@ | ||
import os | ||
from pathlib import Path | ||
import logging | ||
import re | ||
import random | ||
import argparse | ||
import base64 | ||
import torch | ||
from PIL import Image | ||
from tqdm import tqdm | ||
import requests | ||
import io | ||
import pandas as pd | ||
import torch.nn as nn | ||
import numpy as np | ||
import glob | ||
import torchvision.transforms as T | ||
from torchvision.transforms.functional import InterpolationMode | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
AutoProcessor, | ||
AutoModel, | ||
AutoTokenizer, | ||
|
||
) | ||
|
||
IMAGENET_MEAN = (0.485, 0.456, 0.406) | ||
IMAGENET_STD = (0.229, 0.224, 0.225) | ||
|
||
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' | ||
|
||
logger = logging.getLogger("Captioner") | ||
|
||
# load the existing parquet file if it exists | ||
def load_input_parquet(parquet_path: str): | ||
df = pd.read_parquet(path=parquet_path) | ||
return df | ||
|
||
# Load InterVL2-8B model, only need 24G VRAM,if you wanna to load bigger models like 26B or 72B,you should need 1-3 80G A100 | ||
def load_model(model_name_or_path="OpenGVLab/InternVL2-8B"): | ||
model = AutoModel.from_pretrained( | ||
model_name_or_path, | ||
torch_dtype=torch.bfloat16, | ||
low_cpu_mem_usage=True, | ||
trust_remote_code=True).eval().to(device) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=False) | ||
|
||
return ( | ||
model, | ||
tokenizer | ||
) | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description="Process images and generate captions." | ||
) | ||
parser.add_argument( | ||
"--input_dir", type=str, required=True, help="Directory containing the images." | ||
) | ||
parser.add_argument( | ||
"--model_name", | ||
type=str, | ||
default="OpenGVLab/InternVL2-8B", | ||
help="Model name to use for captioning.", | ||
) | ||
parser.add_argument( | ||
"--output_parquet", | ||
type=str, | ||
required=True, | ||
help="Path to the output Parquet dataset.", | ||
) | ||
parser.add_argument( | ||
"--query_str", | ||
type=str, | ||
default="", | ||
help="The query string to use for captioning. This instructs the model how to behave. Not normally needed for InvernVL", | ||
) | ||
parser.add_argument( | ||
"--precision", | ||
type=str, | ||
choices=["bf16", "fp16"], | ||
default="fp16", | ||
help=("Precision for loading the model. Default: fp16"), | ||
) | ||
parser.add_argument( | ||
"--input_parquet", | ||
type=str, | ||
default=None, | ||
help="Path to the input Parquet dataset which will be adjusted to have the new column.", | ||
) | ||
parser.add_argument( | ||
"--input_parquet_hint_column", | ||
type=str, | ||
default="title", | ||
help="When set, the column to use as a hint for the input query str placement value. Default: title", | ||
) | ||
parser.add_argument( | ||
"--max_new_tokens", | ||
type=int, | ||
default=1024, | ||
help="The maximum number of tokens to generate. Default: 1024", | ||
) | ||
parser.add_argument( | ||
"--do_sample", | ||
action="store_true", | ||
default=False, | ||
help=( | ||
"Whether to use sampling for generation. Makes model more responsive to input prompts." | ||
" If not set, greedy decoding is used. Default: False" | ||
), | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
def build_transform(input_size): | ||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD | ||
transform = T.Compose([ | ||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | ||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | ||
T.ToTensor(), | ||
T.Normalize(mean=MEAN, std=STD) | ||
]) | ||
return transform | ||
|
||
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | ||
best_ratio_diff = float('inf') | ||
best_ratio = (1, 1) | ||
area = width * height | ||
for ratio in target_ratios: | ||
target_aspect_ratio = ratio[0] / ratio[1] | ||
ratio_diff = abs(aspect_ratio - target_aspect_ratio) | ||
if ratio_diff < best_ratio_diff: | ||
best_ratio_diff = ratio_diff | ||
best_ratio = ratio | ||
elif ratio_diff == best_ratio_diff: | ||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | ||
best_ratio = ratio | ||
return best_ratio | ||
|
||
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): | ||
orig_width, orig_height = image.size | ||
aspect_ratio = orig_width / orig_height | ||
|
||
# calculate the existing image aspect ratio | ||
target_ratios = set( | ||
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | ||
i * j <= max_num and i * j >= min_num) | ||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | ||
|
||
# find the closest aspect ratio to the target | ||
target_aspect_ratio = find_closest_aspect_ratio( | ||
aspect_ratio, target_ratios, orig_width, orig_height, image_size) | ||
|
||
# calculate the target width and height | ||
target_width = image_size * target_aspect_ratio[0] | ||
target_height = image_size * target_aspect_ratio[1] | ||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | ||
|
||
# resize the image | ||
resized_img = image.resize((target_width, target_height)) | ||
processed_images = [] | ||
for i in range(blocks): | ||
box = ( | ||
(i % (target_width // image_size)) * image_size, | ||
(i // (target_width // image_size)) * image_size, | ||
((i % (target_width // image_size)) + 1) * image_size, | ||
((i // (target_width // image_size)) + 1) * image_size | ||
) | ||
# split the image | ||
split_img = resized_img.crop(box) | ||
processed_images.append(split_img) | ||
assert len(processed_images) == blocks | ||
if use_thumbnail and len(processed_images) != 1: | ||
thumbnail_img = image.resize((image_size, image_size)) | ||
processed_images.append(thumbnail_img) | ||
return processed_images | ||
|
||
def load_image(image_file, input_size=448, max_num=12): | ||
image = Image.open(image_file).convert('RGB') | ||
transform = build_transform(input_size=input_size) | ||
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) | ||
pixel_values = [transform(image) for image in images] | ||
pixel_values = torch.stack(pixel_values) | ||
return pixel_values | ||
|
||
def process_directory( | ||
args, | ||
image_dir, | ||
output_parquet, | ||
model, | ||
tokenizer, | ||
max_new_tokens, | ||
input_parquet=None, | ||
original_query_str=None, | ||
total_to_process: int = None, | ||
): | ||
records = [] | ||
directory_path = Path(image_dir) | ||
parquet_path = f"{output_parquet}.{directory_path.name}.parquet" | ||
print(f"Parquet: {parquet_path}") | ||
total_processed = 0 | ||
for filename in tqdm(os.listdir(image_dir), desc="Processing Images"): | ||
if input_parquet is not None: | ||
# if the caption column at the filename position is non-empty, skip | ||
current_caption = input_parquet[input_parquet["filename"] == filename] | ||
hint_column = args.input_parquet_hint_column | ||
hint_value = None | ||
if hint_column is not None and hint_column != "": | ||
try: | ||
hint_value = current_caption[hint_column].values[0] | ||
except: | ||
hint_value = None | ||
if hint_value is not None and not hint_value == "": | ||
if original_query_str is not None: | ||
args.query_str = original_query_str | ||
args.query_str = args.query_str.replace("%s", hint_value) | ||
logger.info( | ||
f"Using query string: {args.query_str} for hint value: {hint_value}" | ||
) | ||
try: | ||
if ( | ||
not current_caption.empty | ||
and not current_caption["caption"].isnull().values[0] | ||
): | ||
logger.debug(f"Already has caption: {current_caption['caption']}") | ||
continue | ||
except: | ||
logger.debug(f"Error checking for existing caption: {current_caption}") | ||
full_filepath = os.path.join(image_dir, filename) | ||
if os.path.isdir(full_filepath): | ||
logger.info(f"Found directory to traverse: {full_filepath}") | ||
process_directory( | ||
args, | ||
full_filepath, | ||
output_parquet, | ||
model, | ||
tokenizer, | ||
input_parquet=input_parquet, | ||
original_query_str=original_query_str, | ||
) | ||
args.query_str = original_query_str | ||
original_query_str = None | ||
elif filename.lower().endswith((".jpg", ".png", ".jpeg")): | ||
try: | ||
logger.debug(f"Attempting to load image: {filename}") | ||
logger.debug(f"Processing image: {filename}") | ||
# set the max number of tiles in `max_num` | ||
pixel_values = load_image(full_filepath, max_num=12).to(torch.bfloat16).to(device) | ||
generation_config = dict(max_new_tokens=max_new_tokens, do_sample=False) | ||
|
||
question = '<image>\n' + original_query_str | ||
response = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=False) | ||
|
||
total_processed += 1 | ||
logger.debug(f"Best match for {filename}: {response}") | ||
|
||
# with Image.open(full_filepath) as img_file: | ||
# image_bytes = img_file.tobytes() | ||
|
||
records.append({"filename": filename, "caption": response}) | ||
if ( | ||
total_to_process is not None | ||
and total_processed >= total_to_process | ||
): | ||
break | ||
|
||
except Exception as e: | ||
import traceback | ||
|
||
logger.error( | ||
f"Error processing {filename}: {str(e)}, traceback: {traceback.format_exc()}" | ||
) | ||
if "CUDA error" in str(e): | ||
import sys | ||
|
||
sys.exit(1) | ||
new_df = pd.DataFrame(records) | ||
if input_parquet is not None: | ||
# Merge new_df with input_parquet | ||
input_parquet.set_index("filename", inplace=True) | ||
new_df.set_index("filename", inplace=True) | ||
combined_df = input_parquet.combine_first(new_df).reset_index() | ||
else: | ||
combined_df = new_df | ||
# reduce duplicates by "filename" contents | ||
combined_df = combined_df.drop_duplicates(subset=["filename"]) | ||
combined_df.to_parquet(parquet_path, engine="pyarrow") | ||
logger.info(f"Processed Parquet file saved to {output_parquet}") | ||
|
||
def main(): | ||
args = parse_args() | ||
logging.basicConfig(level=logging.INFO) | ||
input_database = None | ||
if args.input_parquet: | ||
if not os.path.exists(args.input_parquet): | ||
raise ValueError("The parquet file specified as input did not exist.") | ||
|
||
input_database = load_input_parquet(args.input_parquet) | ||
|
||
model, tokenizer = load_model(args.model_name) | ||
process_directory( | ||
args, | ||
args.input_dir, | ||
args.output_parquet, | ||
model, | ||
tokenizer, | ||
max_new_tokens=args.max_new_tokens, | ||
input_parquet=input_database, | ||
original_query_str=str(args.query_str), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |