|
2 | 2 |
|
3 | 3 | A convolution based approach to patchifying a 2D image w/ embedding projection.
|
4 | 4 |
|
5 |
| -Based on the impl in https://github.com/google-research/vision_transformer |
| 5 | +Based on code in: |
| 6 | + * https://github.com/google-research/vision_transformer |
| 7 | + * https://github.com/google-research/big_vision/tree/main/big_vision |
6 | 8 |
|
7 | 9 | Hacked together by / Copyright 2020 Ross Wightman
|
8 | 10 | """
|
| 11 | +import logging |
| 12 | +from typing import List |
| 13 | + |
| 14 | +import torch |
9 | 15 | from torch import nn as nn
|
| 16 | +import torch.nn.functional as F |
10 | 17 |
|
11 | 18 | from .helpers import to_2tuple
|
12 | 19 | from .trace_utils import _assert
|
13 | 20 |
|
| 21 | +_logger = logging.getLogger(__name__) |
| 22 | + |
14 | 23 |
|
15 | 24 | class PatchEmbed(nn.Module):
|
16 | 25 | """ 2D Image to Patch Embedding
|
@@ -46,3 +55,122 @@ def forward(self, x):
|
46 | 55 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
47 | 56 | x = self.norm(x)
|
48 | 57 | return x
|
| 58 | + |
| 59 | + |
| 60 | +def resample_patch_embed( |
| 61 | + patch_embed, |
| 62 | + new_size: List[int], |
| 63 | + interpolation: str = 'bicubic', |
| 64 | + antialias: bool = True, |
| 65 | + verbose: bool = False, |
| 66 | +): |
| 67 | + """Resample the weights of the patch embedding kernel to target resolution. |
| 68 | + We resample the patch embedding kernel by approximately inverting the effect |
| 69 | + of patch resizing. |
| 70 | +
|
| 71 | + Code based on: |
| 72 | + https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py |
| 73 | +
|
| 74 | + With this resizing, we can for example load a B/8 filter into a B/16 model |
| 75 | + and, on 2x larger input image, the result will match. |
| 76 | +
|
| 77 | + Args: |
| 78 | + patch_embed: original parameter to be resized. |
| 79 | + new_size (tuple(int, int): target shape (height, width)-only. |
| 80 | + interpolation (str): interpolation for resize |
| 81 | + antialias (bool): use anti-aliasing filter in resize |
| 82 | + verbose (bool): log operation |
| 83 | + Returns: |
| 84 | + Resized patch embedding kernel. |
| 85 | + """ |
| 86 | + import numpy as np |
| 87 | + |
| 88 | + assert len(patch_embed.shape) == 4, "Four dimensions expected" |
| 89 | + assert len(new_size) == 2, "New shape should only be hw" |
| 90 | + old_size = patch_embed.shape[-2:] |
| 91 | + if tuple(old_size) == tuple(new_size): |
| 92 | + return patch_embed |
| 93 | + |
| 94 | + if verbose: |
| 95 | + _logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.") |
| 96 | + |
| 97 | + def resize(x_np, _new_size): |
| 98 | + x_tf = torch.Tensor(x_np)[None, None, ...] |
| 99 | + x_upsampled = F.interpolate( |
| 100 | + x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy() |
| 101 | + return x_upsampled |
| 102 | + |
| 103 | + def get_resize_mat(_old_size, _new_size): |
| 104 | + mat = [] |
| 105 | + for i in range(np.prod(_old_size)): |
| 106 | + basis_vec = np.zeros(_old_size) |
| 107 | + basis_vec[np.unravel_index(i, _old_size)] = 1. |
| 108 | + mat.append(resize(basis_vec, _new_size).reshape(-1)) |
| 109 | + return np.stack(mat).T |
| 110 | + |
| 111 | + resize_mat = get_resize_mat(old_size, new_size) |
| 112 | + resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T)) |
| 113 | + |
| 114 | + def resample_kernel(kernel): |
| 115 | + resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) |
| 116 | + return resampled_kernel.reshape(new_size) |
| 117 | + |
| 118 | + v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1) |
| 119 | + return v_resample_kernel(patch_embed) |
| 120 | + |
| 121 | + |
| 122 | +# def divs(n, m=None): |
| 123 | +# m = m or n // 2 |
| 124 | +# if m == 1: |
| 125 | +# return [1] |
| 126 | +# if n % m == 0: |
| 127 | +# return [m] + divs(n, m - 1) |
| 128 | +# return divs(n, m - 1) |
| 129 | +# |
| 130 | +# |
| 131 | +# class FlexiPatchEmbed(nn.Module): |
| 132 | +# """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT) |
| 133 | +# FIXME WIP |
| 134 | +# """ |
| 135 | +# def __init__( |
| 136 | +# self, |
| 137 | +# img_size=240, |
| 138 | +# patch_size=16, |
| 139 | +# in_chans=3, |
| 140 | +# embed_dim=768, |
| 141 | +# base_img_size=240, |
| 142 | +# base_patch_size=32, |
| 143 | +# norm_layer=None, |
| 144 | +# flatten=True, |
| 145 | +# bias=True, |
| 146 | +# ): |
| 147 | +# super().__init__() |
| 148 | +# self.img_size = to_2tuple(img_size) |
| 149 | +# self.patch_size = to_2tuple(patch_size) |
| 150 | +# self.num_patches = 0 |
| 151 | +# |
| 152 | +# # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48) |
| 153 | +# self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30) |
| 154 | +# |
| 155 | +# self.base_img_size = to_2tuple(base_img_size) |
| 156 | +# self.base_patch_size = to_2tuple(base_patch_size) |
| 157 | +# self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)]) |
| 158 | +# self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1] |
| 159 | +# |
| 160 | +# self.flatten = flatten |
| 161 | +# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias) |
| 162 | +# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
| 163 | +# |
| 164 | +# def forward(self, x): |
| 165 | +# B, C, H, W = x.shape |
| 166 | +# |
| 167 | +# if self.patch_size == self.base_patch_size: |
| 168 | +# weight = self.proj.weight |
| 169 | +# else: |
| 170 | +# weight = resample_patch_embed(self.proj.weight, self.patch_size) |
| 171 | +# patch_size = self.patch_size |
| 172 | +# x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size) |
| 173 | +# if self.flatten: |
| 174 | +# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC |
| 175 | +# x = self.norm(x) |
| 176 | +# return x |
0 commit comments