Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DyHead PyTorch Implementation #10

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,23 @@ This project has adopted the [Microsoft Open Source Code of Conduct](https://ope
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

------
# My Notes

Hi there, I am a recent undergrad graduate and am currently looking for ML positions. I always wanted to learn how to implement code from a paper, and I was happy to implement the DyHead attachment that can be used by others.

All the code I wrote uses PyTorch, these are all the modules:
1. [`concat_fpn_output.py`](./torch/concat_fpn_output.py) - This takes the output of the FPN and concatenates all the levels to the median height and width of all the levels via upsampling or downsampling.
2. [`attention_layers.py`](./torch/attention_layers.py) - This contains all the classes for the three attention mechanisms.
- Big Thanks to user Github [Islanna](https://github.com/Islanna/), she implemented code from the [Dynamic ReLU Paper](https://arxiv.org/pdf/2003.10027.pdf). The Task-aware Attention layer uses the same technique from Dynamic-ReLU-A that constructs a dynamic ReLU funtion that are both spatial and channel shared. I used her code as a way to understand how to implement it and I used the same techniques but made the code simpler for my own learning process, but all credits to her. This is her repository: https://github.com/Islanna/DynamicReLU.

4. [`DyHead.py`](./torch/DyHead.py) - This contains the classes to construct a single DyHead block or the entire DyHead.

The [`DyHead_Example.ipynb`](./torch/DyHead_Example.ipynb) notebook demonstrates how all the classes above work, I would encourage to have a look.

The code used is not the most efficient, but the code is well documented and easily understandable. However, I am sure changes to make it more efficient is not a problem.

## Future Additions:
The code does not contruct a full Object Detection model with a DyHead. This is the case because I currently need to change my focus on to just find a new position but also I was confused about the implementation of ROI Pooling on the tensor *F* since dimensions do not contain the spacial dimensions since it was reshaped to be LxSxC not LxHxWxC. I would like to hear more about how this is implemented.

So in the future when I have more time and a better understanding, I would like to implement both one-stage and two-stage detectors using PyTorch's Built-in FasterRCNN modules to easily adapt the inclusion of DyHead for detection purposes.
Binary file added imgs/.DS_Store
Binary file not shown.
Binary file added imgs/DyHead_Block.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/Figure_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/scale_attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/spatial_attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/task_aware.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions torch/DyHead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch.nn as nn
from attention_layers import Scale_Aware_Layer, Spatial_Aware_Layer, Task_Aware_Layer
from collections import OrderedDict

class DyHead_Block(nn.Module):
def __init__(self, L, S, C):
super(DyHead_Block, self).__init__()
# Saving all dimension sizes of F
self.L_size = L
self.S_size = S
self.C_size = C

# Inititalizing all attention layers
self.scale_attention = Scale_Aware_Layer(s_size=self.S_size)
self.spatial_attention = Spatial_Aware_Layer(L_size=self.L_size)
self.task_attention = Task_Aware_Layer(num_channels=self.C_size)

def forward(self, F_tensor):
scale_output = self.scale_attention(F_tensor)
spacial_output = self.spatial_attention(scale_output)
task_output = self.task_attention(spacial_output)

return task_output

def DyHead(num_blocks, L, S, C):
blocks = [('Block_{}'.format(i+1),DyHead_Block(L, S, C)) for i in range(num_blocks)]

return nn.Sequential(OrderedDict(blocks))
596 changes: 596 additions & 0 deletions torch/DyHead_Example.ipynb

Large diffs are not rendered by default.

157 changes: 157 additions & 0 deletions torch/attention_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import DeformConv2d


class Scale_Aware_Layer(nn.Module):
# Constructor
def __init__(self, s_size):
super(Scale_Aware_Layer, self).__init__()

# Average Pooling
self.avg_layer = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)

#1x1 Conv layer
self.conv = nn.Conv2d(in_channels=s_size, out_channels=1, kernel_size=1)

# Hard Sigmoid
self.hard_sigmoid = nn.Hardsigmoid()

# ReLU function
self.relu = nn.ReLU()

def forward(self, F):

# Transposing input from (batch_size, L, S, C) to (batch_size, S, L, C) so we can use convolutional layer over the level dimension L
x = F.transpose(dim0=2, dim1=1)

# Passing tensor through avg pool layer
x = self.avg_layer(x)

# Passing tensor through Conv layer
x = self.conv(x)

# Reshaping Tensor from (batch_size, 1, L, C) to (batch_size, L, 1, C) to then be multiplied to F
x = x.transpose(dim0=1, dim1=2)

# Passing conv output to relu
x = self.relu(x)

# Passing tensor to hard sigmoid function
pi_L = self.hard_sigmoid(x)

# pi_L: (batch_size, L, 1, C)
# F: (batch_size, L, S, C)
return pi_L * F

class Spatial_Aware_Layer(nn.Module):
# Constructor
def __init__(self, L_size, kernel_height=3, kernel_width=3, padding=1, stride=1, dilation=1, groups=1):
super(Spatial_Aware_Layer, self).__init__()

self.in_channels = L_size
self.out_channels = L_size

self.kernel_size = (kernel_height, kernel_width)
self.padding = padding
self.stride = stride
self.dilation = dilation
self.K = kernel_height * kernel_width
self.groups = groups

# 3x3 Convolution with 3K out_channel output as described in Deform Conv2 paper
self.offset_and_mask_conv = nn.Conv2d(in_channels=self.in_channels,
out_channels=3*self.K, #3K depth
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=dilation)

self.deform_conv = DeformConv2d(in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups)
def forward(self, F):
# Generating offesets and masks (or modulators) for convolution operation
offsets_and_masks = self.offset_and_mask_conv(F)

# Separating offsets and masks as described in Deform Conv v2 paper
offset = offsets_and_masks[:, :2*self.K, :, :] # First 2K channels
mask = torch.sigmoid(offsets_and_masks[:, 2*self.K:, : , :]) # Last 1K channels and passing it through sigmoid

# Passing offsets, masks, and F into deform conv layer
spacial_output = self.deform_conv(F, offset, mask)
return spacial_output

# DyReLUA technique from Dynamic ReLU paper
class DyReLUA(nn.Module):
def __init__(self, channels, reduction=8, k=2, lambdas=None, init_values=None):
super(DyReLUA, self).__init__()

self.fc1 = nn.Linear(channels, channels // reduction)
self.fc2 = nn.Linear(channels//reduction, 2*k)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()

# Defining lambdas in form of [La1, La2, Lb1, Lb2]
if lambdas is not None:
self.lambdas = lambdas
else:
# Default lambdas from DyReLU paper
self.lambdas = torch.tensor([1.0, 1.0, 0.5, 0.5], dtype=torch.float)

# Defining Initializing values in form of [alpha1, alpha2, Beta1, Beta2]
if lambdas is not None:
self.init_values = init_values
else:
# Default initializing values of DyReLU paper
self.init_values = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float)

def forward(self, F_tensor):

# Global Averaging F
kernel_size = F_tensor.shape[2:] # Getting HxW of F
gap_output = F.avg_pool2d(F_tensor, kernel_size)

# Flattening gap_output from (batch_size, C, 1, 1) to (batch_size, C)
gap_output = gap_output.flatten(start_dim=1)

# Passing Global Average output through Fully-Connected Layers
x = self.relu(self.fc1(gap_output))
x = self.fc2(x)

# Normalization between (-1, 1)
residuals = 2 * self.sigmoid(x) - 1

# Getting values of theta, and separating alphas and betas
theta = self.init_values + self.lambdas * residuals # Contains[alpha1(x), alpha2(x), Beta1(x), Beta2(x)]
alphas = theta[0, :2]
betas = theta[0, 2:]

# Performing maximum on both piecewise functions
output = torch.maximum((alphas[0] * F_tensor + betas[0]), (alphas[1] * F_tensor + betas[1]))

return output

class Task_Aware_Layer(nn.Module):
# Defining constructor
def __init__(self, num_channels):
super(Task_Aware_Layer, self).__init__()

# DyReLUA relu
self.dynamic_relu = DyReLUA(num_channels)

def forward(self, F_tensor):
# Permutating F from (batch_size, L, S, C) to (batch_size, C, L, S) so we can reduce the dimensions over LxS
F_tensor = F_tensor.permute(0, 3, 1, 2)

output = self.dynamic_relu(F_tensor)

# Reversing the permutation
output = output.permute(0, 2, 3, 1)

return output
36 changes: 36 additions & 0 deletions torch/concat_fpn_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class concat_feature_maps(nn.Module):
def __init__(self):
super(concat_feature_maps, self).__init__()

def forward(self, fpn_output):
# Calculating median height to upsample or desample each fpn levels
heights = []
level_tensors = []
for key, values in fpn_output.items():
if key != 'pool':
heights.append(values.shape[2])
level_tensors.append(values)
median_height = int(np.median(heights))

# Upsample and Desampling tensors to median height and width
for i in range(len(level_tensors)):
level = level_tensors[i]
# If level height is greater than median, then downsample with interpolate
if level.shape[2] > median_height:
level = F.interpolate(input=level, size=(median_height, median_height),mode='nearest')
# If level height is less than median, then upsample
else:
level = F.interpolate(input=level, size=(median_height, median_height), mode='nearest')
level_tensors[i] = level

# Concating all levels with dimensions (batch_size, levels, C, H, W)
concat_levels = torch.stack(level_tensors, dim=1)

# Reshaping tensor from (batch_size, levels, C, H, W) to (batch_size, levels, HxW=S, C)
concat_levels = concat_levels.flatten(start_dim=3).transpose(dim0=2, dim1=3)
return concat_levels