Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add audio #540

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

## News 🚀🚀🚀

- `2024/08/28`: 🚀 We release the [InternVL Omni](https://huggingface.co/OpenGVLab/InternOmni), a model capable of accepting audio input, which maintains good performance in VQA while also showing strong performance in Audio question answering.
- `2024/08/01`: The [Chartmimic](https://chartmimic.github.io/) team evaluated the InternVL2 series models on their benchmark. The InternVL2-26B and 76B models achieved the top two performances among open-source models, with the InternVL2 76B model surpassing GeminiProVision and exhibiting comparable results to Claude-3-opus.
- `2024/08/01`: InternVL2-Pro achieved the SOTA performance among open-source models on the [CharXiv](https://charxiv.github.io/#leaderboard) dataset, surpassing some well-known closed-source models such as GPT-4V, Gemini 1.5 Flash, and Claude 3 Sonnet.
- `2024/07/24`: The [MLVU](https://github.com/JUNJIE99/MLVU) team evaluated InternVL-1.5 on their benchmark. The average performance on the multiple-choice task was 50.4%, while the performance on the generative tasks was 4.02. The performance on the multiple-choice task ranked #1 among all open-source MLLMs.
Expand Down
48 changes: 48 additions & 0 deletions internvl_chat_audio/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# InternVL-Chat

This folder contains the implementation of the InternVL-Chat.

## 🛠️ Installation

See [INSTALLATION.md](../INSTALLATION.md)

In addition, using this codebase requires executing the following steps:

- Install other requirements:

```bash
pip install --upgrade pip # enable PEP 660 support
pip install -e .
```

## 📖 Documents

- InternVL 2.0

- Introduction [\[link\]](https://internvl.readthedocs.io/en/latest/internvl2.0/introduction.html)
- Quick Start [\[link\]](https://internvl.readthedocs.io/en/latest/internvl2.0/quick_start.html)
- Finetune [\[link\]](https://internvl.readthedocs.io/en/latest/internvl2.0/finetune.html)
- Evaluation [\[link\]](https://internvl.readthedocs.io/en/latest/internvl2.0/evaluation.html)
- Deployment [\[link\]](https://internvl.readthedocs.io/en/latest/internvl2.0/deployment.html)

- InternVL 1.5

- Introduction [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.5/introduction.html)
- Quick Start [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.5/quick_start.html)
- Finetune [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.5/finetune.html)
- Evaluation [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.5/evaluation.html)
- Deployment [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.5/deployment.html)

- InternVL 1.2

- Introduction [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.2/introduction.html)
- Quick Start [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.2/quick_start.html)
- Reproduce [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.2/reproduce.html)
- Finetune [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.2/finetune.html)
- Evaluation [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.2/evaluation.html)

- InternVL 1.1

- Introduction [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.1/introduction.html)
- Quick Start [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.1/quick_start.html)
- Evaluation [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.1/evaluation.html)
272 changes: 272 additions & 0 deletions internvl_chat_audio/eval/caption/evaluate_caption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import argparse
import itertools
import json
import os
import random
import time
from functools import partial

import torch
from internvl.model import load_model_and_tokenizer
from internvl.train.dataset import build_transform, dynamic_preprocess
from PIL import Image
from pycocoevalcap.eval import COCOEvalCap
from pycocotools.coco import COCO
from tqdm import tqdm

ds_collections = {
'flickr30k': {
'root': 'data/flickr30k/',
'annotation': 'data/flickr30k/flickr30k_test_karpathy.json',
'max_new_tokens': 30,
'min_new_tokens': 8,
},
'coco': {
'root': 'data/coco/',
'annotation': ['data/coco/annotations/coco_karpathy_test.json',
'data/coco/annotations/coco_karpathy_test_gt.json'],
'max_new_tokens': 30,
'min_new_tokens': 8,
},
'nocaps': {
'root': 'data/nocaps/images',
'annotation': 'data/nocaps/nocaps_val_4500_captions.json',
'max_new_tokens': 30,
'min_new_tokens': 8,
},
}


class CaptionDataset(torch.utils.data.Dataset):

def __init__(self, name, root, annotation, prompt, input_size=224, dynamic_image_size=False,
use_thumbnail=False, max_num=6):
if name == 'coco':
self.images = json.load(open(annotation))
else:
self.images = json.load(open(annotation))['images']
self.name = name
self.prompt = prompt
self.root = root
self.input_size = input_size
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.max_num = max_num
self.transform = build_transform(is_train=False, input_size=input_size)

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
if self.name == 'coco':
filename = self.images[idx]['image']
image_id = int(filename.split('_')[-1].replace('.jpg', ''))
image_path = os.path.join(self.root, filename)
else:
image_id = self.images[idx]['id']
if 'file_name' in self.images[idx]:
image_path = os.path.join(self.root, self.images[idx]['file_name'])
else:
image_path = os.path.join(self.root, self.images[idx]['image'])

image = Image.open(image_path)
if self.dynamic_image_size:
images = dynamic_preprocess(image, image_size=self.input_size,
use_thumbnail=self.use_thumbnail,
max_num=self.max_num)
else:
images = [image]
pixel_values = [self.transform(image) for image in images]
pixel_values = torch.stack(pixel_values)

return {
'image_id': image_id,
'input_text': self.prompt,
'pixel_values': pixel_values
}


def collate_fn(inputs, tokenizer):
pixel_values = torch.cat([_['pixel_values'] for _ in inputs], dim=0)
image_ids = [_['image_id'] for _ in inputs]
input_texts = [_['input_text'] for _ in inputs]
input_tokens = tokenizer(input_texts, return_tensors='pt')

return pixel_values, image_ids, input_tokens.input_ids, input_tokens.attention_mask


class InferenceSampler(torch.utils.data.sampler.Sampler):

def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)

@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]

begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)

def __iter__(self):
yield from self._local_indices

def __len__(self):
return len(self._local_indices)


def evaluate_chat_model():
prompt = 'Provide a one-sentence caption for the provided image.'
print('prompt:', prompt)
random.seed(args.seed)
summaries = []

for ds_name in args.datasets:
annotation = ds_collections[ds_name]['annotation']
if type(annotation) == list:
annotation = annotation[0]
dataset = CaptionDataset(
name=ds_name,
root=ds_collections[ds_name]['root'],
annotation=annotation,
prompt=prompt,
input_size=image_size,
dynamic_image_size=args.dynamic,
use_thumbnail=use_thumbnail,
max_num=args.max_num
)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
)

image_ids, captions = [], []
for _, (pixel_values, ids, _, _) in tqdm(enumerate(dataloader)):
pixel_values = pixel_values.to(torch.bfloat16).cuda()
generation_config = dict(
num_beams=args.num_beams,
max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
)
pred = model.chat(
tokenizer=tokenizer,
pixel_values=pixel_values,
question=prompt,
generation_config=generation_config,
verbose=True
)
image_ids.extend(ids)
captions.extend([pred])

torch.distributed.barrier()

world_size = torch.distributed.get_world_size()
merged_ids = [None for _ in range(world_size)]
merged_captions = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_ids, image_ids)
torch.distributed.all_gather_object(merged_captions, captions)

merged_ids = [_ for _ in itertools.chain.from_iterable(merged_ids)]
merged_captions = [_ for _ in itertools.chain.from_iterable(merged_captions)]
average_length = sum(len(x.split()) for x in merged_captions) / len(merged_captions)
print(f'Average caption length: {average_length}')

if torch.distributed.get_rank() == 0:
print(f'Evaluating {ds_name} ...')

results = []
for image_id, caption in zip(merged_ids, merged_captions):
results.append({
'image_id': int(image_id),
'caption': caption,
})
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
results_file = f'{ds_name}_{time_prefix}.json'
results_file = os.path.join(args.out_dir, results_file)
json.dump(results, open(results_file, 'w'))

annotation = ds_collections[ds_name]['annotation']
if type(annotation) == list:
annotation = annotation[-1]
coco = COCO(annotation)
coco_result = coco.loadRes(results_file)
coco_eval = COCOEvalCap(coco, coco_result)
coco_eval.evaluate()

summary = coco_eval.eval.items()
print(summary)
summaries.append([args.checkpoint, ds_name, average_length, summary])

torch.distributed.barrier()

out_path = '_'.join(args.checkpoint.split('/')[-2:])
writer = open(os.path.join(args.out_dir, f'{out_path}.txt'), 'a')
print(f"write results to file {os.path.join(args.out_dir, f'{out_path}.txt')}")
for summary in summaries:
print(summary)
writer.write(f'{summary}\n')
writer.close()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='')
parser.add_argument('--datasets', type=str, default='coco,flickr30k,nocaps')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--num-beams', type=int, default=5)
parser.add_argument('--temperature', type=float, default=0.0)
parser.add_argument('--out-dir', type=str, default='results')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--dynamic', action='store_true')
parser.add_argument('--max-num', type=int, default=6)
parser.add_argument('--load-in-8bit', action='store_true')
parser.add_argument('--load-in-4bit', action='store_true')
parser.add_argument('--auto', action='store_true')
args = parser.parse_args()

if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)

args.datasets = args.datasets.split(',')
print('datasets:', args.datasets)
assert args.batch_size == 1, 'Only batch size 1 is supported'

torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)

torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))

model, tokenizer = load_model_and_tokenizer(args)
image_size = model.config.force_image_size or model.config.vision_config.image_size
use_thumbnail = model.config.use_thumbnail

total_params = sum(p.numel() for p in model.parameters()) / 1e9
if total_params > 20 or args.dynamic:
args.num_beams = 1
print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
else:
print(f'[test] total_params: {total_params}B')
print(f'[test] image_size: {image_size}')
print(f'[test] template: {model.config.template}')
print(f'[test] dynamic_image_size: {args.dynamic}')
print(f'[test] use_thumbnail: {use_thumbnail}')
print(f'[test] max_num: {args.max_num}')

evaluate_chat_model()
Loading