diff --git a/tests/test_tools/test_data_loader.py b/tests/test_tools/test_data_loader.py index d95d387..f40aa59 100644 --- a/tests/test_tools/test_data_loader.py +++ b/tests/test_tools/test_data_loader.py @@ -42,15 +42,16 @@ def test_training_data_loader_correctness(train_dataloader: DataLoader): def test_validation_data_loader_correctness(validation_dataloader: DataLoader): batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader)) - assert batch_size == 4 - assert images.shape == (4, 3, 512, 768) - assert targets.shape == (4, 18, 5) - assert reverse_tensors.shape == (4, 5) + assert batch_size == 5 + assert images.shape == (5, 3, 640, 640) + assert targets.shape == (5, 18, 5) + assert reverse_tensors.shape == (5, 5) expected_paths = [ - Path("tests/data/images/val/000000284106.jpg"), Path("tests/data/images/val/000000151480.jpg"), - Path("tests/data/images/val/000000570456.jpg"), + Path("tests/data/images/val/000000284106.jpg"), Path("tests/data/images/val/000000323571.jpg"), + Path("tests/data/images/val/000000556498.jpg"), + Path("tests/data/images/val/000000570456.jpg"), ] assert list(image_paths) == list(expected_paths) diff --git a/tests/test_tools/test_loss_functions.py b/tests/test_tools/test_loss_functions.py index f011d86..450361b 100644 --- a/tests/test_tools/test_loss_functions.py +++ b/tests/test_tools/test_loss_functions.py @@ -51,7 +51,6 @@ def data(): def test_yolo_loss(loss_function, data): predicts, targets = data loss, loss_dict = loss_function(predicts, predicts, targets) - assert torch.isnan(loss) - assert isnan(loss_dict["Loss/BoxLoss"]) - assert isnan(loss_dict["Loss/DFLLoss"]) - assert isinf(loss_dict["Loss/BCELoss"]) + assert loss_dict["Loss/BoxLoss"] == 0 + assert loss_dict["Loss/DFLLoss"] == 0 + assert loss_dict["Loss/BCELoss"] >= 2e5