-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First release: Quadtree implementations and image tokenizers
- Loading branch information
1 parent
60ccbb4
commit 93d9f6c
Showing
42 changed files
with
2,042 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
cache_path.txt | ||
**__pycache__ | ||
examples/playground.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,57 @@ | ||
# Vision Transformers with Mixed-Resolution Tokenization | ||
Official repo. Code and models coming soon! | ||
Official repo for https://arxiv.org/abs/2304.00287 (published in CVPRW 2023). | ||
|
||
 | ||
|
||
**Current release:** Quadtree implementations and image tokenizers. | ||
|
||
**To be released:** inference code and trained models. | ||
|
||
Let me know if you're also interested in the Grad-CAM oracle code, or training code for reproducing the experiments in the paper (based on the timm library). | ||
|
||
</br> | ||
|
||
## Setup | ||
Install torch and torchvision, e.g. by following the [official instructions](https://pytorch.org/get-started/locally/). | ||
|
||
</br> | ||
|
||
## Examples | ||
See notebooks under `examples/`: | ||
* **01_vis_quadtree.ipynb** \ | ||
Quadtree tokenization examples with different patch scorers. | ||
|
||
* **02_tokenization.ipynb** \ | ||
Usage examples for the Quadtree image tokenizer and the vanilla ViT tokenizer. \ | ||
The tokenizers prepare input images to be used as input for a standard Transformer model: \ | ||
they pass the patch pixels through an encoding layer, add sinusoidal position embeddings \ | ||
based on patch locations, and prepend the cls_token. | ||
|
||
* **03_compare_implementations.ipynb** \ | ||
We provide 3 different Quadtree implementations. \ | ||
This notebook shows that they produce identical results, and compares runtime. | ||
|
||
## Patch scorers | ||
We provide implementations for several patch scorers: | ||
|
||
* **FeatureBasedPatchScorer:** the main scoring method presented in our paper, which uses a tiny feature extraction network to create patch representations and estimate semantic information loss. | ||
|
||
* **PixelBlurPatchScorer:** our baseline patch scorer which estimates pixel information loss. Often used in Quadtrees for image compression. | ||
|
||
* **PrecomputedPatchScorer:** useful for running Quadtrees with precomputed patch scores. We use it to visualize the tokenizations induced by our Grad-CAM oracle patch scorer, used for analysis in the paper. | ||
|
||
* **RandomPatchScorer:** fast and useful for sanity checks. Supports seeding. | ||
|
||
</br> | ||
|
||
## Quadtree implementations | ||
We provide 3 different GPU-friendly implementations of the Saliency-Based Quadtree algorithm. They produce identical results. They share the same batchified code for image patchifying and patch scoring, and differ in their implementation of the patch-splitting logic. | ||
|
||
1. **Dict-Lookup Quadtree** \ | ||
When splitting a patch, its children are retrieved from a dictionary by their coordinates. This is perhaps the easiest implementation to read (see `mixed_res.quadtree_impl.quadtree_dict_lookup.Quadtree.run`), but it's also the slowest one as the actual splitting logic isn't batchified. | ||
|
||
2. **Tensor-Lookup Quadtree** \ | ||
Similar to the dict-lookup Quadtree, but the box indices are kept in a tensor lookup table, where tensor indices correspond to patch location and scale. Much faster than the dict-lookup Quadtree. | ||
|
||
3. **Z-Curve Quadtree** \ | ||
Uses an indexing scheme based on z-order curves, where the children of a patch can be found via a simple arithmetic operation. This is our fastest implementation. Unfortunately, it only works for image sizes that are a power of 2, e.g. 256 pixels. |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Image Tokenization\n", | ||
"Usage examples for the Quadtree image tokenizer and the vanilla ViT tokenizer.\n", | ||
"\n", | ||
"The tokenizers prepare input images to be used as input for a standard Transformer model: \\\n", | ||
"they pass the patch pixels through an encoding layer, add sinusoidal position embeddings \\\n", | ||
"based on patch locations, and prepend the cls_token." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2\n", | ||
"import setup" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from torch import nn\n", | ||
"\n", | ||
"from mixed_res.patch_scorers.random_patch_scorer import RandomPatchScorer\n", | ||
"from mixed_res.quadtree_impl.quadtree_z_curve import ZCurveQuadtreeRunner\n", | ||
"from mixed_res.tokenization.patch_embed import FlatPatchEmbed, PatchEmbed\n", | ||
"from mixed_res.tokenization.tokenizers import (QuadtreeTokenizer,\n", | ||
" VanillaTokenizer)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"device = \"cuda\"\n", | ||
"image_size = 256\n", | ||
"channels = 3\n", | ||
"min_patch_size = 16\n", | ||
"max_patch_size = 64\n", | ||
"quadtree_num_patches = 100\n", | ||
"batch_size = 5\n", | ||
"embed_dim = 384\n", | ||
"\n", | ||
"images = torch.randn(batch_size, channels, image_size, image_size, device=device)\n", | ||
"cls_token = nn.Parameter(torch.randn(embed_dim)).to(device)" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Tokenize images with a Quadtree tokenizer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"torch.Size([5, 101, 384])" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# These will probably be initialized inside your ViT's __init__ method\n", | ||
"patch_embed = FlatPatchEmbed(img_size=image_size, patch_size=min_patch_size, embed_dim=embed_dim).to(device)\n", | ||
"quadtree_runner = ZCurveQuadtreeRunner(quadtree_num_patches, min_patch_size, max_patch_size)\n", | ||
"patch_scorer = RandomPatchScorer()\n", | ||
"quadtree_tokenizer = QuadtreeTokenizer(patch_embed, cls_token, quadtree_runner, patch_scorer)\n", | ||
"\n", | ||
"# put this in your forward method\n", | ||
"token_embeds = quadtree_tokenizer.tokenize(images)\n", | ||
"token_embeds.shape # [batch_size, 1 + num_patches, embed_dim]" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Tokenize images with a vanilla ViT tokenizer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"torch.Size([5, 257, 384])" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# These will probably be initialized inside your ViT's __init__ method\n", | ||
"patch_embed = PatchEmbed(img_size=image_size, patch_size=min_patch_size, embed_dim=embed_dim).to(device)\n", | ||
"vanilla_tokenizer = VanillaTokenizer(patch_embed, cls_token)\n", | ||
"\n", | ||
"# put this in your forward method\n", | ||
"token_embeds = vanilla_tokenizer.tokenize(images)\n", | ||
"token_embeds.shape # [batch_size, 1 + (image_size / patch_size)**2, embed_dim]" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "mlskel", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.16" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Ensure that all Quadtree implementations are equivalent\n", | ||
"We provide 3 different Quadtree implementations.\n", | ||
"\n", | ||
"This notebook shows that they produce identical results, and compares runtime." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2\n", | ||
"import setup" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from mixed_res.quadtree_impl.quadtree_dict_lookup import DictLookupQuadtreeRunner\n", | ||
"from mixed_res.quadtree_impl.quadtree_tensor_lookup import TensorLookupQuadtreeRunner\n", | ||
"from mixed_res.quadtree_impl.quadtree_z_curve import ZCurveQuadtreeRunner\n", | ||
"from mixed_res.patch_scorers.random_patch_scorer import RandomPatchScorer\n", | ||
"from mixed_res.quadtree_impl.utils import sort_by_meta, is_power_of_2\n", | ||
"\n", | ||
"device = \"cuda\"\n", | ||
"batch_size = 128\n", | ||
"image_size = 256\n", | ||
"num_patches = 100\n", | ||
"min_patch_size = 16\n", | ||
"max_patch_size = 64\n", | ||
"\n", | ||
"images = torch.randn(batch_size, 3, image_size, image_size, device=device)" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Assert equivalence" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"dict_lookup and tensor_lookup are equivalent\n", | ||
"dict_lookup and z_curve are equivalent\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"patch_scorer = RandomPatchScorer(seed=1337)\n", | ||
"\n", | ||
"# Init Quadtree runners from different implementations\n", | ||
"runner_dict_lookup = DictLookupQuadtreeRunner(num_patches, min_patch_size, max_patch_size)\n", | ||
"runner_tensor_lookup = TensorLookupQuadtreeRunner(num_patches, min_patch_size, max_patch_size)\n", | ||
"if is_power_of_2(image_size):\n", | ||
" runner_z_curve = ZCurveQuadtreeRunner(num_patches, min_patch_size, max_patch_size)\n", | ||
"\n", | ||
"# Run Quadtrees\n", | ||
"res_dict_lookup = runner_dict_lookup.run_batch_quadtree(images, patch_scorer)\n", | ||
"res_tensor_lookup = runner_tensor_lookup.run_batch_quadtree(images, patch_scorer)\n", | ||
"if is_power_of_2(image_size):\n", | ||
" res_z_curve = runner_z_curve.run_batch_quadtree(images, patch_scorer)\n", | ||
"\n", | ||
"# Sort results by metadata (patch location and scale) to make them comparable\n", | ||
"res_dict_lookup = sort_by_meta(res_dict_lookup)\n", | ||
"res_tensor_lookup = sort_by_meta(res_tensor_lookup)\n", | ||
"if is_power_of_2(image_size):\n", | ||
" res_z_curve = sort_by_meta(res_z_curve)\n", | ||
"\n", | ||
"# Assert that results are equivalent\n", | ||
"assert torch.allclose(res_dict_lookup, res_tensor_lookup)\n", | ||
"print(\"dict_lookup and tensor_lookup are equivalent\")\n", | ||
"if is_power_of_2(image_size):\n", | ||
" assert torch.allclose(res_dict_lookup, res_z_curve)\n", | ||
" print(\"dict_lookup and z_curve are equivalent\")" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Compare runtimes" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"7.35 ms ± 16.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | ||
"21.2 ms ± 543 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | ||
"42.1 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%timeit runner_z_curve.run_batch_quadtree(images, patch_scorer)\n", | ||
"%timeit runner_tensor_lookup.run_batch_quadtree(images, patch_scorer)\n", | ||
"%timeit runner_dict_lookup.run_batch_quadtree(images, patch_scorer)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "mixed_res", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.16" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+40.7 KB
examples/imagenet_val/oracle_importance_maps/n01514859/ILSVRC2012_val_00033538.pt
Binary file not shown.
Binary file added
BIN
+40.7 KB
examples/imagenet_val/oracle_importance_maps/n02101556/ILSVRC2012_val_00030747.pt
Binary file not shown.
Binary file added
BIN
+40.7 KB
examples/imagenet_val/oracle_importance_maps/n02102973/ILSVRC2012_val_00030888.pt
Binary file not shown.
Binary file added
BIN
+40.7 KB
examples/imagenet_val/oracle_importance_maps/n02104365/ILSVRC2012_val_00043873.pt
Binary file not shown.
Binary file added
BIN
+40.7 KB
examples/imagenet_val/oracle_importance_maps/n02112137/ILSVRC2012_val_00042887.pt
Binary file not shown.
Binary file added
BIN
+40.7 KB
examples/imagenet_val/oracle_importance_maps/n02119022/ILSVRC2012_val_00027828.pt
Binary file not shown.
Binary file added
BIN
+40.7 KB
examples/imagenet_val/oracle_importance_maps/n04152593/ILSVRC2012_val_00002810.pt
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import os | ||
import sys | ||
from pathlib import Path | ||
|
||
cache_path = Path("cache_path.txt") | ||
if cache_path.exists(): | ||
os.environ["XDG_CACHE_HOME"] = cache_path.read_text() | ||
|
||
if '..' not in sys.path: | ||
sys.path.append('..') |
Oops, something went wrong.