Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed id_tensor registry, so reparametrization works when .cuda() is called #632

Merged
merged 3 commits into from
Jan 22, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/super_gradients/modules/qarepvgg_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type, Union, Mapping, Any
from typing import Type, Union, Mapping, Any, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param activation_type: Type of the nonlinearity
:param activation_type: Type of the nonlinearity (nn.ReLU by default)
:param se_type: Type of the se block (Use nn.Identity to disable SE)
:param stride: Output stride
:param dilation: Dilation factor for 3x3 conv
Expand Down Expand Up @@ -133,10 +133,16 @@ def __init__(
self.identity = Residual()

input_dim = self.in_channels // self.groups
self.id_tensor = torch.zeros((self.in_channels, input_dim, 3, 3))
id_tensor = torch.zeros((self.in_channels, input_dim, 3, 3))
for i in range(self.in_channels):
self.id_tensor[i, i % input_dim, 1, 1] = 1.0
self.id_tensor = self.id_tensor.to(dtype=self.branch_1x1.weight.dtype, device=self.branch_1x1.weight.device)
id_tensor[i, i % input_dim, 1, 1] = 1.0

self.id_tensor: Optional[torch.Tensor]
self.register_buffer(
name="id_tensor",
tensor=id_tensor.to(dtype=self.branch_1x1.weight.dtype, device=self.branch_1x1.weight.device),
persistent=False, # so it's not saved in state_dict
)
else:
self.identity = None

Expand Down Expand Up @@ -234,7 +240,10 @@ def _fuse_bn_tensor(self, kernel, bias, running_mean, running_var, gamma, beta,
A = gamma / std
A_ = A.expand_as(kernel.transpose(0, -1)).transpose(0, -1)

return kernel * A_, bias * A + b
fused_kernel = kernel * A_
fused_bias = bias * A + b

return fused_kernel, fused_bias

def full_fusion(self):
"""Fuse everything into Conv-Act-SE, non-trainable, parameters detached
Expand Down Expand Up @@ -299,7 +308,10 @@ def partial_fusion(self):
self.fully_fused = False

def fuse_block_residual_branches(self):
self.full_fusion()
# inference frameworks will take care of resulting conv-bn-act-se
# no need to fuse post_bn prematurely if it is there
# call self.full_fusion() if you need it
self.partial_fusion()

def from_repvgg(self, repvgg_block: RepVGGBlock):
def from_repvgg(self, src: RepVGGBlock):
raise NotImplementedError