Skip to content

Add option for ML-Decoder - an improved classification head #1012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions timm/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def create_model(
scriptable=None,
exportable=None,
no_jit=None,
use_ml_decoder_head=False,
**kwargs):
"""Create a model

Expand Down Expand Up @@ -80,6 +81,10 @@ def create_model(
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(pretrained=pretrained, **kwargs)

if use_ml_decoder_head:
from timm.models.layers.ml_decoder import add_ml_decoder_head
model = add_ml_decoder_head(model)

if checkpoint_path:
load_checkpoint(model, checkpoint_path)

Expand Down
156 changes: 156 additions & 0 deletions timm/models/layers/ml_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Optional

import torch
from torch import nn
from torch import nn, Tensor
from torch.nn.modules.transformer import _get_activation_fn


def add_ml_decoder_head(model):
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
model.global_pool = nn.Identity()
del model.fc
num_classes = model.num_classes
num_features = model.num_features
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
model.global_pool = nn.Identity()
del model.classifier
num_classes = model.num_classes
num_features = model.num_features
model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
del model.head
num_classes = model.num_classes
num_features = model.num_features
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
else:
print("Model code-writing is not aligned currently with ml-decoder")
exit(-1)
if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
model.drop_rate = 0
return model


class TransformerDecoderLayerOptimal(nn.Module):
def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu",
layer_norm_eps=1e-5) -> None:
super(TransformerDecoderLayerOptimal, self).__init__()
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)

self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)

self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)

self.activation = _get_activation_fn(activation)

def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = torch.nn.functional.relu
super(TransformerDecoderLayerOptimal, self).__setstate__(state)

def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
tgt = tgt + self.dropout1(tgt)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(tgt, memory, memory)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt


# @torch.jit.script
# class ExtrapClasses(object):
# def __init__(self, num_queries: int, group_size: int):
# self.num_queries = num_queries
# self.group_size = group_size
#
# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap:
# torch.Tensor):
# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size)
# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups])
# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size))
# out = (h * w).sum(dim=2) + class_embed_b
# out = out.view((h.shape[0], self.group_size * self.num_queries))
# return out

@torch.jit.script
class GroupFC(object):
def __init__(self, embed_len_decoder: int):
self.embed_len_decoder = embed_len_decoder

def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
for i in range(self.embed_len_decoder):
h_i = h[:, i, :]
w_i = duplicate_pooling[i, :, :]
out_extrap[:, i, :] = torch.matmul(h_i, w_i)


class MLDecoder(nn.Module):
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
super(MLDecoder, self).__init__()
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
if embed_len_decoder > num_classes:
embed_len_decoder = num_classes

# switching to 768 initial embeddings
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
self.embed_standart = nn.Linear(initial_num_features, decoder_embedding)

# decoder
decoder_dropout = 0.1
num_layers_decoder = 1
dim_feedforward = 2048
layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding,
dim_feedforward=dim_feedforward, dropout=decoder_dropout)
self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder)

# non-learnable queries
self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)
self.query_embed.requires_grad_(False)

# group fully-connected
self.num_classes = num_classes
self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)
self.duplicate_pooling = torch.nn.Parameter(
torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor))
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
torch.nn.init.xavier_normal_(self.duplicate_pooling)
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
self.group_fc = GroupFC(embed_len_decoder)

def forward(self, x):
if len(x.shape) == 4: # [bs,2048, 7,7]
embedding_spatial = x.flatten(2).transpose(1, 2)
else: # [bs, 197,468]
embedding_spatial = x
embedding_spatial_786 = self.embed_standart(embedding_spatial)
embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)

bs = embedding_spatial_786.shape[0]
query_embed = self.query_embed.weight
# tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand
h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768]
h = h.transpose(0, 1)

out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
self.group_fc(h, self.duplicate_pooling, out_extrap)
h_out = out_extrap.flatten(1)[:, :self.num_classes]
h_out += self.duplicate_pooling_bias
logits = h_out
return logits
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
help='input batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='validation batch size override (default: None)')
parser.add_argument('--use-ml-decoder-head', type=int, default=0)

# Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
Expand Down Expand Up @@ -379,7 +380,8 @@ def main():
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint)
checkpoint_path=args.initial_checkpoint,
use_ml_decoder_head=args.use_ml_decoder_head)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
Expand Down