From a80fd8cc06b3100e6e0a0c422689208ad17ae139 Mon Sep 17 00:00:00 2001 From: henrytsui000 Date: Thu, 21 Nov 2024 15:38:09 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20[Pass]=20Test,=20mock=20dataset=20a?= =?UTF-8?q?re=205=20images?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_tools/test_data_loader.py | 13 +++++++------ tests/test_tools/test_loss_functions.py | 7 +++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_tools/test_data_loader.py b/tests/test_tools/test_data_loader.py index d95d387..f40aa59 100644 --- a/tests/test_tools/test_data_loader.py +++ b/tests/test_tools/test_data_loader.py @@ -42,15 +42,16 @@ def test_training_data_loader_correctness(train_dataloader: DataLoader): def test_validation_data_loader_correctness(validation_dataloader: DataLoader): batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader)) - assert batch_size == 4 - assert images.shape == (4, 3, 512, 768) - assert targets.shape == (4, 18, 5) - assert reverse_tensors.shape == (4, 5) + assert batch_size == 5 + assert images.shape == (5, 3, 640, 640) + assert targets.shape == (5, 18, 5) + assert reverse_tensors.shape == (5, 5) expected_paths = [ - Path("tests/data/images/val/000000284106.jpg"), Path("tests/data/images/val/000000151480.jpg"), - Path("tests/data/images/val/000000570456.jpg"), + Path("tests/data/images/val/000000284106.jpg"), Path("tests/data/images/val/000000323571.jpg"), + Path("tests/data/images/val/000000556498.jpg"), + Path("tests/data/images/val/000000570456.jpg"), ] assert list(image_paths) == list(expected_paths) diff --git a/tests/test_tools/test_loss_functions.py b/tests/test_tools/test_loss_functions.py index f011d86..450361b 100644 --- a/tests/test_tools/test_loss_functions.py +++ b/tests/test_tools/test_loss_functions.py @@ -51,7 +51,6 @@ def data(): def test_yolo_loss(loss_function, data): predicts, targets = data loss, loss_dict = loss_function(predicts, predicts, targets) - assert torch.isnan(loss) - assert isnan(loss_dict["Loss/BoxLoss"]) - assert isnan(loss_dict["Loss/DFLLoss"]) - assert isinf(loss_dict["Loss/BCELoss"]) + assert loss_dict["Loss/BoxLoss"] == 0 + assert loss_dict["Loss/DFLLoss"] == 0 + assert loss_dict["Loss/BCELoss"] >= 2e5