Skip to content

PyTorch code and models for the DINOv2 self-supervised learning method.

License

Notifications You must be signed in to change notification settings

dongbo1998/dinov2

 
 

Repository files navigation

🆕 [2023-10-26] Added DINOv2 backbones with registers, following Vision Transformers Need Registers.

DINOv2: Learning Robust Visual Features without Supervision

Meta AI Research, FAIR

Maxime Oquab, Timothée Darcet, Théo Moutakanni, Huy V. Vo, Marc Szafraniec, Vasil Khalidov, Patrick Labatut, Armand Joulin, Piotr Bojanowski

[Paper #1] Paper #2] [Blog] [Demo] [BibTeX]

PyTorch implementation and pretrained models for DINOv2. For details, see the papers: DINOv2: Learning Robust Visual Features without Supervision and Vision Transformers Need Registers.

DINOv2 models produce high-performance visual features that can be directly employed with classifiers as simple as linear layers on a variety of computer vision tasks; these visual features are robust and perform well across domains without any requirement for fine-tuning. The models were pretrained on a dataset of 142 M images without using any labels or annotations.

video-reference+dinov2.mp4
Visualization of the three first principal components of the patch features of all frames, mapped to RGB values.

Pretrained models

model # of
params
with
registers
ImageNet
k-NN
ImageNet
linear
download
ViT-S/14 distilled 21 M 79.0% 81.1% backbone only
ViT-S/14 distilled 21 M 79.1% 80.9% backbone only
ViT-B/14 distilled 86 M 82.1% 84.5% backbone only
ViT-B/14 distilled 86 M 82.0% 84.6% backbone only
ViT-L/14 distilled 300 M 83.5% 86.3% backbone only
ViT-L/14 distilled 300 M 83.8% 86.7% backbone only
ViT-g/14 1,100 M 83.5% 86.5% backbone only
ViT-g/14 1,100 M 83.7% 87.1% backbone only

Pretrained backbones (via PyTorch Hub)

Please follow the instructions here to install PyTorch (the only required dependency for loading the model). Installing PyTorch with CUDA support is strongly recommended.

A corresponding model card is included in the repository.

import torch

# DINOv2
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')

# DINOv2 with registers
dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')

Pretrained heads - Image classification

backbone with
registers
download
ImageNet
ViT-S/14 distilled linear head (1 layer, 4 layers)
ViT-S/14 distilled linear head (1 layer, 4 layers)
ViT-B/14 distilled linear head (1 layer, 4 layers)
ViT-B/14 distilled linear head (1 layer, 4 layers)
ViT-L/14 distilled linear head (1 layer, 4 layers)
ViT-L/14 distilled linear head (1 layer, 4 layers)
ViT-g/14 linear head (1 layer, 4 layers)
ViT-g/14 linear head (1 layer, 4 layers)

The (full) classifier models can be loaded via PyTorch Hub:

import torch

# DINOv2
dinov2_vits14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc')
dinov2_vitb14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc')
dinov2_vitl14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc')
dinov2_vitg14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc')

# DINOv2 with registers
dinov2_vits14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg_lc')
dinov2_vitb14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg_lc')
dinov2_vitl14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg_lc')
dinov2_vitg14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc')

Pretrained heads - Depth estimation

backbone download head
NYUd KITTI
ViT-S/14 distilled linear (1 layer, 4 layers), DPT linear (1 layer, 4 layers), DPT
ViT-B/14 distilled linear (1 layer, 4 layers), DPT linear (1 layer, 4 layers), DPT
ViT-L/14 distilled linear (1 layer, 4 layers), DPT linear (1 layer, 4 layers), DPT
ViT-g/14 linear (1 layer, 4 layers), DPT linear (1 layer, 4 layers), DPT

Pretrained heads - Semantic segmentation

backbone download model download head
ADE20K ADE20K VOC2012
ViT-S/14 distilled linear, multi-scale linear, multi-scale
ViT-B/14 distilled linear, multi-scale linear, multi-scale
ViT-L/14 distilled linear, multi-scale linear, multi-scale
ViT-g/14 Mask2Former linear, multi-scale linear, multi-scale

Installation

The training and evaluation code requires PyTorch 2.0 and xFormers 0.0.18 as well as a number of other 3rd party packages. Note that the code has only been tested with the specified versions and also expects a Linux environment. To setup all the required dependencies for training and evaluation, please follow the instructions below:

conda (Recommended) - Clone the repository and then create and activate a dinov2 conda environment using the provided environment definition:

conda env create -f conda.yaml
conda activate dinov2

pip - Clone the repository and then use the provided requirements.txt to install the dependencies:

pip install -r requirements.txt

For dense tasks (depth estimation and semantic segmentation), there are additional dependencies (specific versions of mmcv and mmsegmentation) which are captured in the extras dependency specifications:

conda (Recommended):

conda env create -f conda-extras.yaml
conda activate dinov2-extras

pip:

pip install -r requirements.txt -r requirements-extras.txt

Data preparation

ImageNet-1k

The root directory of the dataset should hold the following contents:

  • <ROOT>/test/ILSVRC2012_test_00000001.JPEG
  • <ROOT>/test/[..]
  • <ROOT>/test/ILSVRC2012_test_00100000.JPEG
  • <ROOT>/train/n01440764/n01440764_10026.JPEG
  • <ROOT>/train/[...]
  • <ROOT>/train/n15075141/n15075141_9993.JPEG
  • <ROOT>/val/n01440764/ILSVRC2012_val_00000293.JPEG
  • <ROOT>/val/[...]
  • <ROOT>/val/n15075141/ILSVRC2012_val_00049174.JPEG
  • <ROOT>/labels.txt

The provided dataset implementation expects a few additional metadata files to be present under the extra directory:

  • <EXTRA>/class-ids-TRAIN.npy
  • <EXTRA>/class-ids-VAL.npy
  • <EXTRA>/class-names-TRAIN.npy
  • <EXTRA>/class-names-VAL.npy
  • <EXTRA>/entries-TEST.npy
  • <EXTRA>/entries-TRAIN.npy
  • <EXTRA>/entries-VAL.npy

These metadata files can be generated (once) with the following lines of Python code:

from dinov2.data.datasets import ImageNet

for split in ImageNet.Split:
    dataset = ImageNet(split=split, root="<ROOT>", extra="<EXTRA>")
    dataset.dump_extra()

Note that the root and extra directories do not have to be distinct directories.

ImageNet-22k

Please adapt the dataset class to match your local setup.


⚠️ To execute the commands provided in the next sections for training and evaluation, the dinov2 package should be included in the Python module search path, i.e. simply prefix the command to run with PYTHONPATH=..

Training

Fast setup: training DINOv2 ViT-L/16 on ImageNet-1k

Run DINOv2 training on 4 A100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit:

python dinov2/run/train/train.py \
    --nodes 4 \
    --config-file dinov2/configs/train/vitl16_short.yaml \
    --output-dir <PATH/TO/OUTPUT/DIR> \
    train.dataset_path=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>

Training time is approximately 1 day and the resulting checkpoint should reach 81.6% on k-NN eval and 82.9% on linear eval.

The training code saves the weights of the teacher in the eval folder every 12500 iterations for evaluation.

Long setup: training DINOv2 ViT-L/14 on ImageNet-22k

Run DINOv2 training on 12 A100-80GB nodes (96 GPUs) in a SLURM cluster environment with submitit:

python dinov2/run/train/train.py \
    --nodes 12 \
    --config-file dinov2/configs/train/vitl14.yaml \
    --output-dir <PATH/TO/OUTPUT/DIR> \
    train.dataset_path=ImageNet22k:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>

Training time is approximately 3.3 days and the resulting checkpoint should reach 82.0% on k-NN eval and 84.5% on linear eval.

The training code saves the weights of the teacher in the eval folder every 12500 iterations for evaluation.

Evaluation

The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node:

k-NN classification on ImageNet-1k

python dinov2/run/eval/knn.py \
    --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
    --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
    --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/knn \
    --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
    --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>

Logistic regression classification on ImageNet-1k

python dinov2/run/eval/log_regression.py \
    --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
    --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
    --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/logreg \
    --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
    --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>

Linear classification with data augmentation on ImageNet-1k

python dinov2/run/eval/linear.py \
    --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
    --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
    --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/linear \
    --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
    --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>

We release the weights from evaluating the different models:

model with
registers
ImageNet
top-1
linear evaluation
ViT-S/14 distilled 81.1% linear head weights
ViT-S/14 distilled 80.8% linear head weights
ViT-B/14 distilled 84.5% linear head weights
ViT-B/14 distilled 84.4% linear head weights
ViT-L/14 distilled 86.3% linear head weights
ViT-L/14 distilled 86.5% linear head weights
ViT-g/14 86.5% linear head weights
ViT-g/14 87.0% linear head weights

The performance of the provided pretrained model weights can be evaluated as follows on ImageNet-1k:

python dinov2/run/eval/linear.py \
    --config-file dinov2/configs/eval/vitg14_pretrain.yaml \
    --pretrained-weights https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth \
    --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
    --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>

Notebooks

A few notebooks are provided to help the community leverage the models and code:

  • Depth estimation - How to load and use the depth heads in combination with a matching backbone via mmcv
  • Semantic segmentation - How to load and use the segmentation heads in combination with a matching backbone via mmcv, and also how to load and use the Mask2Former-based segmentation model trained on ADE20K

License

DINOv2 code and model weights are released under the Apache License 2.0. See LICENSE for additional details.

Contributing

See contributing and the code of conduct.

Citing DINOv2

If you find this repository useful, please consider giving a star ⭐ and citation 🦖:

@misc{oquab2023dinov2,
  title={DINOv2: Learning Robust Visual Features without Supervision},
  author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
  journal={arXiv:2304.07193},
  year={2023}
}
@misc{darcet2023vitneedreg,
  title={Vision Transformers Need Registers},
  author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr},
  journal={arXiv:2309.16588},
  year={2023}
}

About

PyTorch code and models for the DINOv2 self-supervised learning method.

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 68.5%
  • Python 31.5%