Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeCamp #117 [UT] Add missing unit tests #1651

Merged
merged 2 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions mmocr/models/textdet/necks/fpem_ffm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from mmengine.model import BaseModule, ModuleList
from torch import nn
Expand All @@ -14,7 +17,9 @@ class FPEM(BaseModule):
init_cfg (dict or list[dict], optional): Initialization configs.
"""

def __init__(self, in_channels=128, init_cfg=None):
def __init__(self,
in_channels: int = 128,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
Expand All @@ -23,7 +28,8 @@ def __init__(self, in_channels=128, init_cfg=None):
self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)

def forward(self, c2, c3, c4, c5):
def forward(self, c2: torch.Tensor, c3: torch.Tensor, c4: torch.Tensor,
c5: torch.Tensor) -> List[torch.Tensor]:
"""
Args:
c2, c3, c4, c5 (Tensor): Each has the shape of
Expand All @@ -48,8 +54,21 @@ def _upsample_add(self, x, y):


class SeparableConv2d(BaseModule):
"""Implementation of separable convolution, which is consisted of depthwise
convolution and pointwise convolution.

Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int): Stride of the depthwise convolution.
init_cfg (dict or list[dict], optional): Initialization configs.
"""

def __init__(self, in_channels, out_channels, stride=1, init_cfg=None):
def __init__(self,
in_channels: int,
out_channels: int,
stride: int = 1,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)

self.depthwise_conv = nn.Conv2d(
Expand All @@ -64,7 +83,15 @@ def __init__(self, in_channels, out_channels, stride=1, init_cfg=None):
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function.

Args:
x (Tensor): Input tensor.

Returns:
Tensor: Output tensor.
"""
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
x = self.bn(x)
Expand All @@ -85,13 +112,15 @@ class FPEM_FFM(BaseModule):
init_cfg (dict or list[dict], optional): Initialization configs.
"""

def __init__(self,
in_channels,
conv_out=128,
fpem_repeat=2,
align_corners=False,
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
def __init__(
self,
in_channels: List[int],
conv_out: int = 128,
fpem_repeat: int = 2,
align_corners: bool = False,
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
type='Xavier', layer='Conv2d', distribution='uniform')
) -> None:
super().__init__(init_cfg=init_cfg)
# reduce layers
self.reduce_conv_c2 = nn.Sequential(
Expand Down Expand Up @@ -119,7 +148,7 @@ def __init__(self,
for _ in range(fpem_repeat):
self.fpems.append(FPEM(conv_out))

def forward(self, x):
def forward(self, x: List[torch.Tensor]) -> Tuple[torch.Tensor]:
"""
Args:
x (list[Tensor]): A list of four tensors of shape
Expand All @@ -128,7 +157,7 @@ def forward(self, x):
``in_channels``.

Returns:
list[Tensor]: Four tensors of shape
tuple[Tensor]: Four tensors of shape
:math:`(N, C_{out}, H_0, W_0)` where :math:`C_{out}` is
``conv_out``.
"""
Expand Down
31 changes: 16 additions & 15 deletions mmocr/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bbox_utils import (bbox2poly, bbox_center_distance, bbox_diag_distance,
bezier2polygon, is_on_same_line, rescale_bboxes,
stitch_boxes_into_lines)
bezier2polygon, is_on_same_line, rescale_bbox,
rescale_bboxes, stitch_boxes_into_lines)
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type,
is_type_list, valid_boundary)
from .collect_env import collect_env
Expand Down Expand Up @@ -34,17 +34,18 @@
'is_2dlist', 'valid_boundary', 'list_to_file', 'list_from_file',
'is_on_same_line', 'stitch_boxes_into_lines', 'StringStripper',
'bezier2polygon', 'sort_points', 'dump_ocr_data', 'recog_anno_to_imginfo',
'rescale_polygons', 'rescale_polygon', 'rescale_bboxes', 'bbox2poly',
'crop_polygon', 'is_poly_inside_rect', 'poly2bbox', 'poly_intersection',
'poly_iou', 'poly_make_valid', 'poly_union', 'poly2shapely',
'polys2shapely', 'register_all_modules', 'offset_polygon', 'sort_vertex8',
'sort_vertex', 'bbox_center_distance', 'bbox_diag_distance',
'boundary_iou', 'point_distance', 'points_center', 'fill_hole',
'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', 'warp_img',
'ConfigType', 'DetSampleList', 'RecForwardResults', 'InitConfigType',
'OptConfigType', 'OptDetSampleList', 'OptInitConfigType', 'OptMultiConfig',
'OptRecSampleList', 'RecSampleList', 'MultiConfig', 'OptTensor',
'ColorType', 'OptKIESampleList', 'KIESampleList', 'is_archive',
'check_integrity', 'list_files', 'get_md5', 'InstanceList', 'LabelList',
'OptInstanceList', 'OptLabelList', 'RangeType', 'remove_pipeline_elements'
'rescale_polygons', 'rescale_polygon', 'rescale_bbox', 'rescale_bboxes',
'bbox2poly', 'crop_polygon', 'is_poly_inside_rect', 'poly2bbox',
'poly_intersection', 'poly_iou', 'poly_make_valid', 'poly_union',
'poly2shapely', 'polys2shapely', 'register_all_modules', 'offset_polygon',
'sort_vertex8', 'sort_vertex', 'bbox_center_distance',
'bbox_diag_distance', 'boundary_iou', 'point_distance', 'points_center',
'fill_hole', 'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img',
'warp_img', 'ConfigType', 'DetSampleList', 'RecForwardResults',
'InitConfigType', 'OptConfigType', 'OptDetSampleList', 'OptInitConfigType',
'OptMultiConfig', 'OptRecSampleList', 'RecSampleList', 'MultiConfig',
'OptTensor', 'ColorType', 'OptKIESampleList', 'KIESampleList',
'is_archive', 'check_integrity', 'list_files', 'get_md5', 'InstanceList',
'LabelList', 'OptInstanceList', 'OptLabelList', 'RangeType',
'remove_pipeline_elements'
]
46 changes: 46 additions & 0 deletions tests/test_models/test_textdet/test_necks/test_fpem_ffm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import torch

