From 5f527d78e31d1fee4370414ef926f5c8657de2b3 Mon Sep 17 00:00:00 2001 From: willpa11213 Date: Mon, 24 Oct 2022 15:43:23 +0800 Subject: [PATCH 1/4] add svtr backbone --- mmocr/models/textrecog/backbones/__init__.py | 3 +- mmocr/models/textrecog/backbones/svtr.py | 650 ++++++++++++++++++ .../test_backbones/test_svtr.py | 84 +++ 3 files changed, 736 insertions(+), 1 deletion(-) create mode 100644 mmocr/models/textrecog/backbones/svtr.py create mode 100644 tests/test_models/test_textrecog/test_backbones/test_svtr.py diff --git a/mmocr/models/textrecog/backbones/__init__.py b/mmocr/models/textrecog/backbones/__init__.py index 3201de388..43bd3926a 100644 --- a/mmocr/models/textrecog/backbones/__init__.py +++ b/mmocr/models/textrecog/backbones/__init__.py @@ -6,8 +6,9 @@ from .resnet31_ocr import ResNet31OCR from .resnet_abi import ResNetABI from .shallow_cnn import ShallowCNN +from .svtr import SVTR __all__ = [ 'ResNet31OCR', 'MiniVGG', 'NRTRModalityTransform', 'ShallowCNN', - 'ResNetABI', 'ResNet', 'MobileNetV2' + 'ResNetABI', 'ResNet', 'MobileNetV2', 'SVTR' ] diff --git a/mmocr/models/textrecog/backbones/svtr.py b/mmocr/models/textrecog/backbones/svtr.py new file mode 100644 index 000000000..79f51351f --- /dev/null +++ b/mmocr/models/textrecog/backbones/svtr.py @@ -0,0 +1,650 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ + +from mmocr.registry import MODELS + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +def truncated_normal_(tensor, mean=0, std=0.02): + with torch.no_grad(): + size = tensor.size() + tmp = tensor.new_empty(size + (4, )).normal_().cuda() + valid = (tmp < 2) & (tmp > -2) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind.cuda()).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + return tensor + + +class Identity(nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class OverlapPatchEmbed(BaseModule): + """Image to the progressive overlapping Patch Embedding. + + Args: + img_size (int or tuple): The size of input, which will be used to + calculate the out size. Defaults to [32, 100]. + in_channels (int): Number of input channels. Defaults to 3. + embed_dims (int): The dimensions of embedding. Defaults to 768. + num_layers (int, optional): Number of Conv_BN_Layer. Defaults to 2 and + limit to [2, 3]. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + + def __init__(self, + img_size: Union[int, Tuple[int, int]] = [32, 100], + in_channels: int = 3, + embed_dims: int = 768, + num_layers: int = 2, + init_cfg: Optional[Union[Dict, List[Dict]]] = None): + + super().__init__(init_cfg=init_cfg) + + assert num_layers in [2, 3], \ + 'The number of layers must belong to [2, 3]' + self.img_size = img_size + self.net = nn.Sequential() + for num in range(num_layers, 0, -1): + if (num == num_layers): + _input = in_channels + _output = embed_dims // (2**(num - 1)) + self.net.add_module( + f'ConvModule{str(num_layers - num)}', + ConvModule( + in_channels=_input, + out_channels=_output, + kernel_size=3, + stride=2, + padding=1, + bias=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'))) + _input = _output + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (Tensor): A Tensor of shape :math:`(N, C, H, W)`. + + Returns: + Tensor: A tensor of shape math:`(N, HW//16, C)`. + """ + _, _, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model \ + ({self.img_size[0]}*{self.img_size[1]})." + + x = self.net(x).flatten(2).permute(0, 2, 1) + return x + + +class ConvMixer(BaseModule): + """The conv Mixer. + + Args: + dim (int): Number of character components. + num_heads (int, optional): Number of heads. Defaults to 8. + HW (Tuple[int, int], optional): Number of H x W. Defaults to [8, 25]. + local_k (Tuple[int, int], optional): Window size. Defaults to [3, 3]. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int = 8, + HW: Tuple[int, int] = [8, 25], + local_k: Tuple[int, int] = [3, 3], + init_cfg: Optional[Union[Dict, List[Dict]]] = None): + super().__init__(init_cfg) + self.HW = HW + self.embed_dims = embed_dims + self.local_mixer = nn.Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=local_k, + stride=1, + padding=(local_k[0] // 2, local_k[1] // 2), + groups=num_heads) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, HW, C)`. + + Returns: + torch.Tensor: Tensor: A tensor of shape math:`(N, HW, C)`. + """ + h, w = self.HW + x = x.permute(0, 2, 1).reshape([-1, self.embed_dims, h, w]) + x = self.local_mixer(x) + x = x.flatten(2).permute(0, 2, 1) + return x + + +class AttnMixer(BaseModule): + """One of mixer of {'Global', 'Local'}. Defaults to Global Mixer. + + Args: + embed_dims (int): Number of character components. + num_heads (int, optional): Number of heads. Defaults to 8. + mixer (str, optional): The mixer type. Defaults to 'Global'. + HW (Tuple[int, int], optional): Number of H x W. Defaults to [8, 25]. + local_k (Tuple[int, int], optional): Window size. Defaults to [7, 11]. + qkv_bias (bool, optional): Whether a additive bias is required. + Defaults to False. + qk_scale (float, optional): A scaling factor. Defaults to None. + attn_drop (float, optional): A Dropout layer. Defaults to 0.0. + proj_drop (float, optional): A Dropout layer. Defaults to 0.0. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int = 8, + mixer: str = 'Global', + HW: Tuple[int, int] = [8, 25], + local_k: Tuple[int, int] = [7, 11], + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0., + init_cfg: Optional[Union[Dict, List[Dict]]] = None): + super().__init__(init_cfg) + assert mixer in {'Global', 'Local'}, \ + "The type of mixer must belong to {'Global', 'Local'}" + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + self.HW = HW + if HW is not None: + H, W = HW + self.N = H * W + self.C = embed_dims + if mixer == 'Local' and HW is not None: + hk = local_k[0] + wk = local_k[1] + mask = torch.ones([H * W, H + hk - 1, W + wk - 1], + dtype=torch.float32) + for h in range(0, H): + for w in range(0, W): + mask[h * w + w, h:h + hk, w:w + wk] = 0. + mask = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // 2].flatten(1) + mask[mask < -1] = -np.inf + self.mask = mask[None, None, :, :] + self.mixer = mixer + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, H, W, C)`. + + Returns: + torch.Tensor: A Tensor of shape :math:`(N, H, W, C)`. + """ + if self.HW is not None: + N, C = self.N, self.C + else: + _, N, C = x.shape + qkv = self.qkv(x).reshape( + (-1, N, 3, self.num_heads, C // self.num_heads)).permute( + (2, 0, 3, 1, 4)) + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = (q.matmul(k.permute(0, 1, 3, 2))) + if self.mixer == 'Local': + attn += self.mask + attn = F.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = attn.matmul(v).permute(0, 2, 1, 3).reshape(-1, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MLP(BaseModule): + """The MLP block. + + Args: + in_features (int): The input features. + hidden_features (int, optional): The hidden features. + Defaults to None. + out_features (int, optional): The output features. + Defaults to None. + drop (float, optional): cfg of dropout function. Defaults to 0.0. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + + def __init__(self, + in_features: int, + hidden_features: int = None, + out_features: int = None, + drop: float = 0., + init_cfg: Optional[Union[Dict, List[Dict]]] = None): + super().__init__(init_cfg) + hidden_features = hidden_features or in_features + out_features = out_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = nn.GELU() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, H, W, C)`. + + Returns: + torch.Tensor: A Tensor of shape :math:`(N, H, W, C)`. + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class MixingBlock(BaseModule): + """The Mixing block. + + Args: + embed_dims (int): Number of character components. + num_heads (int): Number of heads + mixer (str, optional): The mixer type. Defaults to 'Global'. + window_size (Tuple[int ,int], optional): Local window size. + Defaults to [7, 11]. + HW (Tuple[int, int], optional): The size of [H, W]. + Defaults to [8, 25]. + mlp_ratio (float, optional): The ratio of hidden features to input. + Defaults to 4.0. + qkv_bias (bool, optional): Whether a additive bias is required. + Defaults to False. + qk_scale (float, optional): A scaling factor. Defaults to None. + drop (float, optional): cfg of Dropout. Defaults to 0.. + attn_drop (float, optional): cfg of Dropout. Defaults to 0.0. + drop_path (float, optional): The probability of drop path. + Defaults to 0.0. + pernorm (bool, optional): Whether to place the MxingBlock before norm. + Defaults to True. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + mixer: str = 'Global', + window_size: Tuple[int, int] = [7, 11], + HW: Tuple[int, int] = [8, 25], + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: float = None, + drop: float = 0., + attn_drop: float = 0., + drop_path=0., + prenorm: bool = True, + init_cfg: Optional[Union[Dict, List[Dict]]] = None): + super().__init__(init_cfg) + self.norm1 = nn.LayerNorm(embed_dims, eps=1e-6) + if mixer in {'Global', 'Local'}: + self.mixer = AttnMixer( + embed_dims, + num_heads=num_heads, + mixer=mixer, + HW=HW, + local_k=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + elif mixer == 'Conv': + self.mixer = ConvMixer( + embed_dims, num_heads=num_heads, HW=HW, local_k=window_size) + else: + raise TypeError('The mixer must be one of [Global, Local, Conv]') + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + self.norm2 = nn.LayerNorm(embed_dims, eps=1e-6) + mlp_hidden_dim = int(embed_dims * mlp_ratio) + self.mlp_ratio = mlp_ratio + self.mlp = MLP( + in_features=embed_dims, hidden_features=mlp_hidden_dim, drop=drop) + self.prenorm = prenorm + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, H*W, C)`. + + Returns: + torch.Tensor: A Tensor of shape :math:`(N, H*W, C)`. + """ + if self.prenorm: + x = self.norm1(x + self.drop_path(self.mixer(x))) + x = self.norm2(x + self.drop_path(self.mlp(x))) + else: + x = x + self.drop_path(self.mixer(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class MerigingBlock(BaseModule): + """The last block of any stage, except for the last stage. + + Args: + in_channels (int): The channels of input. + out_channels (int): The channels of output. + types (str, optional): Which downsample operation of ['Pool', 'Conv']. + Defaults to 'Pool'. + stride (Union[int, Tuple[int, int]], optional): Stride of the Conv. + Defaults to [2, 1]. + act (bool, optional): activation function. Defaults to None. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + types: str = 'Pool', + stride: Union[int, Tuple[int, int]] = [2, 1], + act: bool = None, + init_cfg: Optional[Union[Dict, List[Dict]]] = None): + super().__init__(init_cfg) + self.types = types + if types == 'Pool': + self.avgpool = nn.AvgPool2d( + kernel_size=[3, 5], stride=stride, padding=[1, 2]) + self.maxpool = nn.MaxPool2d( + kernel_size=[3, 5], stride=stride, padding=[1, 2]) + self.proj = nn.Linear(in_channels, out_channels) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1) + self.norm = nn.LayerNorm(out_channels) + if act is not None: + self.act = act() + else: + self.act = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, H, W, C)`. + + Returns: + torch.Tensor: A Tensor of shape :math:`(N, H/2, W, 2C)`. + """ + if self.types == 'Pool': + x = (self.avgpool(x) + self.maxpool(x)) * 0.5 + out = self.proj(x.flatten(2).permute(0, 2, 1)) + + else: + x = self.conv(x) + out = x.flatten(2).permute(0, 2, 1) + out = self.norm(out) + if self.act is not None: + out = self.act(out) + + return out + + +@MODELS.register_module() +class SVTRNet(BaseModule): + """A PyTorch implement of : `SVTR: Scene Text Recognition with a Single + Visual Model `_ + + Code is partially modified from https://github.com/PaddlePaddle/PaddleOCR. + + Args: + img_size (Tuple[int, int], optional): The expected input image shape. + Defaults to [32, 100]. + in_channels (int, optional): The num of input channels. Defaults to 3. + embed_dims (Tuple[int, int, int], optional): Number of input channels. + Defaults to [64, 128, 256]. + depth (Tuple[int, int, int], optional): + The number of MixingBlock at each stage. Defaults to [3, 6, 3]. + num_heads (Tuple[int, int, int], optional): Number of attention heads. + Defaults to [2, 4, 8]. + mixer_types (Tuple[str], optional): Mixing type in a MixingBlock. + Defaults to ['Local']*6+['Global']*6. + window_size (Tuple[Tuple[int, int]], optional): + The height and width of the window at eeach stage. + Defaults to [[7, 11], [7, 11], [7, 11]]. + merging_types (str, optional): The way of downsample in MergingBlock. + Defaults to 'Conv'. + mlp_ratio (int, optional): Ratio of hidden features to input in MLP. + Defaults to 4. + qkv_bias (bool, optional): + Whether to add bias for qkv in attention modules. Defaults to True. + qk_scale (float, optional): A scaling factor. Defaults to None. + drop_rate (float, optional): Probability of an element to be zeroed. + Defaults to 0.0. + last_drop (float, optional): cfg of dropout at last stage. + Defaults to 0.1. + attn_drop_rate (float, optional): _description_. Defaults to 0.. + drop_path_rate (float, optional): stochastic depth rate. + Defaults to 0.1. + out_channels (int, optional): The num of output channels in backone. + Defaults to 192. + max_seq_len (int, optional): Maximum output sequence length :math:`T`. + Defaults to 25. + num_layers (int, optional): The num of conv in PatchEmbedding. + Defaults to 2. + prenorm (bool, optional): Whether to place the MxingBlock before norm. + Defaults to True. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + + def __init__(self, + img_size: Tuple[int, int] = [32, 100], + in_channels: int = 3, + embed_dims: Tuple[int, int, int] = [64, 128, 256], + depth: Tuple[int, int, int] = [3, 6, 3], + num_heads: Tuple[int, int, int] = [2, 4, 8], + mixer_types: Tuple[str] = ['Local'] * 6 + ['Global'] * 6, + window_size: Tuple[Tuple[int, int]] = [[7, 11], [7, 11], + [7, 11]], + merging_types: str = 'Conv', + mlp_ratio: int = 4, + qkv_bias: bool = True, + qk_scale: float = None, + drop_rate: float = 0., + last_drop: float = 0.1, + attn_drop_rate: float = 0., + drop_path_rate: float = 0.1, + out_channels: int = 192, + max_seq_len: int = 25, + num_layers: int = 2, + prenorm: bool = True, + init_cfg: Optional[Union[Dict, List[Dict]]] = None): + super().__init__(init_cfg) + self.img_size = img_size + self.embed_dims = embed_dims + self.out_channels = out_channels + self.prenorm = prenorm + self.patch_embed = OverlapPatchEmbed( + img_size=img_size, + in_channels=in_channels, + embed_dims=embed_dims[0], + num_layers=num_layers) + num_patches = (img_size[1] // (2**num_layers)) * ( + img_size[0] // (2**num_layers)) + self.HW = [ + img_size[0] // (2**num_layers), img_size[1] // (2**num_layers) + ] + self.absolute_pos_embed = nn.Parameter( + torch.zeros([1, num_patches, embed_dims[0]], dtype=torch.float32), + requires_grad=True) + self.pos_drop = nn.Dropout(drop_rate) + dpr = np.linspace(0, drop_path_rate, sum(depth)) + + self.blocks1 = nn.ModuleList([ + MixingBlock( + embed_dims=embed_dims[0], + num_heads=num_heads[0], + mixer=mixer_types[0:depth[0]][i], + window_size=window_size[0], + HW=self.HW, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[0:depth[0]][i], + prenorm=prenorm) for i in range(depth[0]) + ]) + self.downsample1 = MerigingBlock( + in_channels=embed_dims[0], + out_channels=embed_dims[1], + types=merging_types, + stride=[2, 1]) + HW = [self.HW[0] // 2, self.HW[1]] + self.merging_types = merging_types + + self.blocks2 = nn.ModuleList([ + MixingBlock( + embed_dims=embed_dims[1], + num_heads=num_heads[1], + mixer=mixer_types[depth[0]:depth[0] + depth[1]][i], + window_size=window_size[1], + HW=HW, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0]:depth[0] + depth[1]][i], + prenorm=prenorm) for i in range(depth[1]) + ]) + self.downsample2 = MerigingBlock( + in_channels=embed_dims[1], + out_channels=embed_dims[2], + types=merging_types, + stride=[2, 1]) + HW = [self.HW[0] // 4, self.HW[1]] + + self.blocks3 = nn.ModuleList([ + MixingBlock( + embed_dims=embed_dims[2], + num_heads=num_heads[2], + mixer=mixer_types[depth[0] + depth[1]:][i], + window_size=window_size[2], + HW=HW, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0] + depth[1]:][i], + prenorm=prenorm) for i in range(depth[2]) + ]) + self.layer_norm = nn.LayerNorm(self.embed_dims[-1], eps=1e-6) + self.avgpool = nn.AdaptiveAvgPool2d([1, max_seq_len]) + self.last_conv = ConvModule( + in_channels=embed_dims[2], + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0) + self.hardwish = nn.Hardswish() + self.dropout = nn.Dropout(p=last_drop) + + trunc_normal_(self.absolute_pos_embed) + self.apply(self._init_weights) + print('------------model weight inits-------------') + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + truncated_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.zeros_(m.bias) + if isinstance(m, nn.LayerNorm): + nn.init.zeros_(m.bias) + nn.init.ones_(m.weight) + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward function except the last combing operation. + + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, H, W, C)`. + + Returns: + torch.Tensor: A Tensor of shape :math:`(N, H/16, W/4, 256)`. + """ + x = self.patch_embed(x) + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + for blk in self.blocks1: + x = blk(x) + x = self.downsample1( + x.permute(0, 2, 1).reshape( + [-1, self.embed_dims[0], self.HW[0], self.HW[1]])) + + for blk in self.blocks2: + x = blk(x) + x = self.downsample2( + x.permute(0, 2, 1).reshape( + [-1, self.embed_dims[1], self.HW[0] // 2, self.HW[1]])) + + for blk in self.blocks3: + x = blk(x) + if not self.prenorm: + x = self.layer_norm(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, H/16, W/4, 256)`. + + Returns: + torch.Tensor: A Tensor of shape :math:`(N, 1, W/4, 192)`. + """ + x = self.forward_features(x) + x = self.avgpool( + x.permute(0, 2, 1).reshape( + [-1, self.embed_dims[2], self.HW[0] // 4, self.HW[1]])) + x = self.last_conv(x) + x = self.hardwish(x) + x = self.dropout(x) + return x diff --git a/tests/test_models/test_textrecog/test_backbones/test_svtr.py b/tests/test_models/test_textrecog/test_backbones/test_svtr.py new file mode 100644 index 000000000..d924815dd --- /dev/null +++ b/tests/test_models/test_textrecog/test_backbones/test_svtr.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.textrecog.backbones.svtr import (AttnMixer, ConvMixer, + MerigingBlock, MixingBlock, + OverlapPatchEmbed, SVTRNet) + + +class TestOverlapPatchEmbed(TestCase): + + def setUp(self) -> None: + self.img = torch.rand(1, 3, 32, 100) + + def test_overlap_patch_embed(self): + Overlap_Patch_Embed = OverlapPatchEmbed( + img_size=self.img.shape[-2:], in_channels=self.img.shape[1]) + self.assertEqual( + Overlap_Patch_Embed(self.img).shape, torch.Size([1, 8 * 25, 768])) + + +class TestConvMixer(TestCase): + + def setUp(self) -> None: + self.img = torch.rand(1, 8 * 25, 768) + + def test_conv_mixer(self): + conv_mixer = ConvMixer(embed_dims=self.img.shape[-1]) + self.assertEqual( + conv_mixer(self.img).shape, torch.Size([1, 8 * 25, 768])) + + +class TestAttnMixer(TestCase): + + def setUp(self) -> None: + self.img = torch.rand(1, 8 * 25, 768) + + def test_attn_mixer(self): + attn_mixer = AttnMixer(embed_dims=self.img.shape[-1]) + self.assertEqual( + attn_mixer(self.img).shape, torch.Size([1, 8 * 25, 768])) + + +class TestMixingBlock(TestCase): + + def setUp(self) -> None: + self.img = torch.rand(1, 8 * 25, 768) + + def test_mixing_block(self): + mixing_block = MixingBlock(self.img.shape[-1], num_heads=8) + self.assertEqual( + mixing_block(self.img).shape, torch.Size([1, 8 * 25, 768])) + + +class TestMergingBlock(TestCase): + + def setUp(self) -> None: + self.img = torch.rand(1, 64, 8, 25) + + def test_mergingblock(self): + mergingblock1 = MerigingBlock( + self.img.shape[1], self.img.shape[1] * 2, types='Pool') + mergingblock2 = MerigingBlock( + self.img.shape[1], self.img.shape[1] * 2, types='Conv') + self.assertEqual( + [mergingblock1(self.img).shape, + mergingblock2(self.img).shape], + [torch.Size([1, 4 * 25, 64 * 2]), + torch.Size([1, 4 * 25, 64 * 2])]) + + +class TestSvtrNet(TestCase): + + def setUp(self) -> None: + self.img = torch.rand(1, 3, 32, 100) + + def test_svtrnet(self): + model = SVTRNet( + img_size=self.img.shape[-2:], + in_channels=self.img.shape[1], + ) + model.train() + self.assertEqual(model(self.img).shape, torch.Size([1, 192, 1, 25])) From a45f11e78f00bdc81c54bc2c92302e5d1dcc9331 Mon Sep 17 00:00:00 2001 From: gaotongxiao Date: Tue, 13 Dec 2022 11:44:56 +0800 Subject: [PATCH 2/4] update backbone --- mmocr/models/textrecog/backbones/__init__.py | 4 +- mmocr/models/textrecog/backbones/svtr.py | 54 ++++++------------- .../test_backbones/test_svtr.py | 5 +- 3 files changed, 19 insertions(+), 44 deletions(-) diff --git a/mmocr/models/textrecog/backbones/__init__.py b/mmocr/models/textrecog/backbones/__init__.py index 43bd3926a..19432e49b 100644 --- a/mmocr/models/textrecog/backbones/__init__.py +++ b/mmocr/models/textrecog/backbones/__init__.py @@ -6,9 +6,9 @@ from .resnet31_ocr import ResNet31OCR from .resnet_abi import ResNetABI from .shallow_cnn import ShallowCNN -from .svtr import SVTR +from .svtr import SVTRNet __all__ = [ 'ResNet31OCR', 'MiniVGG', 'NRTRModalityTransform', 'ShallowCNN', - 'ResNetABI', 'ResNet', 'MobileNetV2', 'SVTR' + 'ResNetABI', 'ResNet', 'MobileNetV2', 'SVTRNet' ] diff --git a/mmocr/models/textrecog/backbones/svtr.py b/mmocr/models/textrecog/backbones/svtr.py index 79f51351f..beab610fa 100644 --- a/mmocr/models/textrecog/backbones/svtr.py +++ b/mmocr/models/textrecog/backbones/svtr.py @@ -8,23 +8,10 @@ from mmcv.cnn import ConvModule from mmcv.cnn.bricks import DropPath from mmengine.model import BaseModule -from mmengine.model.weight_init import trunc_normal_ +from mmengine.model.weight_init import trunc_normal_init from mmocr.registry import MODELS -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - -def truncated_normal_(tensor, mean=0, std=0.02): - with torch.no_grad(): - size = tensor.size() - tmp = tensor.new_empty(size + (4, )).normal_().cuda() - valid = (tmp < 2) & (tmp > -2) - ind = valid.max(-1, keepdim=True)[1] - tensor.data.copy_(tmp.gather(-1, ind.cuda()).squeeze(-1)) - tensor.data.mul_(std).add_(mean) - return tensor - class Identity(nn.Module): @@ -39,8 +26,6 @@ class OverlapPatchEmbed(BaseModule): """Image to the progressive overlapping Patch Embedding. Args: - img_size (int or tuple): The size of input, which will be used to - calculate the out size. Defaults to [32, 100]. in_channels (int): Number of input channels. Defaults to 3. embed_dims (int): The dimensions of embedding. Defaults to 768. num_layers (int, optional): Number of Conv_BN_Layer. Defaults to 2 and @@ -50,7 +35,6 @@ class OverlapPatchEmbed(BaseModule): """ def __init__(self, - img_size: Union[int, Tuple[int, int]] = [32, 100], in_channels: int = 3, embed_dims: int = 768, num_layers: int = 2, @@ -60,7 +44,6 @@ def __init__(self, assert num_layers in [2, 3], \ 'The number of layers must belong to [2, 3]' - self.img_size = img_size self.net = nn.Sequential() for num in range(num_layers, 0, -1): if (num == num_layers): @@ -74,7 +57,6 @@ def __init__(self, kernel_size=3, stride=2, padding=1, - bias=False, norm_cfg=dict(type='BN'), act_cfg=dict(type='GELU'))) _input = _output @@ -88,11 +70,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: Tensor: A tensor of shape math:`(N, HW//16, C)`. """ - _, _, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model \ - ({self.img_size[0]}*{self.img_size[1]})." - x = self.net(x).flatten(2).permute(0, 2, 1) return x @@ -101,7 +78,7 @@ class ConvMixer(BaseModule): """The conv Mixer. Args: - dim (int): Number of character components. + embed_dims (int): Number of character components. num_heads (int, optional): Number of heads. Defaults to 8. HW (Tuple[int, int], optional): Number of H x W. Defaults to [8, 25]. local_k (Tuple[int, int], optional): Window size. Defaults to [3, 3]. @@ -154,8 +131,8 @@ class AttnMixer(BaseModule): qkv_bias (bool, optional): Whether a additive bias is required. Defaults to False. qk_scale (float, optional): A scaling factor. Defaults to None. - attn_drop (float, optional): A Dropout layer. Defaults to 0.0. - proj_drop (float, optional): A Dropout layer. Defaults to 0.0. + attn_drop (float, optional): Attn dropout probability. Defaults to 0.0. + proj_drop (float, optional): Proj dropout layer. Defaults to 0.0. init_cfg (dict or list[dict], optional): Initialization configs. Defaults to None. """ @@ -193,10 +170,10 @@ def __init__(self, dtype=torch.float32) for h in range(0, H): for w in range(0, W): - mask[h * w + w, h:h + hk, w:w + wk] = 0. + mask[h * W + w, h:h + hk, w:w + wk] = 0. mask = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // 2].flatten(1) - mask[mask < -1] = -np.inf - self.mask = mask[None, None, :, :] + mask[mask >= 1] = -np.inf + self.mask = mask[None, None, :, :].to(self.qkv.weight.device) self.mixer = mixer def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -216,7 +193,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: (-1, N, 3, self.num_heads, C // self.num_heads)).permute( (2, 0, 3, 1, 4)) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - attn = (q.matmul(k.permute(0, 1, 3, 2))) + attn = q.matmul(k.permute(0, 1, 3, 2)) if self.mixer == 'Local': attn += self.mask attn = F.softmax(attn, dim=-1) @@ -425,8 +402,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @MODELS.register_module() class SVTRNet(BaseModule): - """A PyTorch implement of : `SVTR: Scene Text Recognition with a Single - Visual Model `_ + """A PyTorch implementation of `SVTR: Scene Text Recognition with a Single + Visual Model `_ Code is partially modified from https://github.com/PaddlePaddle/PaddleOCR. @@ -465,7 +442,7 @@ class SVTRNet(BaseModule): Defaults to 25. num_layers (int, optional): The num of conv in PatchEmbedding. Defaults to 2. - prenorm (bool, optional): Whether to place the MxingBlock before norm. + prenorm (bool, optional): Whether to place the MixingBlock before norm. Defaults to True. init_cfg (dict or list[dict], optional): Initialization configs. Defaults to None. @@ -499,7 +476,6 @@ def __init__(self, self.out_channels = out_channels self.prenorm = prenorm self.patch_embed = OverlapPatchEmbed( - img_size=img_size, in_channels=in_channels, embed_dims=embed_dims[0], num_layers=num_layers) @@ -576,22 +552,22 @@ def __init__(self, ]) self.layer_norm = nn.LayerNorm(self.embed_dims[-1], eps=1e-6) self.avgpool = nn.AdaptiveAvgPool2d([1, max_seq_len]) - self.last_conv = ConvModule( + self.last_conv = nn.Conv2d( in_channels=embed_dims[2], out_channels=self.out_channels, kernel_size=1, + bias=False, stride=1, padding=0) self.hardwish = nn.Hardswish() self.dropout = nn.Dropout(p=last_drop) - trunc_normal_(self.absolute_pos_embed) + trunc_normal_init(self.absolute_pos_embed, mean=0, std=0.02) self.apply(self._init_weights) - print('------------model weight inits-------------') def _init_weights(self, m): if isinstance(m, nn.Linear): - truncated_normal_(m.weight) + trunc_normal_init(m.weight, mean=0, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) if isinstance(m, nn.LayerNorm): diff --git a/tests/test_models/test_textrecog/test_backbones/test_svtr.py b/tests/test_models/test_textrecog/test_backbones/test_svtr.py index d924815dd..cb4d1b650 100644 --- a/tests/test_models/test_textrecog/test_backbones/test_svtr.py +++ b/tests/test_models/test_textrecog/test_backbones/test_svtr.py @@ -14,8 +14,7 @@ def setUp(self) -> None: self.img = torch.rand(1, 3, 32, 100) def test_overlap_patch_embed(self): - Overlap_Patch_Embed = OverlapPatchEmbed( - img_size=self.img.shape[-2:], in_channels=self.img.shape[1]) + Overlap_Patch_Embed = OverlapPatchEmbed(in_channels=self.img.shape[1]) self.assertEqual( Overlap_Patch_Embed(self.img).shape, torch.Size([1, 8 * 25, 768])) @@ -70,7 +69,7 @@ def test_mergingblock(self): torch.Size([1, 4 * 25, 64 * 2])]) -class TestSvtrNet(TestCase): +class TestSVTRNet(TestCase): def setUp(self) -> None: self.img = torch.rand(1, 3, 32, 100) From df90804e39139631cd801468975195bac18eb0da Mon Sep 17 00:00:00 2001 From: gaotongxiao Date: Tue, 27 Dec 2022 16:14:22 +0800 Subject: [PATCH 3/4] fix --- mmocr/models/textrecog/backbones/svtr.py | 98 ++++++++++++++---------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/mmocr/models/textrecog/backbones/svtr.py b/mmocr/models/textrecog/backbones/svtr.py index beab610fa..ad6556148 100644 --- a/mmocr/models/textrecog/backbones/svtr.py +++ b/mmocr/models/textrecog/backbones/svtr.py @@ -80,7 +80,8 @@ class ConvMixer(BaseModule): Args: embed_dims (int): Number of character components. num_heads (int, optional): Number of heads. Defaults to 8. - HW (Tuple[int, int], optional): Number of H x W. Defaults to [8, 25]. + input_shape (Tuple[int, int], optional): The shape of input [H, W]. + Defaults to [8, 25]. local_k (Tuple[int, int], optional): Window size. Defaults to [3, 3]. init_cfg (dict or list[dict], optional): Initialization configs. Defaults to None. @@ -89,11 +90,11 @@ class ConvMixer(BaseModule): def __init__(self, embed_dims: int, num_heads: int = 8, - HW: Tuple[int, int] = [8, 25], + input_shape: Tuple[int, int] = [8, 25], local_k: Tuple[int, int] = [3, 3], init_cfg: Optional[Union[Dict, List[Dict]]] = None): super().__init__(init_cfg) - self.HW = HW + self.input_shape = input_shape self.embed_dims = embed_dims self.local_mixer = nn.Conv2d( in_channels=embed_dims, @@ -112,7 +113,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Tensor: A tensor of shape math:`(N, HW, C)`. """ - h, w = self.HW + h, w = self.input_shape x = x.permute(0, 2, 1).reshape([-1, self.embed_dims, h, w]) x = self.local_mixer(x) x = x.flatten(2).permute(0, 2, 1) @@ -126,7 +127,8 @@ class AttnMixer(BaseModule): embed_dims (int): Number of character components. num_heads (int, optional): Number of heads. Defaults to 8. mixer (str, optional): The mixer type. Defaults to 'Global'. - HW (Tuple[int, int], optional): Number of H x W. Defaults to [8, 25]. + input_shape (Tuple[int, int], optional): The shape of input [H, W]. + Defaults to [8, 25]. local_k (Tuple[int, int], optional): Window size. Defaults to [7, 11]. qkv_bias (bool, optional): Whether a additive bias is required. Defaults to False. @@ -141,7 +143,7 @@ def __init__(self, embed_dims: int, num_heads: int = 8, mixer: str = 'Global', - HW: Tuple[int, int] = [8, 25], + input_shape: Tuple[int, int] = [8, 25], local_k: Tuple[int, int] = [7, 11], qkv_bias: bool = False, qk_scale: float = None, @@ -158,22 +160,24 @@ def __init__(self, self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(embed_dims, embed_dims) self.proj_drop = nn.Dropout(proj_drop) - self.HW = HW - if HW is not None: - H, W = HW - self.N = H * W - self.C = embed_dims - if mixer == 'Local' and HW is not None: + self.input_shape = input_shape + if input_shape is not None: + height, width = input_shape + self.input_size = height * width + self.embed_dims = embed_dims + if mixer == 'Local' and input_shape is not None: hk = local_k[0] wk = local_k[1] - mask = torch.ones([H * W, H + hk - 1, W + wk - 1], - dtype=torch.float32) - for h in range(0, H): - for w in range(0, W): - mask[h * W + w, h:h + hk, w:w + wk] = 0. - mask = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // 2].flatten(1) + mask = torch.ones( + [height * width, height + hk - 1, width + wk - 1], + dtype=torch.float32) + for h in range(0, height): + for w in range(0, width): + mask[h * width + w, h:h + hk, w:w + wk] = 0. + mask = mask[:, hk // 2:height + hk // 2, + wk // 2:width + wk // 2].flatten(1) mask[mask >= 1] = -np.inf - self.mask = mask[None, None, :, :].to(self.qkv.weight.device) + self.register_buffer('mask', mask[None, None, :, :]) self.mixer = mixer def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -185,13 +189,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: A Tensor of shape :math:`(N, H, W, C)`. """ - if self.HW is not None: - N, C = self.N, self.C + if self.input_shape is not None: + input_size, embed_dims = self.input_size, self.embed_dims else: - _, N, C = x.shape - qkv = self.qkv(x).reshape( - (-1, N, 3, self.num_heads, C // self.num_heads)).permute( - (2, 0, 3, 1, 4)) + _, input_size, embed_dims = x.shape + qkv = self.qkv(x).reshape((-1, input_size, 3, self.num_heads, + embed_dims // self.num_heads)).permute( + (2, 0, 3, 1, 4)) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] attn = q.matmul(k.permute(0, 1, 3, 2)) if self.mixer == 'Local': @@ -199,7 +203,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: attn = F.softmax(attn, dim=-1) attn = self.attn_drop(attn) - x = attn.matmul(v).permute(0, 2, 1, 3).reshape(-1, N, C) + x = attn.matmul(v).permute(0, 2, 1, 3).reshape(-1, input_size, + embed_dims) x = self.proj(x) x = self.proj_drop(x) return x @@ -259,7 +264,7 @@ class MixingBlock(BaseModule): mixer (str, optional): The mixer type. Defaults to 'Global'. window_size (Tuple[int ,int], optional): Local window size. Defaults to [7, 11]. - HW (Tuple[int, int], optional): The size of [H, W]. + input_shape (Tuple[int, int], optional): The shape of input [H, W]. Defaults to [8, 25]. mlp_ratio (float, optional): The ratio of hidden features to input. Defaults to 4.0. @@ -281,7 +286,7 @@ def __init__(self, num_heads: int, mixer: str = 'Global', window_size: Tuple[int, int] = [7, 11], - HW: Tuple[int, int] = [8, 25], + input_shape: Tuple[int, int] = [8, 25], mlp_ratio: float = 4., qkv_bias: bool = False, qk_scale: float = None, @@ -297,7 +302,7 @@ def __init__(self, embed_dims, num_heads=num_heads, mixer=mixer, - HW=HW, + input_shape=input_shape, local_k=window_size, qkv_bias=qkv_bias, qk_scale=qk_scale, @@ -305,7 +310,10 @@ def __init__(self, proj_drop=drop) elif mixer == 'Conv': self.mixer = ConvMixer( - embed_dims, num_heads=num_heads, HW=HW, local_k=window_size) + embed_dims, + num_heads=num_heads, + input_shape=input_shape, + local_k=window_size) else: raise TypeError('The mixer must be one of [Global, Local, Conv]') self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() @@ -481,7 +489,7 @@ def __init__(self, num_layers=num_layers) num_patches = (img_size[1] // (2**num_layers)) * ( img_size[0] // (2**num_layers)) - self.HW = [ + self.input_shape = [ img_size[0] // (2**num_layers), img_size[1] // (2**num_layers) ] self.absolute_pos_embed = nn.Parameter( @@ -496,7 +504,7 @@ def __init__(self, num_heads=num_heads[0], mixer=mixer_types[0:depth[0]][i], window_size=window_size[0], - HW=self.HW, + input_shape=self.input_shape, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, @@ -510,7 +518,7 @@ def __init__(self, out_channels=embed_dims[1], types=merging_types, stride=[2, 1]) - HW = [self.HW[0] // 2, self.HW[1]] + input_shape = [self.input_shape[0] // 2, self.input_shape[1]] self.merging_types = merging_types self.blocks2 = nn.ModuleList([ @@ -519,7 +527,7 @@ def __init__(self, num_heads=num_heads[1], mixer=mixer_types[depth[0]:depth[0] + depth[1]][i], window_size=window_size[1], - HW=HW, + input_shape=input_shape, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, @@ -533,7 +541,7 @@ def __init__(self, out_channels=embed_dims[2], types=merging_types, stride=[2, 1]) - HW = [self.HW[0] // 4, self.HW[1]] + input_shape = [self.input_shape[0] // 4, self.input_shape[1]] self.blocks3 = nn.ModuleList([ MixingBlock( @@ -541,7 +549,7 @@ def __init__(self, num_heads=num_heads[2], mixer=mixer_types[depth[0] + depth[1]:][i], window_size=window_size[2], - HW=HW, + input_shape=input_shape, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, @@ -592,14 +600,18 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: for blk in self.blocks1: x = blk(x) x = self.downsample1( - x.permute(0, 2, 1).reshape( - [-1, self.embed_dims[0], self.HW[0], self.HW[1]])) + x.permute(0, 2, 1).reshape([ + -1, self.embed_dims[0], self.input_shape[0], + self.input_shape[1] + ])) for blk in self.blocks2: x = blk(x) x = self.downsample2( - x.permute(0, 2, 1).reshape( - [-1, self.embed_dims[1], self.HW[0] // 2, self.HW[1]])) + x.permute(0, 2, 1).reshape([ + -1, self.embed_dims[1], self.input_shape[0] // 2, + self.input_shape[1] + ])) for blk in self.blocks3: x = blk(x) @@ -618,8 +630,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ x = self.forward_features(x) x = self.avgpool( - x.permute(0, 2, 1).reshape( - [-1, self.embed_dims[2], self.HW[0] // 4, self.HW[1]])) + x.permute(0, 2, 1).reshape([ + -1, self.embed_dims[2], self.input_shape[0] // 4, + self.input_shape[1] + ])) x = self.last_conv(x) x = self.hardwish(x) x = self.dropout(x) From a62e0c53c11ef9155d5ec1ffe277ee749219e4da Mon Sep 17 00:00:00 2001 From: gaotongxiao Date: Fri, 30 Dec 2022 11:44:46 +0800 Subject: [PATCH 4/4] apply comments, move backbone to encoder --- mmocr/models/textrecog/backbones/__init__.py | 3 +-- mmocr/models/textrecog/encoders/__init__.py | 3 ++- .../svtr.py => encoders/svtr_encoder.py} | 17 +++++------------ .../test_svtr_encoder.py} | 14 ++++++++------ 4 files changed, 16 insertions(+), 21 deletions(-) rename mmocr/models/textrecog/{backbones/svtr.py => encoders/svtr_encoder.py} (98%) rename tests/test_models/test_textrecog/{test_backbones/test_svtr.py => test_encoders/test_svtr_encoder.py} (82%) diff --git a/mmocr/models/textrecog/backbones/__init__.py b/mmocr/models/textrecog/backbones/__init__.py index 19432e49b..3201de388 100644 --- a/mmocr/models/textrecog/backbones/__init__.py +++ b/mmocr/models/textrecog/backbones/__init__.py @@ -6,9 +6,8 @@ from .resnet31_ocr import ResNet31OCR from .resnet_abi import ResNetABI from .shallow_cnn import ShallowCNN -from .svtr import SVTRNet __all__ = [ 'ResNet31OCR', 'MiniVGG', 'NRTRModalityTransform', 'ShallowCNN', - 'ResNetABI', 'ResNet', 'MobileNetV2', 'SVTRNet' + 'ResNetABI', 'ResNet', 'MobileNetV2' ] diff --git a/mmocr/models/textrecog/encoders/__init__.py b/mmocr/models/textrecog/encoders/__init__.py index 4ef77de0c..69896ba4a 100644 --- a/mmocr/models/textrecog/encoders/__init__.py +++ b/mmocr/models/textrecog/encoders/__init__.py @@ -5,8 +5,9 @@ from .nrtr_encoder import NRTREncoder from .sar_encoder import SAREncoder from .satrn_encoder import SATRNEncoder +from .svtr_encoder import SVTREncoder __all__ = [ 'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder', - 'SATRNEncoder', 'ABIEncoder' + 'SATRNEncoder', 'ABIEncoder', 'SVTREncoder' ] diff --git a/mmocr/models/textrecog/backbones/svtr.py b/mmocr/models/textrecog/encoders/svtr_encoder.py similarity index 98% rename from mmocr/models/textrecog/backbones/svtr.py rename to mmocr/models/textrecog/encoders/svtr_encoder.py index ad6556148..f97550c1b 100644 --- a/mmocr/models/textrecog/backbones/svtr.py +++ b/mmocr/models/textrecog/encoders/svtr_encoder.py @@ -13,15 +13,6 @@ from mmocr.registry import MODELS -class Identity(nn.Module): - - def __init__(self): - super(Identity, self).__init__() - - def forward(self, input): - return input - - class OverlapPatchEmbed(BaseModule): """Image to the progressive overlapping Patch Embedding. @@ -126,7 +117,8 @@ class AttnMixer(BaseModule): Args: embed_dims (int): Number of character components. num_heads (int, optional): Number of heads. Defaults to 8. - mixer (str, optional): The mixer type. Defaults to 'Global'. + mixer (str, optional): The mixer type, choices are 'Global' and + 'Local'. Defaults to 'Global'. input_shape (Tuple[int, int], optional): The shape of input [H, W]. Defaults to [8, 25]. local_k (Tuple[int, int], optional): Window size. Defaults to [7, 11]. @@ -316,7 +308,8 @@ def __init__(self, local_k=window_size) else: raise TypeError('The mixer must be one of [Global, Local, Conv]') - self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() self.norm2 = nn.LayerNorm(embed_dims, eps=1e-6) mlp_hidden_dim = int(embed_dims * mlp_ratio) self.mlp_ratio = mlp_ratio @@ -409,7 +402,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @MODELS.register_module() -class SVTRNet(BaseModule): +class SVTREncoder(BaseModule): """A PyTorch implementation of `SVTR: Scene Text Recognition with a Single Visual Model `_ diff --git a/tests/test_models/test_textrecog/test_backbones/test_svtr.py b/tests/test_models/test_textrecog/test_encoders/test_svtr_encoder.py similarity index 82% rename from tests/test_models/test_textrecog/test_backbones/test_svtr.py rename to tests/test_models/test_textrecog/test_encoders/test_svtr_encoder.py index cb4d1b650..8dab69459 100644 --- a/tests/test_models/test_textrecog/test_backbones/test_svtr.py +++ b/tests/test_models/test_textrecog/test_encoders/test_svtr_encoder.py @@ -3,9 +3,11 @@ import torch -from mmocr.models.textrecog.backbones.svtr import (AttnMixer, ConvMixer, - MerigingBlock, MixingBlock, - OverlapPatchEmbed, SVTRNet) +from mmocr.models.textrecog.encoders.svtr_encoder import (AttnMixer, ConvMixer, + MerigingBlock, + MixingBlock, + OverlapPatchEmbed, + SVTREncoder) class TestOverlapPatchEmbed(TestCase): @@ -69,13 +71,13 @@ def test_mergingblock(self): torch.Size([1, 4 * 25, 64 * 2])]) -class TestSVTRNet(TestCase): +class TestSVTREncoder(TestCase): def setUp(self) -> None: self.img = torch.rand(1, 3, 32, 100) - def test_svtrnet(self): - model = SVTRNet( + def test_svtr_encoder(self): + model = SVTREncoder( img_size=self.img.shape[-2:], in_channels=self.img.shape[1], )