Skip to content

Commit

Permalink
ImageNet HF support and fixes (#30)
Browse files Browse the repository at this point in the history
Summary:
This PR fixes some of the bugs in the original code that allows us to
match the FLAVA numbers with the original codebase. This was separately
tested on WinoGround dataset for zero-shot and ITM evaluation to match
with the original code.

Also, adds support for ImageNet dataset from HF which simplifies some of
the launch process.

Fixes #27

Pull Request resolved: #30

Reviewed By: ebsmothers

Differential Revision: D35970492

Pulled By: ankitade

fbshipit-source-id: a33af4c009751e46459b86a2cfb623dcd6c3e70e
  • Loading branch information
apsdehal authored and facebook-github-bot committed May 9, 2022
1 parent 3b537be commit a69f32f
Show file tree
Hide file tree
Showing 14 changed files with 502 additions and 1,237 deletions.
8 changes: 5 additions & 3 deletions examples/flava/callbacks/multimodal_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from data import default_text_transform, VL_MAX_LENGTH_DEFAULT
from imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
from data.imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
from pytorch_lightning import Callback, LightningDataModule
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm
Expand Down Expand Up @@ -50,8 +50,10 @@ def run_imagenet_zero_shot(model, dataloader, device, text_transform, *args, **k
classifier = _zero_shot_classifier(model, device, text_transform)
logger.info("Classifier built")
top1, top5, n = 0.0, 0.0, 0.0
for images, target in tqdm(dataloader):
images = images["image"].to(device)
for sample in tqdm(dataloader):
images = sample["image"]
target = sample["label"]
images = images.to(device)
target = target.to(device)

# predict
Expand Down
Loading

0 comments on commit a69f32f

Please sign in to comment.