Skip to content

eminorhan/silicon-menagerie

Repository files navigation

A menagerie of models trained on SAYCam and other reference datasets

This is a stand-alone repository to facilitate the use of all models I have trained on SAYCam (and more!). The models are all hosted on Hugging Face. Please see our preprint for a detailed description of the models and their capabilities:

Orhan AE, Lake BM (2023) Learning high-level visual representations from a child's perspective without strong inductive biases. arXiv:2305.15372.

What you need:

  • A reasonably recent version of PyTorch and torchvision. The code was tested with pytorch==1.11.0 and torchvision==0.12.0.
  • The huggingface_hub library to download the models from the Huggingface Hub. The code was tested with huggingface-hub==0.13.4.
  • For the attention visualizations, you will need the PIL library. The code was tested with pillow==9.4.0.
  • For the generative models (VQGAN-GPT models), you will need the PyTorch Lightning library. The code was tested with pytorch-lightning==2.0.1. This is a dependency of the Taming Transformers library.
  • You do not need a GPU to load and use these models (although, of course, things will run faster on a GPU).
  • If you're only doing inference and you're not feeding the model very large batches of input, you should be able to easily fit even the largest models here on a single V100 GPU with 32GB memory.

Image embedding models

Loading the models

Model names are specified in the format 'x_y_z', where x is the self-supervised learning (SSL) algorithm used to train the model, y is the data used for training the model, and z is the model architecture:

  • x can be one of dino, mugs, or mae
  • y can be one of say, s, a, y, imagenet100, imagenet10, imagenet1, kinetics-200h, ego4d-200h
  • z can be one of resnext50, vitb14, vitl16, vitb16, vits16

Here, say, s, a, y represent the full SAYCam dataset and the individual babies S, A, and Y, respectively; imagenet100, imagenet10, imagenet1 are the full ImageNet training set and its subsets (10% and 1% of the full training set, respectively); kinetics-200h and ego4d-200h are 200-hour subsets of the Kinetics-700 and Ego4D video datasets. Please note that not all possible combinations of x_y_z are available; you can see a list of all available models by running:

>>> print(utils.get_available_models())

You will get an error if you try to load an unavailable model.

Loading a pretrained model is then as easy as:

from utils import load_model

model = load_model('dino_say_vitb14')

This will download the corresponding pretrained checkpoint (i.e. a ViT-B/14 model trained with DINO on the full SAYCam dataset), store it in cache, build the right model architecture, and load the pretrained weights onto the model, all in one go! When you load a model, you will get a warning message that says something like _IncompatibleKeys(missing_keys=[], unexpected_keys=...). This is normal. This happens because we're not loading the projection head used during DINO or Mugs pretraining, or the decoder model used during MAE pretraining. We're only interested in the encoder backbone.

Training and evaluation

Models were trained and evaluated using separate repositories. DINO models were trained and evaluated with code from this repository, Mugs models were trained and evaluated with code from this repository, and the MAE models were trained and evaluated with code from this repository. Training logs for all models can be found inside the logs folder here.

Attaching a linear classifier head

Please see the example code here to find out how to attach a linear classifier head to a pretrained model (either a resnext or a ViT model), and train or finetune it on a downstream classification task.

Visualizing the attention heads

I also include here some bare bones functionality to visualize the attention heads of the transformer models given an image. All you need to do is something along the lines of:

import torch
from utils import load_model, preprocess_image, visualize_attentions

model = load_model('dino_say_vitb14')

img = preprocess_image(img_path="imgs/img_0.jpg", img_size=1400)
with torch.no_grad():
    visualize_attentions(model, img, patch_size=14)

The file test_emb_model.py contains a more fleshed out usage example. This will produce images like the following (with the original image on the top left and the attention maps for each head after that).

dino_say_vitb14:

dino_imagenet100_vitb14:

You can find more examples in the assets/atts folder.

Additional assets

assets/cams contains additional examples of class activation maps (CAMs) for the dino_s_resnext50 model (similar to Figure 3a in the paper). assets/tsnes contains additional examples of t-SNE maps for three more datasets and multiple different models (similar to Figure 4 in the paper).

Generative image models (VQGAN-GPT)

These are generative models that can be used to generate images. For these models, we first learn a discrete codebook with a VQGAN model and then encode the video frames as 32x32 integers from this codebook. These discretized and spatially downsampled frames are then fed into a GPT model to learn a prior over the frames.

Training

The VQGAN codebooks were trained with code from this repository, which is my personal copy of the Taming Transformers repository, using the config files here. The GPT models were trained with code from this repository.

Loading the models

Loading a pretrained VQGAN-GPT model is extremely easy:

from gpt_utils import load_model

gpt_model, vq_model = load_model('say_gimel')

Here, 'say_gimel' is the model identifier, vq_model is the VQ codebook part of the model that is used to encode images into latents and decode latents back into images, and gpt_model is the pretrained GPT model. Model identifiers are specified in the format 'x_y', where x is the data used for training the model and y is the name of the GPT model configuration:

  • x can be one of say, s, a, y, imagenet100, imagenet10, imagenet1, konkleiid, konklenonvehicle
  • y can be one of alef, bet, gimel, dalet

Here, alef, bet, gimel, dalet refer to model configurations with different sizes, dalet being the largest model (1.5B params) and alef being the smallest one (110M params). You can find the detailed model specifications in this file. konkleiid and konklenonvehicle are subsets of the Konkle objects dataset (the first is an iid half split of the dataset, the second one includes all nonvehicle classes in the dataset).

In addition, you can also load models finetuned on these two Konkle datasets by concatenating their names to the x string with a + sign: e.g. 'say+konkleiid_gimel' is a gimel sized model (~730M parameters) pretrained on say (all of SAYCam) and then finetuned on konkleiid.

Please note that not all possible x_y combinations are available; you can see a list of all available models by running:

>>> print(gpt_utils.get_available_models())

You will get an error if you try to load an unavailable model.

Generating images from the pretrained models

I also provide two utility functions in gpt_utils to generate images from the pretrained VQGAN-GPT models: generate_images_freely generates unconditional samples from the model and generate_images_from_half generates conditional samples conditioned on the upper halves of a set of images. You can use these functions as follows:

from gpt_utils import load_model, generate_images_freely, generate_images_from_half

# load model
gpt_model, vq_model = load_model('say_gimel')

# generate unconditional samples from the loaded model
x = generate_images_freely(gpt_model, vq_model, n_samples=36)

# generate conditional samples from the model
y = generate_images_from_half(gpt_model, vq_model, img_dir, n_imgs=6, n_samples_per_img=6)

where img_dir is the directory containing the conditioning images. We randomly sample n_imgs images from this directory to condition on. We then generate n_samples_per_img conditional samples per image. The file test_gpt_model.py contains a more fleshed out usage example. You can generate images like the following with these functions:

unconditional:

conditional:

In the conditional examples, the upper half is given as context, the lower half is generated by the model. Again for the conditional case, the top row shows the ground truth images, the other rows are model generations (bottom halves). You can find more examples in the assets/gpt_samples folder.

Packages

No packages published

Languages