From a0cddf71823bb043a2ca1989402d322ff7fe3391 Mon Sep 17 00:00:00 2001 From: henrytsui000 Date: Mon, 3 Jun 2024 00:38:04 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20[Pass]=20test=20on=20new=20framewor?= =?UTF-8?q?k=20and=20file=20name?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_tools/test_module_helper.py | 6 +++--- tests/test_utils/test_dataaugment.py | 9 +++++++-- tests/test_utils/test_loss.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/test_tools/test_module_helper.py b/tests/test_tools/test_module_helper.py index bfa087e..26dba64 100644 --- a/tests/test_tools/test_module_helper.py +++ b/tests/test_tools/test_module_helper.py @@ -7,7 +7,7 @@ project_root = Path(__file__).resolve().parent.parent.parent sys.path.append(str(project_root)) -from yolo.tools.module_helper import auto_pad, get_activation +from yolo.utils.module_utils import auto_pad, create_activation_function @pytest.mark.parametrize( @@ -29,10 +29,10 @@ def test_auto_pad(kernel_size, dilation, expected): [("ReLU", nn.ReLU), ("leakyrelu", nn.LeakyReLU), ("none", nn.Identity), (None, nn.Identity), (False, nn.Identity)], ) def test_get_activation(activation_name, expected_type): - result = get_activation(activation_name) + result = create_activation_function(activation_name) assert isinstance(result, expected_type), f"get_activation does not return correct type for {activation_name}" def test_get_activation_invalid(): with pytest.raises(ValueError): - get_activation("unsupported_activation") + create_activation_function("unsupported_activation") diff --git a/tests/test_utils/test_dataaugment.py b/tests/test_utils/test_dataaugment.py index aacb00b..220c826 100644 --- a/tests/test_utils/test_dataaugment.py +++ b/tests/test_utils/test_dataaugment.py @@ -9,7 +9,12 @@ project_root = Path(__file__).resolve().parent.parent.parent sys.path.append(str(project_root)) -from yolo.utils.data_augmentation import Compose, HorizontalFlip, Mosaic, VerticalFlip +from yolo.tools.data_augmentation import ( + AugmentationComposer, + HorizontalFlip, + Mosaic, + VerticalFlip, +) def test_horizontal_flip(): @@ -33,7 +38,7 @@ def test_compose(): def mock_transform(image, boxes): return image, boxes - compose = Compose([mock_transform, mock_transform]) + compose = AugmentationComposer([mock_transform, mock_transform]) img = Image.new("RGB", (10, 10), color="blue") boxes = torch.tensor([[0, 0.2, 0.2, 0.8, 0.8]]) diff --git a/tests/test_utils/test_loss.py b/tests/test_utils/test_loss.py index f3853e7..6efb17e 100644 --- a/tests/test_utils/test_loss.py +++ b/tests/test_utils/test_loss.py @@ -8,7 +8,7 @@ project_root = Path(__file__).resolve().parent.parent.parent sys.path.append(str(project_root)) -from yolo.utils.loss_functions import YOLOLoss +from yolo.tools.loss_functions import YOLOLoss @pytest.fixture