Skip to content

Commit

Permalink
Support winoground (#116)
Browse files Browse the repository at this point in the history
* support winoground. re-use image_caption_selection for both sugar crepe and winoground.

* minor
  • Loading branch information
mehdidc authored Jan 12, 2024
1 parent 3f1d126 commit 5f23a76
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 25 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,12 @@ To evaluate on all the tasks together, you can do:

`clip_benchmark eval --model ViT-B-32 --pretrained laion400m_e32 --dataset=sugar_crepe --output=result.json`

For [winoground](https://huggingface.co/datasets/facebook/winoground/):

`clip_benchmark eval --model ViT-B-32 --pretrained laion400m_e32 --dataset=winoground --output=result.json`

NB: `pip install datasets` is required for winoground.

### Webdataset example

Here is an example on how to run it on [webdatasets](https://github.com/webdataset/webdataset).
Expand Down
8 changes: 5 additions & 3 deletions clip_benchmark/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
RenderedSST2, StanfordCars)

from . import (babel_imagenet, caltech101, flickr, imagenetv2, objectnet,
sugar_crepe, voc2007)
sugar_crepe, voc2007, winoground)


def build_dataset(dataset_name, root="root", transform=None, split="test", download=True, annotation_file=None, language="en", task="zeroshot_classification", wds_cache_dir=None, custom_classname_file=None, custom_template_file=None, **kwargs):
Expand Down Expand Up @@ -242,6 +242,8 @@ def download_imagenet(r):
url = f"https://raw.githubusercontent.com/RAIVNLab/sugar-crepe/main/data/{task}.json"
call(f"wget {url} --output-document={ann}", shell=True)
ds = sugar_crepe.SugarCrepe(root=os.path.join(root, "val2017"), ann_file=ann, transform=transform, **kwargs)
elif dataset_name == "winoground":
ds = winoground.WinoGround(root=root, transform=transform)
elif dataset_name == "mscoco_captions":
# https://github.com/mehdidc/retrieval_annotations/releases/tag/1.0.0(annotations)
if split == "train":
Expand Down Expand Up @@ -523,13 +525,13 @@ def __len__(self):
def get_dataset_default_task(dataset):
if dataset in ("flickr30k", "flickr8k", "mscoco_captions", "multilingual_mscoco_captions", "flickr30k-200", "crossmodal3600", "xtd200"):
return "zeroshot_retrieval"
elif dataset.startswith("sugar_crepe"):
elif dataset.startswith("sugar_crepe") or dataset == "winoground":
return "image_caption_selection"
else:
return "zeroshot_classification"

def get_dataset_collate_fn(dataset_name):
if dataset_name in ("mscoco_captions", "multilingual_mscoco_captions", "flickr30k", "flickr8k", "flickr30k-200", "crossmodal3600", "xtd200") or dataset_name.startswith("sugar_crepe"):
if dataset_name in ("mscoco_captions", "multilingual_mscoco_captions", "flickr30k", "flickr8k", "flickr30k-200", "crossmodal3600", "xtd200", "winoground") or dataset_name.startswith("sugar_crepe"):
return image_captions_collate_fn
else:
return default_collate
Expand Down
30 changes: 30 additions & 0 deletions clip_benchmark/datasets/winoground.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
from torch.utils.data import Dataset
from PIL import Image
import torch
import json

class WinoGround(Dataset):

def __init__(self, root=".", transform=None):
from datasets import load_dataset
self.ds = load_dataset("facebook/winoground", cache_dir=root)["test"]
self.transform = transform

def __getitem__(self, idx):
data = self.ds[idx]
img0 = data["image_0"]
img1 = data["image_1"]
cap0 = data["caption_0"]
cap1 = data["caption_1"]
if self.transform is not None:
img0 = self.transform(img0)
img1 = self.transform(img1)
imgs = torch.stack([img0, img1])
else:
imgs = [img0, img1]
caps = [cap0, cap1]
return imgs, caps

def __len__(self):
return len(self.ds)
65 changes: 43 additions & 22 deletions clip_benchmark/metrics/image_caption_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import torch.nn.functional as F
from tqdm import tqdm

def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5]):
def evaluate(model, dataloader, tokenizer, device, amp=True):
"""
Evaluate the model on the given dataset
Evaluate the model on the given dataset.
The task has N instances, each instance has I images and C captions.
For each instance, the goal is to find the correct image for each caption and the correct caption for each image.
This is done by computing the similarities between each image and each caption.
This procedure is used to evaluate the models on Winoground and SugarCrepe.
Parameters
----------
Expand All @@ -28,32 +32,49 @@ def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5])
Returns
-------
dict of accuracy metric
dict of accuracy metrics
"""
autocast = torch.cuda.amp.autocast if amp else suppress
preds = []
image_score = []
text_score = []
score = []
for batch_images, batch_texts in tqdm(dataloader):
if len(batch_images.shape) == 4:
B, C, H, W = batch_images.shape
batch_images = batch_images.view(B, 1, C, H, W)
# batch_images: B, nb_images_per_instance, C, H, W
# batch_texts: B, nb_captions_per_instance

B, nim, C, H, W = batch_images.shape
nt = len(batch_texts[0])
batch_images = batch_images.to(device)
batch_images_ = batch_images.view(B*nim, C, H, W) # B*nim, C, H, W
# tokenize all texts in the batch
batch_texts_tok = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device)
nb_texts_for_each_image = [len(texts) for texts in batch_texts]

batch_texts_tok_ = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device)
# compute the embedding of images and texts
with torch.no_grad(), autocast():
batch_images_emb = F.normalize(model.encode_image(batch_images), dim=-1).cpu()
batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok), dim=-1).cpu()
start = 0
for i, nb in enumerate(nb_texts_for_each_image):
end = start + nb
image_emb = batch_images_emb[i:i+1]
texts_emb = batch_texts_emb[start:end]
scores = image_emb @ texts_emb.t()
scores = scores[0]
pred = scores.argmax().item()
start = end
preds.append(pred)
pred = torch.Tensor(preds).long()
acc = (pred==0).float().mean().item() # 0 is the index of the caption, the rest (>0) are considered negative captions
batch_images_emb = F.normalize(model.encode_image(batch_images_), dim=-1).view(B, nim, -1)
batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok_), dim=-1).view(B, nt, -1)
gt = torch.arange(min(nim, nt)).to(device)
for i in range(B):
# iteratve over instances

# compute similarities between each image and each text
images_emb = batch_images_emb[i]
texts_emb = batch_texts_emb[i]
scores = images_emb @ texts_emb.t()

# i-th image should be matched to the i-th text
image_closest_text = scores.argmax(dim=1)[:len(gt)]
text_closest_image = scores.argmax(dim=0)[:len(gt)]
pred_text_is_correct = (image_closest_text==gt).all().item()
pred_image_is_correct = (text_closest_image==gt).all().item()
all_correct = pred_text_is_correct and pred_image_is_correct
image_score.append(pred_image_is_correct)
text_score.append(pred_text_is_correct)
score.append(all_correct)
metrics = {}
metrics[f"acc"] = acc
metrics["image_acc"] = torch.Tensor(image_score).float().mean().item()
metrics["text_acc"] = torch.Tensor(text_score).float().mean().item()
metrics["acc"] = torch.Tensor(score).float().mean().item()
return metrics

0 comments on commit 5f23a76

Please sign in to comment.