from mmocr.models.textdet.necks.fpem_ffm import FPEM, FPEM_FFM


class TestFPEM(unittest.TestCase):

def setUp(self):
self.c2 = torch.Tensor(1, 8, 64, 64)
self.c3 = torch.Tensor(1, 8, 32, 32)
self.c4 = torch.Tensor(1, 8, 16, 16)
self.c5 = torch.Tensor(1, 8, 8, 8)
self.fpem = FPEM(in_channels=8)

def test_forward(self):
neck = FPEM(in_channels=8)
neck.init_weights()
out = neck(self.c2, self.c3, self.c4, self.c5)
self.assertTrue(out[0].shape == self.c2.shape)
self.assertTrue(out[1].shape == self.c3.shape)
self.assertTrue(out[2].shape == self.c4.shape)
self.assertTrue(out[3].shape == self.c5.shape)


class TestFPEM_FFM(unittest.TestCase):

def setUp(self):
self.c2 = torch.Tensor(1, 8, 64, 64)
self.c3 = torch.Tensor(1, 16, 32, 32)
self.c4 = torch.Tensor(1, 32, 16, 16)
self.c5 = torch.Tensor(1, 64, 8, 8)
self.in_channels = [8, 16, 32, 64]
self.conv_out = 8
self.features = [self.c2, self.c3, self.c4, self.c5]

def test_forward(self):
neck = FPEM_FFM(in_channels=self.in_channels, conv_out=self.conv_out)
neck.init_weights()
out = neck(self.features)
self.assertTrue(out[0].shape == torch.Size([1, 8, 64, 64]))
self.assertTrue(out[1].shape == out[0].shape)
self.assertTrue(out[2].shape == out[0].shape)
self.assertTrue(out[3].shape == out[0].shape)
32 changes: 30 additions & 2 deletions tests/test_utils/test_bbox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch

from mmocr.utils import (bbox2poly, bbox_center_distance, bbox_diag_distance,
bezier2polygon, is_on_same_line,
stitch_boxes_into_lines)
bezier2polygon, is_on_same_line, rescale_bbox,
rescale_bboxes, stitch_boxes_into_lines)
from mmocr.utils.bbox_utils import bbox_jitter


Expand Down Expand Up @@ -236,3 +236,31 @@ def test_stitch_boxes_into_lines(self):
result.sort(key=lambda x: x['box'][0])
expected_result.sort(key=lambda x: x['box'][0])
self.assertEqual(result, expected_result)


class TestRescaleBbox(unittest.TestCase):

def setUp(self) -> None:
self.bbox = np.array([0, 0, 1, 1])
self.bboxes = np.array([[0, 0, 1, 1], [1, 1, 2, 2]])
self.scale = 2

def test_rescale_bbox(self):
# mul
rescaled_bbox = rescale_bbox(self.bbox, self.scale, mode='mul')
self.assertTrue(np.allclose(rescaled_bbox, np.array([0, 0, 2, 2])))
# div
rescaled_bbox = rescale_bbox(self.bbox, self.scale, mode='div')
self.assertTrue(np.allclose(rescaled_bbox, np.array([0, 0, 0.5, 0.5])))

def test_rescale_bboxes(self):
# mul
rescaled_bboxes = rescale_bboxes(self.bboxes, self.scale, mode='mul')
self.assertTrue(
np.allclose(rescaled_bboxes, np.array([[0, 0, 2, 2], [2, 2, 4,
4]])))
# div
rescaled_bboxes = rescale_bboxes(self.bboxes, self.scale, mode='div')
self.assertTrue(
np.allclose(rescaled_bboxes,
np.array([[0, 0, 0.5, 0.5], [0.5, 0.5, 1, 1]])))
6 changes: 6 additions & 0 deletions tests/test_utils/test_check_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ def test_valid_boundary():
assert utils.valid_boundary(x, False)
x = [0, 0, 1, 0, 1, 1, 0, 1, 1]
assert utils.valid_boundary(x, True)


def test_equal_len():

assert utils.equal_len([1, 2, 3], [1, 2, 3])
assert not utils.equal_len([1, 2, 3], [1, 2, 3, 4])