From b86d98c95274b30bb6a355f622af2088da2a3345 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 22 Oct 2021 17:22:09 +0200 Subject: [PATCH 001/108] :art: added types --- src/datamodules/util/transformations/functional.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/datamodules/util/transformations/functional.py b/src/datamodules/util/transformations/functional.py index b17df76e..bf9e24de 100644 --- a/src/datamodules/util/transformations/functional.py +++ b/src/datamodules/util/transformations/functional.py @@ -1,10 +1,12 @@ +from typing import List + import numpy as np import torch from sklearn.preprocessing import OneHotEncoder -def gt_to_one_hot(matrix, class_encodings): +def gt_to_one_hot(matrix: torch.Tensor, class_encodings: List[int]): """ Convert ground truth tensor or numpy matrix to one-hot encoded matrix @@ -48,7 +50,7 @@ def gt_to_one_hot(matrix, class_encodings): return torch.LongTensor(one_hot_matrix.transpose((2, 0, 1))) -def argmax_onehot(tensor): +def argmax_onehot(tensor: torch.Tensor): """ # TODO """ From c112dd82c03c0254bea127e185d3c8b96947eb2e Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 22 Oct 2021 18:19:24 +0200 Subject: [PATCH 002/108] :white_check_mark: added some more tests --- tests/tasks/utils/test_functional.py | 35 ++++++++++++++ tests/utils/__init__.py | 0 tests/utils/test_utils.py | 69 ++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+) create mode 100644 tests/tasks/utils/test_functional.py create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/test_utils.py diff --git a/tests/tasks/utils/test_functional.py b/tests/tasks/utils/test_functional.py new file mode 100644 index 00000000..10bbb596 --- /dev/null +++ b/tests/tasks/utils/test_functional.py @@ -0,0 +1,35 @@ +import pytest +import torch +from _pytest.fixtures import fixture + +from src.datamodules.util.transformations.functional import gt_to_one_hot + + +@fixture +def get_class_encodings(): + return [1, 2] + + +@fixture +def get_input_tensor(): + return torch.tensor( + [[[0.01, 0.1], [0.001, 0.01], [0.01, 0.1]], [[0.01, 0.1], [0.01, 0.1], [3.01, 0.1]], + [[0.01, 0.1], [0.01, 0.1], [3.01, 0.1]]], + dtype=torch.float) + + +def test_gt_to_one_hot_work(get_input_tensor, get_class_encodings): + result = gt_to_one_hot(get_input_tensor, get_class_encodings) + assert torch.equal(result, torch.tensor([[[1, 1], + [0, 1], + [1, 1]], + [[0, 0], + [1, 0], + [0, 0]]])) + + +def test_gt_to_one_hot_crash(get_input_tensor): + class_encodings = [1] + with pytest.raises(KeyError): + gt_to_one_hot(get_input_tensor, class_encodings) + diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py new file mode 100644 index 00000000..25fe6156 --- /dev/null +++ b/tests/utils/test_utils.py @@ -0,0 +1,69 @@ +import pytest +from _pytest.fixtures import fixture +from omegaconf import DictConfig + +from src.utils.utils import _check_if_in_config, REQUIRED_CONFIGS, check_config + + +@fixture +def get_dict(): + return DictConfig({'plugins': { + 'ddp_plugin': {'_target_': 'pytorch_lightning.plugins.DDPPlugin', 'find_unused_parameters': False}}, 'task': { + '_target_': 'src.tasks.semantic_segmentation.semantic_segmentation.SemanticSegmentation', + 'confusion_matrix_log_every_n_epoch': 1, 'confusion_matrix_val': True, 'confusion_matrix_test': True}, + 'loss': {'_target_': 'torch.nn.CrossEntropyLoss'}, + 'metric': {'_target_': 'src.metrics.divahisdb.HisDBIoU', + 'num_classes': '${datamodule:num_classes}'}, 'model': { + 'backbone': {'_target_': 'pl_bolts.models.vision.UNet', 'num_classes': '${datamodule:num_classes}', + 'num_layers': 2, 'features_start': 32}, 'header': {'_target_': 'torch.nn.Identity'}}, + 'optimizer': {'_target_': 'torch.optim.Adam', 'lr': 0.001, 'betas': [0.9, 0.999], 'eps': 1e-08, + 'weight_decay': 0, 'amsgrad': False}, 'callbacks': { + 'check_backbone_header_compatibility': { + '_target_': 'src.callbacks.model_callbacks.CheckBackboneHeaderCompatibility'}, + 'model_checkpoint': {'_target_': 'src.callbacks.model_callbacks.SaveModelStateDictAndTaskCheckpoint', + 'monitor': 'val/crossentropyloss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', + 'verbose': False, 'dirpath': 'checkpoints/', + 'filename': '${checkpoint_folder_name}dev-baby-unet-cb55-10', + 'backbone_filename': '${checkpoint_folder_name}backbone', + 'header_filename': '${checkpoint_folder_name}header'}, + 'watch_model': {'_target_': 'src.callbacks.wandb_callbacks.WatchModelWithWandb', 'log': 'all', + 'log_freq': 1}}, 'logger': { + 'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'unsupervised', + 'name': 'dev-baby-unet-cb55-10', 'offline': False, 'job_type': 'train', 'group': 'dev-runs', + 'tags': ['best_model', 'USL'], 'save_dir': '.', 'log_model': False, 'notes': 'Testing'}, + 'csv': {'_target_': 'pytorch_lightning.loggers.csv_logs.CSVLogger', 'save_dir': '.', 'name': 'csv/'}}, + 'seed': 42, 'train': True, 'test': True, + 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'gpus': -1, 'accelerator': 'ddp', + 'min_epochs': 1, 'max_epochs': 3, 'weights_summary': 'full', 'precision': 16}, + 'datamodule': { + '_target_': 'src.datamodules.hisDBDataModule.DIVAHisDBDataModule.DIVAHisDBDataModuleCropped', + 'data_dir': '/netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation', + 'crop_size': 256, 'num_workers': 4, 'batch_size': 16, 'shuffle': True, 'drop_last': True}, + 'save_config': True, 'checkpoint_folder_name': '{epoch}/', 'work_dir': '${hydra:runtime.cwd}', + 'debug': False, 'print_config': True, 'disable_warnings': True}) + + +def test_check_config_everything_good(get_dict): + check_config(get_dict) + + +def test_check_config_no_seed(get_dict): + del get_dict['seed'] + check_config(get_dict) + + +def test_check_config_no_plugins(get_dict): + del get_dict['plugins'] + check_config(get_dict) + + +def test__check_if_in_config_good_config(get_dict): + for cf in REQUIRED_CONFIGS: + _check_if_in_config(config=get_dict, name=cf) + + +def test__check_if_in_config_bad_config(get_dict): + del get_dict['datamodule'] + with pytest.raises(ValueError): + for cf in REQUIRED_CONFIGS: + _check_if_in_config(config=get_dict, name=cf) From 55e50d3bffe8ac04ea6b03bfa761fdb19f03151e Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Sat, 23 Oct 2021 12:28:14 +0200 Subject: [PATCH 003/108] :white_check_mark: better testing of the cropped dataset --- tests/datasets/test_cropped_hisdb_dataset.py | 140 +++++++++++-------- 1 file changed, 78 insertions(+), 62 deletions(-) diff --git a/tests/datasets/test_cropped_hisdb_dataset.py b/tests/datasets/test_cropped_hisdb_dataset.py index 9f3816cf..b95c57ed 100644 --- a/tests/datasets/test_cropped_hisdb_dataset.py +++ b/tests/datasets/test_cropped_hisdb_dataset.py @@ -1,5 +1,6 @@ from pathlib import PosixPath +import pytest import torch from pytest import fixture @@ -50,70 +51,21 @@ def test__get_train_val_items_test(dataset_test): assert index == 0 -def test_get_gt_data_paths_train(data_dir_cropped): +def test_dataset_train_selection_int_error(data_dir_cropped): + with pytest.raises(ValueError): + CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train', selection=2) + + +def test_dataset_train_selection_int(data_dir_cropped, get_train_file_names): + files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train', selection=1) + assert len(files_from_method) == 12 + assert files_from_method == get_train_file_names + + +def test_get_gt_data_paths_train(data_dir_cropped, get_train_file_names): files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train') - expected_result = [(PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0000', (0, 0)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0150', (0, 150)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0187', (0, 187)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0000', (150, 0)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0150', (150, 150)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0187', (150, 187)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0000', (300, 0)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0150', (300, 150)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0187', (300, 187)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0000', (349, 0)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0150', (349, 150)), ( - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), - PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0187', (349, 187))] assert len(files_from_method) == 12 - assert files_from_method == expected_result + assert files_from_method == get_train_file_names def test_get_gt_data_paths_val(data_dir_cropped): @@ -261,3 +213,67 @@ def test_get_gt_data_paths_test(data_dir_cropped): 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0393_y0231', (393, 231))] assert len(files_from_method) == 15 assert files_from_method == expected_result + + +@fixture +def get_train_file_names(data_dir_cropped): + return [(PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0000', (0, 0)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0150', (0, 150)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0187', (0, 187)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0000', (150, 0)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0150', (150, 150)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0187', (150, 187)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0000', (300, 0)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0150', (300, 150)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0187', (300, 187)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0000', (349, 0)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0150', (349, 150)), ( + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), + PosixPath( + data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0187', (349, 187))] From 6fd5e9b702e3e3af6c276c1c2000d42785d5ce49 Mon Sep 17 00:00:00 2001 From: Paul M Date: Mon, 25 Oct 2021 14:40:14 +0200 Subject: [PATCH 004/108] :recycle: refactor datamodule files --- configs/datamodule/cb55_10_cropped_datamodule.yaml | 2 +- configs/datamodule/cb55_cropped_datamodule.yaml | 2 +- configs/experiment/cb55_full_run_unet.yaml | 2 +- .../experiment/development_baby_unet_cb55_10.yaml | 2 +- configs/task/semantic_segmentation_task.yaml | 3 ++- src/datamodules/{datasets => DivaHisDB}/__init__.py | 0 .../datamodule_cropped.py} | 12 ++++++------ .../datasets}/__init__.py | 0 .../datasets/cropped_hisdb_dataset.py | 2 +- .../{util => DivaHisDB/utils}/__init__.py | 0 .../utils}/functional.py | 0 .../analytics => DivaHisDB/utils}/image_analytics.py | 0 src/datamodules/{util => DivaHisDB/utils}/misc.py | 2 +- .../DivaHisDB}/utils/output_tools.py | 1 - .../utils}/single_transforms.py | 0 .../utils}/twin_transforms.py | 2 +- .../utils}/wrapper_transforms.py | 0 .../{util/analytics => utils}/__init__.py | 0 src/datamodules/{util => utils}/exceptions.py | 0 .../transformations => tasks/DivaHisDB}/__init__.py | 0 .../semantic_segmentation.py | 2 +- src/tasks/semantic_segmentation/utils/__init__.py | 0 .../datamodules/DivaHisDB}/__init__.py | 0 .../test_hisDBDataModule.py | 4 ++-- .../test_image_analytics.py | 4 ++-- .../{hisDBDataModule => DivaHisDB}/test_misc.py | 4 ++-- tests/datamodules/hisDBDataModule/__init__.py | 0 tests/datasets/test_cropped_hisdb_dataset.py | 2 +- tests/tasks/sem_seg/test_output_tools.py | 2 +- tests/tasks/sem_seg/test_semantic_segmentation.py | 7 +++---- tests/tasks/utils/test_functional.py | 2 +- tools/merge_cropped_output.py | 8 ++++---- 32 files changed, 32 insertions(+), 33 deletions(-) rename src/datamodules/{datasets => DivaHisDB}/__init__.py (100%) rename src/datamodules/{hisDBDataModule/DIVAHisDBDataModule.py => DivaHisDB/datamodule_cropped.py} (92%) rename src/datamodules/{hisDBDataModule => DivaHisDB/datasets}/__init__.py (100%) rename src/datamodules/{ => DivaHisDB}/datasets/cropped_hisdb_dataset.py (99%) rename src/datamodules/{util => DivaHisDB/utils}/__init__.py (100%) rename src/datamodules/{util/transformations => DivaHisDB/utils}/functional.py (100%) rename src/datamodules/{util/analytics => DivaHisDB/utils}/image_analytics.py (100%) rename src/datamodules/{util => DivaHisDB/utils}/misc.py (95%) rename src/{tasks/semantic_segmentation => datamodules/DivaHisDB}/utils/output_tools.py (99%) rename src/datamodules/{util/transformations => DivaHisDB/utils}/single_transforms.py (100%) rename src/datamodules/{util/transformations => DivaHisDB/utils}/twin_transforms.py (97%) rename src/datamodules/{util/transformations => DivaHisDB/utils}/wrapper_transforms.py (100%) rename src/datamodules/{util/analytics => utils}/__init__.py (100%) rename src/datamodules/{util => utils}/exceptions.py (100%) rename src/{datamodules/util/transformations => tasks/DivaHisDB}/__init__.py (100%) rename src/tasks/{semantic_segmentation => DivaHisDB}/semantic_segmentation.py (98%) delete mode 100644 src/tasks/semantic_segmentation/utils/__init__.py rename {src/tasks/semantic_segmentation => tests/datamodules/DivaHisDB}/__init__.py (100%) rename tests/datamodules/{hisDBDataModule => DivaHisDB}/test_hisDBDataModule.py (93%) rename tests/datamodules/{hisDBDataModule => DivaHisDB}/test_image_analytics.py (89%) rename tests/datamodules/{hisDBDataModule => DivaHisDB}/test_misc.py (89%) delete mode 100644 tests/datamodules/hisDBDataModule/__init__.py diff --git a/configs/datamodule/cb55_10_cropped_datamodule.yaml b/configs/datamodule/cb55_10_cropped_datamodule.yaml index 231ee851..44e7d090 100644 --- a/configs/datamodule/cb55_10_cropped_datamodule.yaml +++ b/configs/datamodule/cb55_10_cropped_datamodule.yaml @@ -1,4 +1,4 @@ -_target_: src.datamodules.hisDBDataModule.DIVAHisDBDataModule.DIVAHisDBDataModuleCropped +_target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation diff --git a/configs/datamodule/cb55_cropped_datamodule.yaml b/configs/datamodule/cb55_cropped_datamodule.yaml index f3b58ad1..fe0f1dcc 100644 --- a/configs/datamodule/cb55_cropped_datamodule.yaml +++ b/configs/datamodule/cb55_cropped_datamodule.yaml @@ -1,4 +1,4 @@ -_target_: src.datamodules.hisDBDataModule.DIVAHisDBDataModule.DIVAHisDBDataModuleCropped +_target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 diff --git a/configs/experiment/cb55_full_run_unet.yaml b/configs/experiment/cb55_full_run_unet.yaml index 8fc00033..bab651f9 100644 --- a/configs/experiment/cb55_full_run_unet.yaml +++ b/configs/experiment/cb55_full_run_unet.yaml @@ -39,7 +39,7 @@ task: confusion_matrix_test: True datamodule: - _target_: src.datamodules.hisDBDataModule.DIVAHisDBDataModule.DIVAHisDBDataModuleCropped + _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 diff --git a/configs/experiment/development_baby_unet_cb55_10.yaml b/configs/experiment/development_baby_unet_cb55_10.yaml index 00fc12b5..16c31c3a 100644 --- a/configs/experiment/development_baby_unet_cb55_10.yaml +++ b/configs/experiment/development_baby_unet_cb55_10.yaml @@ -43,7 +43,7 @@ task: confusion_matrix_test: True datamodule: - _target_: src.datamodules.hisDBDataModule.DIVAHisDBDataModule.DIVAHisDBDataModuleCropped + _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation crop_size: 256 diff --git a/configs/task/semantic_segmentation_task.yaml b/configs/task/semantic_segmentation_task.yaml index e5ff43cf..20f5560e 100644 --- a/configs/task/semantic_segmentation_task.yaml +++ b/configs/task/semantic_segmentation_task.yaml @@ -1 +1,2 @@ -_target_: src.tasks.semantic_segmentation.semantic_segmentation.SemanticSegmentation +_target_: src.tasks.DivaHisDB.semantic_segmentation.SemanticSegmentation + diff --git a/src/datamodules/datasets/__init__.py b/src/datamodules/DivaHisDB/__init__.py similarity index 100% rename from src/datamodules/datasets/__init__.py rename to src/datamodules/DivaHisDB/__init__.py diff --git a/src/datamodules/hisDBDataModule/DIVAHisDBDataModule.py b/src/datamodules/DivaHisDB/datamodule_cropped.py similarity index 92% rename from src/datamodules/hisDBDataModule/DIVAHisDBDataModule.py rename to src/datamodules/DivaHisDB/datamodule_cropped.py index 2d1f3215..22be939b 100644 --- a/src/datamodules/hisDBDataModule/DIVAHisDBDataModule.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -5,17 +5,17 @@ from torchvision import transforms from src.datamodules.base_datamodule import AbstractDatamodule -from src.datamodules.datasets.cropped_hisdb_dataset import CroppedHisDBDataset -from src.datamodules.util.analytics.image_analytics import get_analytics -from src.datamodules.util.misc import validate_path_for_segmentation -from src.datamodules.util.transformations.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling -from src.datamodules.util.transformations.wrapper_transforms import OnlyImage, OnlyTarget +from src.datamodules.DivaHisDB.datasets.cropped_hisdb_dataset import CroppedHisDBDataset +from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics +from src.datamodules.DivaHisDB.utils.misc import validate_path_for_segmentation +from src.datamodules.DivaHisDB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling +from src.datamodules.DivaHisDB.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils log = utils.get_logger(__name__) -class DIVAHisDBDataModuleCropped(AbstractDatamodule): +class DivaHisDBDataModuleCropped(AbstractDatamodule): def __init__(self, data_dir: str = None, selection_train: Optional[Union[int, List[str]]] = None, selection_val: Optional[Union[int, List[str]]] = None, diff --git a/src/datamodules/hisDBDataModule/__init__.py b/src/datamodules/DivaHisDB/datasets/__init__.py similarity index 100% rename from src/datamodules/hisDBDataModule/__init__.py rename to src/datamodules/DivaHisDB/datasets/__init__.py diff --git a/src/datamodules/datasets/cropped_hisdb_dataset.py b/src/datamodules/DivaHisDB/datasets/cropped_hisdb_dataset.py similarity index 99% rename from src/datamodules/datasets/cropped_hisdb_dataset.py rename to src/datamodules/DivaHisDB/datasets/cropped_hisdb_dataset.py index 3965859a..dca49977 100644 --- a/src/datamodules/datasets/cropped_hisdb_dataset.py +++ b/src/datamodules/DivaHisDB/datasets/cropped_hisdb_dataset.py @@ -12,7 +12,7 @@ from torch import is_tensor from torchvision.transforms import ToTensor -from src.datamodules.util.misc import has_extension, pil_loader +from src.datamodules.DivaHisDB.utils.misc import has_extension, pil_loader from src.utils import utils IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] diff --git a/src/datamodules/util/__init__.py b/src/datamodules/DivaHisDB/utils/__init__.py similarity index 100% rename from src/datamodules/util/__init__.py rename to src/datamodules/DivaHisDB/utils/__init__.py diff --git a/src/datamodules/util/transformations/functional.py b/src/datamodules/DivaHisDB/utils/functional.py similarity index 100% rename from src/datamodules/util/transformations/functional.py rename to src/datamodules/DivaHisDB/utils/functional.py diff --git a/src/datamodules/util/analytics/image_analytics.py b/src/datamodules/DivaHisDB/utils/image_analytics.py similarity index 100% rename from src/datamodules/util/analytics/image_analytics.py rename to src/datamodules/DivaHisDB/utils/image_analytics.py diff --git a/src/datamodules/util/misc.py b/src/datamodules/DivaHisDB/utils/misc.py similarity index 95% rename from src/datamodules/util/misc.py rename to src/datamodules/DivaHisDB/utils/misc.py index 7138d6d9..998b8578 100644 --- a/src/datamodules/util/misc.py +++ b/src/datamodules/DivaHisDB/utils/misc.py @@ -9,7 +9,7 @@ import numpy as np from PIL import Image -from src.datamodules.util.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir +from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir try: import accimage diff --git a/src/tasks/semantic_segmentation/utils/output_tools.py b/src/datamodules/DivaHisDB/utils/output_tools.py similarity index 99% rename from src/tasks/semantic_segmentation/utils/output_tools.py rename to src/datamodules/DivaHisDB/utils/output_tools.py index ef7385e1..06fe4441 100644 --- a/src/tasks/semantic_segmentation/utils/output_tools.py +++ b/src/datamodules/DivaHisDB/utils/output_tools.py @@ -1,7 +1,6 @@ from pathlib import Path from typing import Union -import numpy import numpy as np import torch from PIL import Image diff --git a/src/datamodules/util/transformations/single_transforms.py b/src/datamodules/DivaHisDB/utils/single_transforms.py similarity index 100% rename from src/datamodules/util/transformations/single_transforms.py rename to src/datamodules/DivaHisDB/utils/single_transforms.py diff --git a/src/datamodules/util/transformations/twin_transforms.py b/src/datamodules/DivaHisDB/utils/twin_transforms.py similarity index 97% rename from src/datamodules/util/transformations/twin_transforms.py rename to src/datamodules/DivaHisDB/utils/twin_transforms.py index 206706a7..5b193825 100644 --- a/src/datamodules/util/transformations/twin_transforms.py +++ b/src/datamodules/DivaHisDB/utils/twin_transforms.py @@ -2,7 +2,7 @@ from torchvision.transforms import functional as F -from src.datamodules.util.transformations import functional as F_custom +from src.datamodules.DivaHisDB.utils import functional as F_custom class TwinCompose(object): diff --git a/src/datamodules/util/transformations/wrapper_transforms.py b/src/datamodules/DivaHisDB/utils/wrapper_transforms.py similarity index 100% rename from src/datamodules/util/transformations/wrapper_transforms.py rename to src/datamodules/DivaHisDB/utils/wrapper_transforms.py diff --git a/src/datamodules/util/analytics/__init__.py b/src/datamodules/utils/__init__.py similarity index 100% rename from src/datamodules/util/analytics/__init__.py rename to src/datamodules/utils/__init__.py diff --git a/src/datamodules/util/exceptions.py b/src/datamodules/utils/exceptions.py similarity index 100% rename from src/datamodules/util/exceptions.py rename to src/datamodules/utils/exceptions.py diff --git a/src/datamodules/util/transformations/__init__.py b/src/tasks/DivaHisDB/__init__.py similarity index 100% rename from src/datamodules/util/transformations/__init__.py rename to src/tasks/DivaHisDB/__init__.py diff --git a/src/tasks/semantic_segmentation/semantic_segmentation.py b/src/tasks/DivaHisDB/semantic_segmentation.py similarity index 98% rename from src/tasks/semantic_segmentation/semantic_segmentation.py rename to src/tasks/DivaHisDB/semantic_segmentation.py index 2397f297..3fd21b3a 100644 --- a/src/tasks/semantic_segmentation/semantic_segmentation.py +++ b/src/tasks/DivaHisDB/semantic_segmentation.py @@ -7,7 +7,7 @@ import torchmetrics from src.tasks.base_task import AbstractTask -from src.tasks.semantic_segmentation.utils.output_tools import _get_argmax +from src.datamodules.DivaHisDB.utils.output_tools import _get_argmax from src.utils import utils from src.tasks.utils.outputs import OutputKeys, reduce_dict diff --git a/src/tasks/semantic_segmentation/utils/__init__.py b/src/tasks/semantic_segmentation/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/tasks/semantic_segmentation/__init__.py b/tests/datamodules/DivaHisDB/__init__.py similarity index 100% rename from src/tasks/semantic_segmentation/__init__.py rename to tests/datamodules/DivaHisDB/__init__.py diff --git a/tests/datamodules/hisDBDataModule/test_hisDBDataModule.py b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py similarity index 93% rename from tests/datamodules/hisDBDataModule/test_hisDBDataModule.py rename to tests/datamodules/DivaHisDB/test_hisDBDataModule.py index c11925a2..e37d8bfa 100644 --- a/tests/datamodules/hisDBDataModule/test_hisDBDataModule.py +++ b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py @@ -4,7 +4,7 @@ from omegaconf import OmegaConf from pytest import fixture -from src.datamodules.hisDBDataModule.DIVAHisDBDataModule import DIVAHisDBDataModuleCropped +from src.datamodules.DivaHisDB.datamodule_cropped import DivaHisDBDataModuleCropped from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped from tests.datasets.test_cropped_hisdb_dataset import dataset_test @@ -14,7 +14,7 @@ @fixture def data_module_cropped(data_dir_cropped): OmegaConf.clear_resolvers() - datamodules = DIVAHisDBDataModuleCropped(data_dir_cropped, num_workers=NUM_WORKERS) + datamodules = DivaHisDBDataModuleCropped(data_dir_cropped, num_workers=NUM_WORKERS) return datamodules diff --git a/tests/datamodules/hisDBDataModule/test_image_analytics.py b/tests/datamodules/DivaHisDB/test_image_analytics.py similarity index 89% rename from tests/datamodules/hisDBDataModule/test_image_analytics.py rename to tests/datamodules/DivaHisDB/test_image_analytics.py index 72fbf96f..e9672a70 100644 --- a/tests/datamodules/hisDBDataModule/test_image_analytics.py +++ b/tests/datamodules/DivaHisDB/test_image_analytics.py @@ -2,8 +2,8 @@ import numpy as np -from src.datamodules.datasets.cropped_hisdb_dataset import CroppedHisDBDataset -from src.datamodules.util.analytics.image_analytics import get_analytics +from src.datamodules.DivaHisDB.datasets.cropped_hisdb_dataset import CroppedHisDBDataset +from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped TEST_JSON = {'mean': [0.7050454974582426, 0.6503181590413943, 0.5567698583877997], diff --git a/tests/datamodules/hisDBDataModule/test_misc.py b/tests/datamodules/DivaHisDB/test_misc.py similarity index 89% rename from tests/datamodules/hisDBDataModule/test_misc.py rename to tests/datamodules/DivaHisDB/test_misc.py index a22302ee..f18382b8 100644 --- a/tests/datamodules/hisDBDataModule/test_misc.py +++ b/tests/datamodules/DivaHisDB/test_misc.py @@ -1,8 +1,8 @@ import pytest from pytest import fixture -from src.datamodules.util.exceptions import PathNone, PathNotDir, PathMissingSplitDir, PathMissingDirinSplitDir -from src.datamodules.util.misc import validate_path_for_segmentation +from src.datamodules.utils.exceptions import PathNone, PathNotDir, PathMissingSplitDir, PathMissingDirinSplitDir +from src.datamodules.DivaHisDB.utils.misc import validate_path_for_segmentation @fixture diff --git a/tests/datamodules/hisDBDataModule/__init__.py b/tests/datamodules/hisDBDataModule/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/datasets/test_cropped_hisdb_dataset.py b/tests/datasets/test_cropped_hisdb_dataset.py index b95c57ed..7a54c860 100644 --- a/tests/datasets/test_cropped_hisdb_dataset.py +++ b/tests/datasets/test_cropped_hisdb_dataset.py @@ -4,7 +4,7 @@ import torch from pytest import fixture -from src.datamodules.datasets.cropped_hisdb_dataset import CroppedHisDBDataset +from src.datamodules.DivaHisDB.datasets.cropped_hisdb_dataset import CroppedHisDBDataset from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped diff --git a/tests/tasks/sem_seg/test_output_tools.py b/tests/tasks/sem_seg/test_output_tools.py index dbcfc3b2..f7bca67a 100644 --- a/tests/tasks/sem_seg/test_output_tools.py +++ b/tests/tasks/sem_seg/test_output_tools.py @@ -2,7 +2,7 @@ from PIL import Image from torch import tensor, equal -from src.tasks.semantic_segmentation.utils.output_tools import _get_argmax, merge_patches, output_to_class_encodings, \ +from src.datamodules.DivaHisDB.utils.output_tools import _get_argmax, merge_patches, output_to_class_encodings, \ save_output_page_image # batchsize (2) x classes (4) x W (2) x H (2) diff --git a/tests/tasks/sem_seg/test_semantic_segmentation.py b/tests/tasks/sem_seg/test_semantic_segmentation.py index f9884c92..af225008 100644 --- a/tests/tasks/sem_seg/test_semantic_segmentation.py +++ b/tests/tasks/sem_seg/test_semantic_segmentation.py @@ -7,9 +7,8 @@ from pl_bolts.models.vision import UNet from pytorch_lightning import seed_everything -from src.datamodules.hisDBDataModule.DIVAHisDBDataModule import DIVAHisDBDataModuleCropped -from src.tasks.semantic_segmentation.semantic_segmentation import SemanticSegmentation - +from src.datamodules.DivaHisDB.datamodule_cropped import DivaHisDBDataModuleCropped +from src.tasks.DivaHisDB.semantic_segmentation import SemanticSegmentation from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped @@ -18,7 +17,7 @@ def test_semantic_segmentation(data_dir_cropped, tmp_path): seed_everything(42) # datamodule - data_module = DIVAHisDBDataModuleCropped( + data_module = DivaHisDBDataModuleCropped( data_dir=str(data_dir_cropped), batch_size=2, num_workers=2) diff --git a/tests/tasks/utils/test_functional.py b/tests/tasks/utils/test_functional.py index 10bbb596..fb5b2a86 100644 --- a/tests/tasks/utils/test_functional.py +++ b/tests/tasks/utils/test_functional.py @@ -2,7 +2,7 @@ import torch from _pytest.fixtures import fixture -from src.datamodules.util.transformations.functional import gt_to_one_hot +from src.datamodules.DivaHisDB.utils.functional import gt_to_one_hot @fixture diff --git a/tools/merge_cropped_output.py b/tools/merge_cropped_output.py index 15ff6b17..c8f3cd8b 100644 --- a/tools/merge_cropped_output.py +++ b/tools/merge_cropped_output.py @@ -12,9 +12,9 @@ from PIL import Image from tqdm import tqdm -from src.datamodules.datasets.cropped_hisdb_dataset import CroppedHisDBDataset -from src.datamodules.hisDBDataModule.DIVAHisDBDataModule import DIVAHisDBDataModuleCropped -from src.tasks.semantic_segmentation.utils.output_tools import merge_patches, save_output_page_image +from src.datamodules.DivaHisDB.datamodule_cropped import DivaHisDBDataModuleCropped +from src.datamodules.DivaHisDB.datasets.cropped_hisdb_dataset import CroppedHisDBDataset +from src.datamodules.DivaHisDB.utils.output_tools import merge_patches, save_output_page_image from tools.generate_cropped_dataset import pil_loader from tools.viz import visualize @@ -40,7 +40,7 @@ def __init__(self, datamodule_path: Path, prediction_path: Path, output_path: Pa self.prediction_path = prediction_path self.output_path = output_path - data_module = DIVAHisDBDataModuleCropped(data_dir=str(datamodule_path)) + data_module = DivaHisDBDataModuleCropped(data_dir=str(datamodule_path)) self.num_classes = data_module.num_classes self.class_encodings = data_module.class_encodings From ba9150ef7cb53c499d7f057faf9f9c03456ff919 Mon Sep 17 00:00:00 2001 From: Paul M Date: Mon, 25 Oct 2021 14:43:29 +0200 Subject: [PATCH 005/108] :recycle: refactor datamodule files --- tests/{ => datamodules/DivaHisDB}/datasets/__init__.py | 0 .../DivaHisDB}/datasets/test_cropped_hisdb_dataset.py | 0 tests/datamodules/DivaHisDB/test_hisDBDataModule.py | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename tests/{ => datamodules/DivaHisDB}/datasets/__init__.py (100%) rename tests/{ => datamodules/DivaHisDB}/datasets/test_cropped_hisdb_dataset.py (100%) diff --git a/tests/datasets/__init__.py b/tests/datamodules/DivaHisDB/datasets/__init__.py similarity index 100% rename from tests/datasets/__init__.py rename to tests/datamodules/DivaHisDB/datasets/__init__.py diff --git a/tests/datasets/test_cropped_hisdb_dataset.py b/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py similarity index 100% rename from tests/datasets/test_cropped_hisdb_dataset.py rename to tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py diff --git a/tests/datamodules/DivaHisDB/test_hisDBDataModule.py b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py index e37d8bfa..d4b6dc63 100644 --- a/tests/datamodules/DivaHisDB/test_hisDBDataModule.py +++ b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py @@ -6,7 +6,7 @@ from src.datamodules.DivaHisDB.datamodule_cropped import DivaHisDBDataModuleCropped from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped -from tests.datasets.test_cropped_hisdb_dataset import dataset_test +from tests.datamodules.DivaHisDB.datasets.test_cropped_hisdb_dataset import dataset_test NUM_WORKERS = 4 From defe2f4ebd74974cf1aaf1ac57ac397750410b5b Mon Sep 17 00:00:00 2001 From: Paul M Date: Mon, 25 Oct 2021 14:55:02 +0200 Subject: [PATCH 006/108] :sparkle: data and gt folder names are parameters now --- configs/experiment/cb55_full_run_unet.yaml | 2 ++ .../experiment/development_baby_unet_cb55_10.yaml | 2 ++ src/datamodules/DivaHisDB/datamodule_cropped.py | 7 +++++-- .../DivaHisDB/datasets/cropped_hisdb_dataset.py | 15 ++++++++++----- src/datamodules/DivaHisDB/utils/misc.py | 11 ++++++----- 5 files changed, 25 insertions(+), 12 deletions(-) diff --git a/configs/experiment/cb55_full_run_unet.yaml b/configs/experiment/cb55_full_run_unet.yaml index bab651f9..cc6e29de 100644 --- a/configs/experiment/cb55_full_run_unet.yaml +++ b/configs/experiment/cb55_full_run_unet.yaml @@ -47,6 +47,8 @@ datamodule: batch_size: 16 shuffle: True drop_last: True + data_folder_name: data + gt_folder_name: gt callbacks: model_checkpoint: diff --git a/configs/experiment/development_baby_unet_cb55_10.yaml b/configs/experiment/development_baby_unet_cb55_10.yaml index 16c31c3a..ff8b5c1c 100644 --- a/configs/experiment/development_baby_unet_cb55_10.yaml +++ b/configs/experiment/development_baby_unet_cb55_10.yaml @@ -51,6 +51,8 @@ datamodule: batch_size: 16 shuffle: True drop_last: True + data_folder_name: data + gt_folder_name: gt callbacks: model_checkpoint: diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index 22be939b..8a4f89b8 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -16,7 +16,7 @@ class DivaHisDBDataModuleCropped(AbstractDatamodule): - def __init__(self, data_dir: str = None, + def __init__(self, data_dir: str = None, data_folder_name: str = 'data', gt_folder_name: str = 'gt', selection_train: Optional[Union[int, List[str]]] = None, selection_val: Optional[Union[int, List[str]]] = None, selection_test: Optional[Union[int, List[str]]] = None, @@ -48,7 +48,10 @@ def __init__(self, data_dir: str = None, self.shuffle = shuffle self.drop_last = drop_last - self.data_dir = validate_path_for_segmentation(data_dir) + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + self.data_dir = validate_path_for_segmentation(data_dir=data_dir, data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name) self.selection_train = selection_train self.selection_val = selection_val diff --git a/src/datamodules/DivaHisDB/datasets/cropped_hisdb_dataset.py b/src/datamodules/DivaHisDB/datasets/cropped_hisdb_dataset.py index dca49977..16707640 100644 --- a/src/datamodules/DivaHisDB/datasets/cropped_hisdb_dataset.py +++ b/src/datamodules/DivaHisDB/datasets/cropped_hisdb_dataset.py @@ -31,7 +31,8 @@ class CroppedHisDBDataset(data.Dataset): root/data/xxz.png """ - def __init__(self, path: Path, selection: Optional[Union[int, List[str]]] = None, + def __init__(self, path: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + selection: Optional[Union[int, List[str]]] = None, is_test=False, image_transform=None, target_transform=None, twin_transform=None, classes=None, **kwargs): """ @@ -53,6 +54,8 @@ def __init__(self, path: Path, selection: Optional[Union[int, List[str]]] = None """ self.path = path + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name self.selection = selection # Init list @@ -67,7 +70,8 @@ def __init__(self, path: Path, selection: Optional[Union[int, List[str]]] = None self.is_test = is_test # List of tuples that contain the path to the gt and image that belong together - self.img_paths_per_page = self.get_gt_data_paths(path, selection=self.selection) + self.img_paths_per_page = self.get_gt_data_paths(path, data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name, selection=self.selection) # TODO: make more fanzy stuff here # self.img_paths = [pair for page in self.img_paths_per_page for pair in page] @@ -144,7 +148,8 @@ def _apply_transformation(self, img, gt): return img, gt, border_mask @staticmethod - def get_gt_data_paths(directory: Path, selection: Optional[Union[int, List[str]]] = None) \ + def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + selection: Optional[Union[int, List[str]]] = None) \ -> List[Tuple[Path, Path, str, str, Tuple[int, int]]]: """ Structure of the folder @@ -161,8 +166,8 @@ def get_gt_data_paths(directory: Path, selection: Optional[Union[int, List[str]] paths = [] directory = directory.expanduser() - path_data_root = directory / 'data' - path_gt_root = directory / 'gt' + path_data_root = directory / data_folder_name + path_gt_root = directory / gt_folder_name if not (path_data_root.is_dir() or path_gt_root.is_dir()): log.error("folder data or gt not found in " + str(directory)) diff --git a/src/datamodules/DivaHisDB/utils/misc.py b/src/datamodules/DivaHisDB/utils/misc.py index 998b8578..3520992d 100644 --- a/src/datamodules/DivaHisDB/utils/misc.py +++ b/src/datamodules/DivaHisDB/utils/misc.py @@ -48,13 +48,13 @@ def convert_to_rgb(pic): return pic -def validate_path_for_segmentation(data_dir): +def validate_path_for_segmentation(data_dir, data_folder_name: str = 'data', gt_folder_name: str = 'gt'): if data_dir is None: raise PathNone("Please provide the path to root dir of the dataset " "(folder containing the train/val/test folder)") else: split_names = ['train', 'val', 'test'] - type_names = ['data', 'gt'] + type_names = [data_folder_name, gt_folder_name] data_folder = Path(data_dir) if not data_folder.is_dir(): @@ -62,13 +62,14 @@ def validate_path_for_segmentation(data_dir): "(folder containing the train/val/test folder)") split_folders = [d for d in data_folder.iterdir() if d.is_dir() and d.name in split_names] if len(split_folders) != 3: - raise PathMissingSplitDir("Your path needs to contain train/val/test and each of them a folder data and gt") + raise PathMissingSplitDir(f'Your path needs to contain train/val/test and ' + f'each of them a folder {data_folder_name} and {gt_folder_name}') # check if we have train/test/val for split in split_folders: type_folders = [d for d in split.iterdir() if d.is_dir() and d.name in type_names] # check if we have data/gt if len(type_folders) != 2: - raise PathMissingDirinSplitDir(f"Folder {split.name} does not contain a gt and data folder") - + raise PathMissingDirinSplitDir(f'Folder {split.name} does not contain a {data_folder_name} ' + f'and {gt_folder_name} folder') return Path(data_dir) From 95b4716f1b50f408bd0b7fd6a8c16fd4bb71fe0a Mon Sep 17 00:00:00 2001 From: Paul M Date: Mon, 25 Oct 2021 15:01:49 +0200 Subject: [PATCH 007/108] :recycle: changed dataset name --- src/datamodules/DivaHisDB/datamodule_cropped.py | 2 +- .../datasets/{cropped_hisdb_dataset.py => cropped_dataset.py} | 0 .../DivaHisDB/datasets/test_cropped_hisdb_dataset.py | 2 +- tests/datamodules/DivaHisDB/test_image_analytics.py | 2 +- tools/merge_cropped_output.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename src/datamodules/DivaHisDB/datasets/{cropped_hisdb_dataset.py => cropped_dataset.py} (100%) diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index 8a4f89b8..8f172007 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -5,7 +5,7 @@ from torchvision import transforms from src.datamodules.base_datamodule import AbstractDatamodule -from src.datamodules.DivaHisDB.datasets.cropped_hisdb_dataset import CroppedHisDBDataset +from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics from src.datamodules.DivaHisDB.utils.misc import validate_path_for_segmentation from src.datamodules.DivaHisDB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling diff --git a/src/datamodules/DivaHisDB/datasets/cropped_hisdb_dataset.py b/src/datamodules/DivaHisDB/datasets/cropped_dataset.py similarity index 100% rename from src/datamodules/DivaHisDB/datasets/cropped_hisdb_dataset.py rename to src/datamodules/DivaHisDB/datasets/cropped_dataset.py diff --git a/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py b/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py index 7a54c860..03d95c9b 100644 --- a/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py +++ b/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py @@ -4,7 +4,7 @@ import torch from pytest import fixture -from src.datamodules.DivaHisDB.datasets.cropped_hisdb_dataset import CroppedHisDBDataset +from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped diff --git a/tests/datamodules/DivaHisDB/test_image_analytics.py b/tests/datamodules/DivaHisDB/test_image_analytics.py index e9672a70..d00c3017 100644 --- a/tests/datamodules/DivaHisDB/test_image_analytics.py +++ b/tests/datamodules/DivaHisDB/test_image_analytics.py @@ -2,7 +2,7 @@ import numpy as np -from src.datamodules.DivaHisDB.datasets.cropped_hisdb_dataset import CroppedHisDBDataset +from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped diff --git a/tools/merge_cropped_output.py b/tools/merge_cropped_output.py index c8f3cd8b..b5ac8943 100644 --- a/tools/merge_cropped_output.py +++ b/tools/merge_cropped_output.py @@ -13,7 +13,7 @@ from tqdm import tqdm from src.datamodules.DivaHisDB.datamodule_cropped import DivaHisDBDataModuleCropped -from src.datamodules.DivaHisDB.datasets.cropped_hisdb_dataset import CroppedHisDBDataset +from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset from src.datamodules.DivaHisDB.utils.output_tools import merge_patches, save_output_page_image from tools.generate_cropped_dataset import pil_loader from tools.viz import visualize From 530270222565ce21ac59d53f1a414a402410c3e1 Mon Sep 17 00:00:00 2001 From: Paul M Date: Mon, 25 Oct 2021 17:36:16 +0200 Subject: [PATCH 008/108] :construction: initial datamodule files for RGB gt --- src/datamodules/RGB/__init__.py | 0 src/datamodules/RGB/datamodule_cropped.py | 141 ++++++++ src/datamodules/RGB/datasets/__init__.py | 0 .../RGB/datasets/cropped_dataset.py | 250 +++++++++++++ src/datamodules/RGB/utils/__init__.py | 0 src/datamodules/RGB/utils/functional.py | 57 +++ src/datamodules/RGB/utils/image_analytics.py | 336 ++++++++++++++++++ src/datamodules/RGB/utils/misc.py | 75 ++++ src/datamodules/RGB/utils/output_tools.py | 118 ++++++ .../RGB/utils/single_transforms.py | 144 ++++++++ src/datamodules/RGB/utils/twin_transforms.py | 101 ++++++ .../RGB/utils/wrapper_transforms.py | 37 ++ 12 files changed, 1259 insertions(+) create mode 100644 src/datamodules/RGB/__init__.py create mode 100644 src/datamodules/RGB/datamodule_cropped.py create mode 100644 src/datamodules/RGB/datasets/__init__.py create mode 100644 src/datamodules/RGB/datasets/cropped_dataset.py create mode 100644 src/datamodules/RGB/utils/__init__.py create mode 100644 src/datamodules/RGB/utils/functional.py create mode 100644 src/datamodules/RGB/utils/image_analytics.py create mode 100644 src/datamodules/RGB/utils/misc.py create mode 100644 src/datamodules/RGB/utils/output_tools.py create mode 100644 src/datamodules/RGB/utils/single_transforms.py create mode 100644 src/datamodules/RGB/utils/twin_transforms.py create mode 100644 src/datamodules/RGB/utils/wrapper_transforms.py diff --git a/src/datamodules/RGB/__init__.py b/src/datamodules/RGB/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py new file mode 100644 index 00000000..50ad0c67 --- /dev/null +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -0,0 +1,141 @@ +from pathlib import Path +from typing import Union, List, Optional + +from torch.utils.data import DataLoader +from torchvision import transforms + +from src.datamodules.base_datamodule import AbstractDatamodule +from src.datamodules.RGB.datasets.cropped_dataset import CroppedHisDBDataset +from src.datamodules.RGB.utils.image_analytics import get_analytics +from src.datamodules.RGB.utils.misc import validate_path_for_segmentation +from src.datamodules.RGB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling +from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget +from src.utils import utils + +log = utils.get_logger(__name__) + + +class DivaHisDBDataModuleCropped(AbstractDatamodule): + def __init__(self, data_dir: str = None, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + selection_train: Optional[Union[int, List[str]]] = None, + selection_val: Optional[Union[int, List[str]]] = None, + selection_test: Optional[Union[int, List[str]]] = None, + crop_size: int = 256, num_workers: int = 4, batch_size: int = 8, + shuffle: bool = True, drop_last: bool = True): + super().__init__() + + analytics = get_analytics(input_path=Path(data_dir), + get_gt_data_paths_func=CroppedHisDBDataset.get_gt_data_paths) + + self.mean = analytics['mean'] + self.std = analytics['std'] + self.class_encodings = analytics['class_encodings'] + self.num_classes = len(self.class_encodings) + self.class_weights = analytics['class_weights'] + + self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std)])) + self.target_transform = OnlyTarget(transforms.Compose([ + # transforms the gt image into a one-hot encoded matrix + OneHotEncoding(class_encodings=self.class_encodings), + # transforms the one hot encoding to argmax labels -> for the cross-entropy criterion + OneHotToPixelLabelling()])) + self.twin_transform = TwinRandomCrop(crop_size=crop_size) + + self.num_workers = num_workers + self.batch_size = batch_size + + self.shuffle = shuffle + self.drop_last = drop_last + + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + self.data_dir = validate_path_for_segmentation(data_dir=data_dir, data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name) + + self.selection_train = selection_train + self.selection_val = selection_val + self.selection_test = selection_test + + self.dims = (3, crop_size, crop_size) + + def setup(self, stage: Optional[str] = None): + super().setup() + if stage == 'fit' or stage is None: + self.train = CroppedHisDBDataset(**self._create_dataset_parameters('train'), selection=self.selection_train) + self.val = CroppedHisDBDataset(**self._create_dataset_parameters('val'), selection=self.selection_val) + + self._check_min_num_samples(num_samples=len(self.train), data_split='train', + drop_last=self.drop_last) + self._check_min_num_samples(num_samples=len(self.val), data_split='val', + drop_last=self.drop_last) + + if stage == 'test' or stage is not None: + self.test = CroppedHisDBDataset(**self._create_dataset_parameters('test'), selection=self.selection_test) + # self._check_min_num_samples(num_samples=len(self.test), data_split='test', + # drop_last=False) + + def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: bool): + num_processes = self.trainer.num_processes + batch_size = self.batch_size + if drop_last: + if num_samples < (self.trainer.num_processes * self.batch_size): + log.error( + f'#samples ({num_samples}) in "{data_split}" smaller than ' + f'#processes({num_processes}) times batch size ({batch_size}). ' + f'This only works if drop_last is false!') + raise ValueError() + else: + if num_samples < (self.trainer.num_processes * self.batch_size): + log.warning( + f'#samples ({num_samples}) in "{data_split}" smaller than ' + f'#processes ({num_processes}) times batch size ({batch_size}). ' + f'This works due to drop_last=False, however samples will occur multiple times. ' + f'Check if this behavior is intended!') + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(self.train, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=self.shuffle, + drop_last=self.drop_last, + pin_memory=True) + + def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + return DataLoader(self.val, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=self.shuffle, + drop_last=self.drop_last, + pin_memory=True) + + def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + return DataLoader(self.test, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + drop_last=False, + pin_memory=True) + + def _create_dataset_parameters(self, dataset_type: str = 'train'): + is_test = dataset_type == 'test' + return {'path': self.data_dir / dataset_type, + 'image_transform': self.image_transform, + 'target_transform': self.target_transform, + 'twin_transform': self.twin_transform, + 'classes': self.class_encodings, + 'is_test': is_test} + + def get_img_name_coordinates(self, index): + """ + Returns the original filename of the crop and its coordinate based on the index. + You can just use this during testing! + :param index: + :return: + """ + if not hasattr(self, 'test'): + raise Exception('This method can just be called during testing') + + return self.test.img_paths_per_page[index][2:] + + diff --git a/src/datamodules/RGB/datasets/__init__.py b/src/datamodules/RGB/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datamodules/RGB/datasets/cropped_dataset.py b/src/datamodules/RGB/datasets/cropped_dataset.py new file mode 100644 index 00000000..55e277df --- /dev/null +++ b/src/datamodules/RGB/datasets/cropped_dataset.py @@ -0,0 +1,250 @@ +""" +Load a dataset of historic documents by specifying the folder where its located. +""" + +# Utils +import re +from pathlib import Path +from typing import List, Tuple, Union, Optional + +import torch.utils.data as data +from omegaconf import ListConfig +from torch import is_tensor +from torchvision.transforms import ToTensor + +from src.datamodules.RGB.utils.misc import has_extension, pil_loader +from src.utils import utils + +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] + +log = utils.get_logger(__name__) + +class CroppedHisDBDataset(data.Dataset): + """A generic data loader where the images are arranged in this way: :: + + root/gt/xxx.png + root/gt/xxy.png + root/gt/xxz.png + + root/data/xxx.png + root/data/xxy.png + root/data/xxz.png + """ + + def __init__(self, path: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + selection: Optional[Union[int, List[str]]] = None, + is_test=False, image_transform=None, target_transform=None, twin_transform=None, + classes=None, **kwargs): + """ + #TODO doc + Parameters + ---------- + path : string + Path to dataset folder (train / val / test) + classes : + workers : int + imgs_in_memory : + crops_per_image : int + crop_size : int + image_transform : callable + target_transform : callable + twin_transform : callable + loader : callable + A function to load an image given its path. + """ + + self.path = path + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + self.selection = selection + + # Init list + self.classes = classes + # self.crops_per_image = crops_per_image + + # transformations + self.image_transform = image_transform + self.target_transform = target_transform + self.twin_transform = twin_transform + + self.is_test = is_test + + # List of tuples that contain the path to the gt and image that belong together + self.img_paths_per_page = self.get_gt_data_paths(path, data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name, selection=self.selection) + + # TODO: make more fanzy stuff here + # self.img_paths = [pair for page in self.img_paths_per_page for pair in page] + + self.num_samples = len(self.img_paths_per_page) + if self.num_samples == 0: + raise RuntimeError("Found 0 images in subfolders of: {} \n Supported image extensions are: {}".format( + path, ",".join(IMG_EXTENSIONS))) + + def __len__(self): + """ + This function returns the length of an epoch so the data loader knows when to stop. + The length is different during train/val and test, because we process the whole image during testing, + and only sample from the images during train/val. + """ + return self.num_samples + + def __getitem__(self, index): + if self.is_test: + return self._get_test_items(index=index) + else: + return self._get_train_val_items(index=index) + + def _get_train_val_items(self, index): + data_img, gt_img = self._load_data_and_gt(index=index) + img, gt, boundary_mask = self._apply_transformation(data_img, gt_img) + return img, gt, boundary_mask + + def _get_test_items(self, index): + data_img, gt_img = self._load_data_and_gt(index=index) + img, gt, boundary_mask = self._apply_transformation(data_img, gt_img) + return img, gt, boundary_mask, index + + def _load_data_and_gt(self, index): + data_img = pil_loader(self.img_paths_per_page[index][0]) + gt_img = pil_loader(self.img_paths_per_page[index][1]) + + return data_img, gt_img + + def _apply_transformation(self, img, gt): + """ + Applies the transformations that have been defined in the setup (setup.py). If no transformations + have been defined, the PIL image is returned instead. + + Parameters + ---------- + img: PIL image + image data + gt: PIL image + ground truth image + coordinates: tuple (int, int) + coordinates where the sliding window should be cropped + Returns + ------- + tuple + img and gt after transformations + """ + if self.twin_transform is not None and not self.is_test: + img, gt = self.twin_transform(img, gt) + + if self.image_transform is not None: + # perform transformations + img, gt = self.image_transform(img, gt) + + if not is_tensor(img): + img = ToTensor()(img) + if not is_tensor(gt): + gt = ToTensor()(gt) + + border_mask = gt[0, :, :] != 0 + if self.target_transform is not None: + img, gt = self.target_transform(img, gt) + + return img, gt, border_mask + + @staticmethod + def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + selection: Optional[Union[int, List[str]]] = None) \ + -> List[Tuple[Path, Path, str, str, Tuple[int, int]]]: + """ + Structure of the folder + + directory/data/ORIGINAL_FILENAME/FILE_NAME_X_Y.png + directory/gt/ORIGINAL_FILENAME/FILE_NAME_X_Y.png + + + :param directory: + :param selection: + :return: tuple + (path_data_file, path_gt_file, original_image_name, (x, y)) + """ + paths = [] + directory = directory.expanduser() + + path_data_root = directory / data_folder_name + path_gt_root = directory / gt_folder_name + + if not (path_data_root.is_dir() or path_gt_root.is_dir()): + log.error("folder data or gt not found in " + str(directory)) + + # get all subitems (and files) sorted + subitems = sorted(path_data_root.iterdir()) + + # check the selection parameter + if selection: + subdirectories = [x.name for x in subitems if x.is_dir()] + + if isinstance(selection, int): + if selection < 0: + msg = f'Parameter "selection" is a negative integer ({selection}). ' \ + f'Negative values are not supported!' + log.error(msg) + raise ValueError(msg) + + elif selection == 0: + selection = None + + elif selection > len(subdirectories): + msg = f'Parameter "selection" is larger ({selection}) than ' \ + f'number of subdirectories ({len(subdirectories)}).' + log.error(msg) + raise ValueError(msg) + + elif isinstance(selection, ListConfig) or isinstance(selection, list): + if not all(x in subdirectories for x in selection): + msg = f'Parameter "selection" contains a non-existing subdirectory.)' + log.error(msg) + raise ValueError(msg) + + else: + msg = f'Parameter "selection" exists, but it is of unsupported type ({type(selection)})' + log.error(msg) + raise TypeError(msg) + + counter = 0 # Counter for subdirectories, needed for selection parameter + + for path_data_subdir in subitems: + if not path_data_subdir.is_dir(): + if has_extension(path_data_subdir.name, IMG_EXTENSIONS): + log.warning("image file found in data root: " + str(path_data_subdir)) + continue + + counter += 1 + + if selection: + if isinstance(selection, int): + if counter > selection: + break + + elif isinstance(selection, ListConfig) or isinstance(selection, list): + if path_data_subdir.name not in selection: + continue + + path_gt_subdir = path_gt_root / path_data_subdir.stem + assert path_gt_subdir.is_dir() + + for path_data_file, path_gt_file in zip(sorted(path_data_subdir.iterdir()), + sorted(path_gt_subdir.iterdir())): + assert has_extension(path_data_file.name, IMG_EXTENSIONS) == \ + has_extension(path_gt_file.name, IMG_EXTENSIONS), \ + 'get_gt_data_paths(): image file aligned with non-image file' + + if has_extension(path_data_file.name, IMG_EXTENSIONS) and has_extension(path_gt_file.name, + IMG_EXTENSIONS): + assert path_data_file.stem == path_gt_file.stem, \ + 'get_gt_data_paths(): mismatch between data filename and gt filename' + coordinates = re.compile(r'.+_x(\d+)_y(\d+)\.') + m = coordinates.match(path_data_file.name) + if m is None: + continue + x = int(m.group(1)) + y = int(m.group(2)) + # TODO check if we need x/y + paths.append((path_data_file, path_gt_file, path_data_subdir.stem, path_data_file.stem, (x, y))) + + return paths diff --git a/src/datamodules/RGB/utils/__init__.py b/src/datamodules/RGB/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datamodules/RGB/utils/functional.py b/src/datamodules/RGB/utils/functional.py new file mode 100644 index 00000000..bf9e24de --- /dev/null +++ b/src/datamodules/RGB/utils/functional.py @@ -0,0 +1,57 @@ +from typing import List + +import numpy as np +import torch + +from sklearn.preprocessing import OneHotEncoder + + +def gt_to_one_hot(matrix: torch.Tensor, class_encodings: List[int]): + """ + Convert ground truth tensor or numpy matrix to one-hot encoded matrix + + Parameters + ------- + matrix: float tensor from to_tensor() or numpy array + shape (C x H x W) in the range [0.0, 1.0] or shape (H x W x C) BGR + class_encodings: List of int + Blue channel values that encode the different classes + Returns + ------- + torch.LongTensor of size [#C x H x W] + sparse one-hot encoded multi-class matrix, where #C is the number of classes + """ + num_classes = len(class_encodings) + + if type(matrix).__module__ == np.__name__: + im_np = matrix[:, :, 2].astype(np.uint8) + border_mask = matrix[:, :, 0].astype(np.uint8) != 0 + else: + # TODO: ugly fix -> better to not normalize in the first place + np_array = (matrix * 255).numpy().astype(np.uint8) + im_np = np_array[2, :, :].astype(np.uint8) + border_mask = np_array[0, :, :].astype(np.uint8) != 0 + im_np[border_mask] = 1 + + integer_encoded = np.array([i for i in range(num_classes)]) + onehot_encoder = OneHotEncoder(sparse=False, categories='auto') + integer_encoded = integer_encoded.reshape(len(integer_encoded), 1) + onehot_encoded = onehot_encoder.fit_transform(integer_encoded).astype(np.int8) + + np.place(im_np, im_np == 0, + 1) # needed to deal with 0 fillers at the borders during testing (replace with background) + replace_dict = {k: v for k, v in zip(class_encodings, onehot_encoded)} + + # create the one hot matrix + one_hot_matrix = np.asanyarray( + [[replace_dict[im_np[i, j]] for j in range(im_np.shape[1])] for i in range(im_np.shape[0])]).astype( + np.uint8) + + return torch.LongTensor(one_hot_matrix.transpose((2, 0, 1))) + + +def argmax_onehot(tensor: torch.Tensor): + """ + # TODO + """ + return torch.LongTensor(torch.argmax(tensor, dim=0)) diff --git a/src/datamodules/RGB/utils/image_analytics.py b/src/datamodules/RGB/utils/image_analytics.py new file mode 100644 index 00000000..5dcc90c7 --- /dev/null +++ b/src/datamodules/RGB/utils/image_analytics.py @@ -0,0 +1,336 @@ +# Utils +import errno +import json +import logging +import os +from multiprocessing import Pool +from pathlib import Path +from typing import List + +import numpy as np +# Torch related stuff +import torch +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from PIL import Image + + +def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): + """ + Parameters + ---------- + input_path: Path to dataset + + Returns + ------- + """ + analytics_file_path = input_path / 'analytics.json' + if analytics_file_path.exists(): + with analytics_file_path.open(mode='r') as f: + analytics_dict = json.load(fp=f) + else: + train_path = input_path / 'train' + gt_data_path_list = get_gt_data_paths_func(train_path) + file_names_data = np.asarray([str(item[0]) for item in gt_data_path_list]) + file_names_gt = np.asarray([str(item[1]) for item in gt_data_path_list]) + mean, std = compute_mean_std(file_names=file_names_data, **kwargs) + + # Measure weights for class balancing + logging.info(f'Measuring class weights') + # create a list with all gt file paths + class_weights, class_encodings = _get_class_frequencies_weights_segmentation(gt_images=file_names_gt, **kwargs) + analytics_dict = {'mean': mean.tolist(), + 'std': std.tolist(), + 'class_weights': class_weights.tolist(), + 'class_encodings': class_encodings.tolist()} + # save json + try: + with analytics_file_path.open(mode='w') as f: + json.dump(obj=analytics_dict, fp=f) + except IOError as e: + if e.errno == errno.EACCES: + print(f'WARNING: No permissions to write analytics file ({analytics_file_path})') + else: + raise + # returns the 'mean[RGB]', 'std[RGB]', 'class_frequencies_weights[num_classes]', 'class_encodings' + return analytics_dict + + +def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): + """ + Computes mean and std of all images present at target folder. + + Parameters + ---------- + input_folder : String (path) + Path to the dataset folder (see above for details) + inmem : Boolean + Specifies whether is should be computed i nan online of offline fashion. + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + mean : float + Mean value of all pixels of the images in the input folder + std : float + Standard deviation of all pixels of the images in the input folder + """ + file_names_np = np.array(list(map(str, file_names))) + # Compute mean and std + mean, std = _cms_inmem(file_names_np) if inmem else _cms_online(file_names_np, workers) + return mean, std + + +def _cms_online(file_names, workers=4): + """ + Computes mean and image_classification deviation in an online fashion. + This is useful when the dataset is too big to be allocated in memory. + + Parameters + ---------- + file_names : List of String + List of file names of the dataset + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + mean : double + std : double + """ + logging.info('Begin computing the mean') + + # Set up a pool of workers + pool = Pool(workers + 1) + + # Online mean + results = pool.map(_return_mean, file_names) + mean_sum = np.sum(np.array(results), axis=0) + + # Divide by number of samples in train set + mean = mean_sum / file_names.size + + logging.info('Finished computing the mean') + logging.info('Begin computing the std') + + # Online image_classification deviation + results = pool.starmap(_return_std, [[item, mean] for item in file_names]) + std_sum = np.sum(np.array([item[0] for item in results]), axis=0) + total_pixel_count = np.sum(np.array([item[1] for item in results])) + std = np.sqrt(std_sum / total_pixel_count) + logging.info('Finished computing the std') + + # Shut down the pool + pool.close() + + return mean, std + + +# Loads an image with OpenCV and returns the channel wise means of the image. +def _return_mean(image_path): + img = np.array(Image.open(image_path).convert('RGB')) + mean = np.array([np.mean(img[:, :, 0]), np.mean(img[:, :, 1]), np.mean(img[:, :, 2])]) / 255.0 + return mean + + +# Loads an image with OpenCV and returns the channel wise std of the image. +def _return_std(image_path, mean): + img = np.array(Image.open(image_path).convert('RGB')) / 255.0 + m2 = np.square(np.array([img[:, :, 0] - mean[0], img[:, :, 1] - mean[1], img[:, :, 2] - mean[2]])) + return np.sum(np.sum(m2, axis=1), 1), m2.size / 3.0 + + +def _cms_inmem(file_names): + """ + Computes mean and image_classification deviation in an offline fashion. This is possible only when the dataset can + be allocated in memory. + + Parameters + ---------- + file_names: List of String + List of file names of the dataset + Returns + ------- + mean : double + std : double + """ + img = np.zeros([file_names.size] + list(np.array(Image.open(file_names[0]).convert('RGB')).shape)) + + # Load all samples + for i, sample in enumerate(file_names): + img[i] = np.array(Image.open(sample).convert('RGB')) + + mean = np.array([np.mean(img[:, :, :, 0]), np.mean(img[:, :, :, 1]), np.mean(img[:, :, :, 2])]) / 255.0 + std = np.array([np.std(img[:, :, :, 0]), np.std(img[:, :, :, 1]), np.std(img[:, :, :, 2])]) / 255.0 + + return mean, std + + +def get_class_weights(input_folder, workers=4, **kwargs): + """ + Get the weights proportional to the inverse of their class frequencies. + The vector sums up to 1 + + Parameters + ---------- + input_folder : String (path) + Path to the dataset folder (see above for details) + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + ndarray[double] of size (num_classes) + The weights vector as a 1D array normalized (sum up to 1) + """ + # Sanity check on the folder + if not os.path.isdir(input_folder): + logging.error(f"Folder {input_folder} does not exist") + raise FileNotFoundError + + # Load the dataset + ds = datasets.ImageFolder(input_folder, transform=transforms.Compose([transforms.ToTensor()])) + + logging.info('Begin computing class frequencies weights') + + if hasattr(ds, 'targets'): + labels = ds.targets + elif hasattr(ds, 'labels'): + labels = ds.labels + else: + # This is a fail-safe net in case a custom dataset changed the name of the internal variables + data_loader = torch.utils.data.DataLoader(ds, batch_size=1, num_workers=workers) + labels = [] + for target, label in data_loader: + labels.append(label) + labels = np.concatenate(labels).reshape(len(ds)) + + class_support = np.unique(labels, return_counts=True)[1] + class_frequencies = class_support / len(labels) + # Class weights are the inverse of the class frequencies + class_weights = 1 / class_frequencies + # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) + class_weights /= class_weights.sum() + + logging.info('Finished computing class frequencies weights ') + logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') + logging.info(f'Class weights (rounded): {np.around(class_weights * 100, decimals=2)}') + + return class_weights + + +def compute_mean_std_graphs(dataset, **kwargs): + """ + Computes mean and std of all node and edge features present in the given ParsedGxlDataset (see gxl_parser.py). + + Parameters + ---------- + input_folder : ParsedGxlDataset + Dataset object (see above for details) + + # TODO implement online version + + Returns + ------- + node_features : {"mean": list, "std": list} + Mean and std value of all node features in the input dataset + edge_features : {"mean": list, "std": list} + Mean and std value of all edge features in the input dataset + """ + if dataset.data.x is not None: + logging.info('Begin computing the node feature mean and std') + nodes = _get_feature_mean_std(dataset.data.x) + logging.info('Finished computing the node feature mean and std') + else: + nodes = {} + logging.info('No node features present') + + if dataset.data.edge_attr is not None: + logging.info('Begin computing the edge feature mean and std') + edges = _get_feature_mean_std(dataset.data.edge_attr) + logging.info('Finished computing the edge feature mean and std') + else: + edges = {} + logging.info('No edge features present') + + return nodes, edges + + +def _get_feature_mean_std(torch_array): + array = np.array(torch_array) + return {'mean': [np.mean(col) for col in array.T], 'std': [np.std(col) for col in array.T]} + + +def get_class_weights_graphs(dataset, **kwargs): + """ + Get the weights proportional to the inverse of their class frequencies. + The vector sums up to 1 + + Parameters + ---------- + input_folder : ParsedGxlDataset + Dataset object (see above for details) + + # TODO implement online version + + Returns + ------- + ndarray[double] of size (num_classes) + The weights vector as a 1D array normalized (sum up to 1) + """ + logging.info('Begin computing class frequencies weights') + + class_frequencies = np.array(dataset.config['class_freq'][1]) + # Class weights are the inverse of the class frequencies + class_weights = 1 / class_frequencies + # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) + class_weights /= class_weights.sum() + + logging.info('Finished computing class frequencies weights ') + logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') + logging.info(f'Class weights (rounded): {np.around(class_weights)}') + + return class_weights + + +def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): + """ + Get the weights proportional to the inverse of their class frequencies. + The vector sums up to 1 + + Parameters + ---------- + gt_images: list of strings + Path to all ground truth images, which contain the pixel-wise label + workers: int + Number of workers to use for the mean/std computation + + Returns + ------- + ndarray[double] of size (num_classes) and ints the classes are represented as + The weights vector as a 1D array normalized (sum up to 1) + """ + logging.info('Begin computing class frequencies weights') + + total_num_pixels = 0 + label_counter = {} + + for path in gt_images: + img = np.array(Image.open(path))[:, :, 2].flatten() + total_num_pixels += len(img) + for i, j in zip(*np.unique(img, return_counts=True)): + label_counter[i] = label_counter.get(i, 0) + j + + classes = np.array(sorted(label_counter.keys())) + num_samples_per_class = np.array([label_counter[k] for k in classes]) + class_frequencies = (num_samples_per_class / total_num_pixels) + logging.info('Finished computing class frequencies weights') + logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') + # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) + return (1 / num_samples_per_class) / ((1 / num_samples_per_class).sum()), classes + + +if __name__ == '__main__': + # print(get_analytics(input_path=Path('/netscratch/datasets/semantic_segmentation/datasets/CB55/'), inmem=True, workers=16)) + print(get_analytics(input_path=Path('tests/dummy_data/dummy_dataset'))) diff --git a/src/datamodules/RGB/utils/misc.py b/src/datamodules/RGB/utils/misc.py new file mode 100644 index 00000000..3520992d --- /dev/null +++ b/src/datamodules/RGB/utils/misc.py @@ -0,0 +1,75 @@ +""" +General purpose utility functions. + +""" + +from pathlib import Path + +# Utils +import numpy as np +from PIL import Image + +from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir + +try: + import accimage +except ImportError: + accimage = None + + +def has_extension(filename, extensions): + filename_lower = filename.lower() + return any(filename_lower.endswith(ext) for ext in extensions) + + +def pil_loader(path, to_rgb=True): + pic = Image.open(path) + if to_rgb: + pic = convert_to_rgb(pic) + return pic + + +def convert_to_rgb(pic): + if pic.mode == "RGB": + pass + elif pic.mode in ("CMYK", "RGBA", "P"): + pic = pic.convert('RGB') + elif pic.mode == "I": + img = (np.divide(np.array(pic, np.int32), 2 ** 16 - 1) * 255).astype(np.uint8) + pic = Image.fromarray(np.stack((img, img, img), axis=2)) + elif pic.mode == "I;16": + img = (np.divide(np.array(pic, np.int16), 2 ** 8 - 1) * 255).astype(np.uint8) + pic = Image.fromarray(np.stack((img, img, img), axis=2)) + elif pic.mode == "L": + img = np.array(pic).astype(np.uint8) + pic = Image.fromarray(np.stack((img, img, img), axis=2)) + else: + raise TypeError(f"unsupported image type {pic.mode}") + return pic + + +def validate_path_for_segmentation(data_dir, data_folder_name: str = 'data', gt_folder_name: str = 'gt'): + if data_dir is None: + raise PathNone("Please provide the path to root dir of the dataset " + "(folder containing the train/val/test folder)") + else: + split_names = ['train', 'val', 'test'] + type_names = [data_folder_name, gt_folder_name] + + data_folder = Path(data_dir) + if not data_folder.is_dir(): + raise PathNotDir("Please provide the path to root dir of the dataset " + "(folder containing the train/val/test folder)") + split_folders = [d for d in data_folder.iterdir() if d.is_dir() and d.name in split_names] + if len(split_folders) != 3: + raise PathMissingSplitDir(f'Your path needs to contain train/val/test and ' + f'each of them a folder {data_folder_name} and {gt_folder_name}') + + # check if we have train/test/val + for split in split_folders: + type_folders = [d for d in split.iterdir() if d.is_dir() and d.name in type_names] + # check if we have data/gt + if len(type_folders) != 2: + raise PathMissingDirinSplitDir(f'Folder {split.name} does not contain a {data_folder_name} ' + f'and {gt_folder_name} folder') + return Path(data_dir) diff --git a/src/datamodules/RGB/utils/output_tools.py b/src/datamodules/RGB/utils/output_tools.py new file mode 100644 index 00000000..06fe4441 --- /dev/null +++ b/src/datamodules/RGB/utils/output_tools.py @@ -0,0 +1,118 @@ +from pathlib import Path +from typing import Union + +import numpy as np +import torch +from PIL import Image + + +def _get_argmax(output: Union[torch.Tensor, np.ndarray], dim=1): + """ + takes the biggest value from a pixel across all classes + :param output: (Batch_size x num_classes x W x H) + matrix with the given attributes + :return: (Batch_size x W x H) + matrix with the hisdb class number for each pixel + """ + if isinstance(output, torch.Tensor): + return torch.argmax(output, dim=dim) + if isinstance(output, np.ndarray): + return np.argmax(output, axis=dim) + return output + + +def merge_patches(patch, coordinates, full_output): + """ + This function merges the patch into the full output image + Overlapping values are resolved by taking the max. + + Parameters + ---------- + patch: numpy matrix of size [#classes x crop_size x crop_size] + a patch from the larger image + coordinates: tuple of ints + top left coordinates of the patch within the larger image for all patches in a batch + full_output: numpy matrix of size [#C x H x W] + output image at full size + Returns + ------- + full_output: numpy matrix [#C x Htot x Wtot] + """ + assert len(full_output.shape) == 3 + assert full_output.size != 0 + + # Resolve patch coordinates + x1, y1 = coordinates + x2, y2 = x1 + patch.shape[2], y1 + patch.shape[1] + + # If this triggers it means that a patch is 'out-of-bounds' of the image and that should never happen! + assert x2 <= full_output.shape[2] + assert y2 <= full_output.shape[1] + + mask = np.isnan(full_output[:, y1:y2, x1:x2]) + # if still NaN in full_output just insert value from crop, if there is a value then take max + full_output[:, y1:y2, x1:x2] = np.where(mask, patch, np.maximum(patch, full_output[:, y1:y2, x1:x2])) + + return full_output + + +def save_output_page_image(image_name, output_image, output_folder: Path, class_encoding): + """ + Helper function to save the output during testing in the DIVAHisDB format + + Parameters + ---------- + image_name: str + name of the image that is saved + output_image: numpy matrix of size [#C x H x W] + output image at full size + output_folder: Path + path to the output folder for the test data + class_encoding: list(int) + list with the class encodings + + Returns + ------- + mean_iu : float + mean iu of this image + """ + + output_encoded = output_to_class_encodings(output_image, class_encoding) + + dest_folder = output_folder + dest_folder.mkdir(parents=True, exist_ok=True) + dest_filename = dest_folder / image_name + + # Save the output + Image.fromarray(output_encoded.astype(np.uint8)).save(str(dest_filename)) + + +def output_to_class_encodings(output, class_encodings, perform_argmax=True): + """ + This function converts the output prediction matrix to an image like it was provided in the ground truth + + Parameters + ------- + output : np.array of size [#C x H x W] + output prediction of the network for a full-size image, where #C is the number of classes + class_encodings : List + Contains the range of encoded classes + perform_argmax : bool + perform argmax on input data + Returns + ------- + numpy array of size [C x H x W] (BGR) + """ + + B = np.argmax(output, axis=0) if perform_argmax else output + + class_to_B = {i: j for i, j in enumerate(class_encodings)} + + masks = [B == old for old in class_to_B.keys()] + + for mask, (old, new) in zip(masks, class_to_B.items()): + B = np.where(mask, new, B) + + rgb = np.dstack((np.zeros(shape=(B.shape[0], B.shape[1], 2), dtype=np.int8), B)) + + return rgb diff --git a/src/datamodules/RGB/utils/single_transforms.py b/src/datamodules/RGB/utils/single_transforms.py new file mode 100644 index 00000000..dc48230b --- /dev/null +++ b/src/datamodules/RGB/utils/single_transforms.py @@ -0,0 +1,144 @@ +import math +import random + +import torch +from PIL import Image +from torchvision.transforms import Pad + + +class ResizePad(object): + """ + Perform resizing keeping the aspect ratio of the image --padding type: continuous (black). + Expects PIL image and int value as target_size + (It can be extended to perform other transforms on both PIL image and object boxes.) + + Example: + target_size = 200 + # im: numpy array + img = Image.fromarray(im.astype('uint8'), 'RGB') + img = ResizePad(target_size)(img) + """ + + def __init__(self, target_size): + self.target_size = target_size + self.boxes = torch.Tensor([[0, 0, 0, 0]]) + + def resize(self, img, boxes, size, max_size=1000): + '''Resize the input PIL image to the given size. + Args: + img: (PIL.Image) image to be resized. + boxes: (tensor) object boxes, sized [#ojb,4]. + size: (tuple or int) + - if is tuple, resize image to the size. + - if is int, resize the shorter side to the size while maintaining the aspect ratio. + max_size: (int) when size is int, limit the image longer size to max_size. + This is essential to limit the usage of GPU memory. + Returns: + img: (PIL.Image) resized image. + boxes: (tensor) resized boxes. + ''' + w, h = img.size + if isinstance(size, int): + size_min = min(w, h) + size_max = max(w, h) + sw = sh = float(size) / size_min + if sw * size_max > max_size: + sw = sh = float(max_size) / size_max + ow = int(w * sw + 0.5) + oh = int(h * sh + 0.5) + else: + ow, oh = size + sw = float(ow) / w + sh = float(oh) / h + return img.resize((ow, oh), Image.BILINEAR), \ + boxes * torch.Tensor([sw, sh, sw, sh]) + + def random_crop(self, img, boxes): + '''Crop the given PIL image to a random size and aspect ratio. + A crop of random size of (0.08 to 1.0) of the original size and a random + aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. + Args: + img: (PIL.Image) image to be cropped. + boxes: (tensor) object boxes, sized [#ojb,4]. + Returns: + img: (PIL.Image) randomly cropped image. + boxes: (tensor) randomly cropped boxes. + ''' + success = False + for attempt in range(10): + area = img.size[0] * img.size[1] + target_area = random.uniform(0.56, 1.0) * area + aspect_ratio = random.uniform(3. / 4, 4. / 3) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if random.random() < 0.5: + w, h = h, w + + if w <= img.size[0] and h <= img.size[1]: + x = random.randint(0, img.size[0] - w) + y = random.randint(0, img.size[1] - h) + success = True + break + + # Fallback + if not success: + w = h = min(img.size[0], img.size[1]) + x = (img.size[0] - w) // 2 + y = (img.size[1] - h) // 2 + + img = img.crop((x, y, x + w, y + h)) + boxes -= torch.Tensor([x, y, x, y]) + boxes[:, 0::2].clamp_(min=0, max=w - 1) + boxes[:, 1::2].clamp_(min=0, max=h - 1) + return img, boxes + + def center_crop(self, img, boxes, size): + '''Crops the given PIL Image at the center. + Args: + img: (PIL.Image) image to be cropped. + boxes: (tensor) object boxes, sized [#ojb,4]. + size (tuple): desired output size of (w,h). + Returns: + img: (PIL.Image) center cropped image. + boxes: (tensor) center cropped boxes. + ''' + w, h = img.size + ow, oh = size + i = int(round((h - oh) / 2.)) + j = int(round((w - ow) / 2.)) + img = img.crop((j, i, j + ow, i + oh)) + boxes -= torch.Tensor([j, i, j, i]) + boxes[:, 0::2].clamp_(min=0, max=ow - 1) + boxes[:, 1::2].clamp_(min=0, max=oh - 1) + return img, boxes + + def random_flip(self, img, boxes): + '''Randomly flip the given PIL Image. + Args: + img: (PIL Image) image to be flipped. + boxes: (tensor) object boxes, sized [#ojb,4]. + Returns: + img: (PIL.Image) randomly flipped image. + boxes: (tensor) randomly flipped boxes. + ''' + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + w = img.width + xmin = w - boxes[:, 2] + xmax = w - boxes[:, 0] + boxes[:, 0] = xmin + boxes[:, 2] = xmax + return img, boxes + + def resize_with_padding(self, img, target_size): + img, boxes = self.resize(img, self.boxes, target_size, max_size=target_size) + padding = (max(0, target_size - img.size[0]) // 2, max(0, target_size - img.size[1]) // 2) + img = Pad(padding)(img) + + return img + + def __call__(self, img): + img = self.resize_with_padding(img, self.target_size) + return img \ No newline at end of file diff --git a/src/datamodules/RGB/utils/twin_transforms.py b/src/datamodules/RGB/utils/twin_transforms.py new file mode 100644 index 00000000..b0382c5c --- /dev/null +++ b/src/datamodules/RGB/utils/twin_transforms.py @@ -0,0 +1,101 @@ +import random + +from torchvision.transforms import functional as F + +from src.datamodules.RGB.utils import functional as F_custom + + +class TwinCompose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img, gt): + for t in self.transforms: + img, gt = t(img, gt) + return img, gt + + +class TwinRandomCrop(object): + """Crop the given PIL Images at the same random location""" + + def __init__(self, crop_size): + self.crop_size = crop_size + + def get_params(self, img_size): + """Get parameters for ``crop`` for a random crop""" + w, h = img_size + th = self.crop_size + tw = self.crop_size + + assert w >= tw and h >= th + + if w == tw and h == th: + return 0, 0, h, w + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + def __call__(self, img, gt): + i, j, h, w = self.get_params(img.size) + return F.crop(img, i, j, h, w), F.crop(gt, i, j, h, w) + + +class TwinImageToTensor(object): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + Converts a PIL Image or numpy.ndarray (W x H x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. + """ + + def __call__(self, img, gt): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + return F.to_tensor(img), F.to_tensor(gt) + + +class ToTensorSlidingWindowCrop(object): + """ + Crop the data and ground truth image at the specified coordinates to the specified size and convert + them to a tensor. + """ + + def __init__(self, crop_size): + self.crop_size = crop_size + + def __call__(self, img, gt, coordinates): + """ + Args: + img (PIL Image): Data image to be cropped and converted to tensor. + gt (PIL Image): Ground truth image to be cropped and converted to tensor. + + Returns: + Data tensor, gt tensor (tuple of tensors): cropped and converted images + + """ + x_position = coordinates[0] + y_position = coordinates[1] + + return F.to_tensor(F.crop(img, x_position, y_position, self.crop_size, self.crop_size)), \ + F.to_tensor(F.crop(gt, x_position, y_position, self.crop_size, self.crop_size)) + + +class OneHotToPixelLabelling(object): + def __call__(self, tensor): + return F_custom.argmax_onehot(tensor) + + +class OneHotEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_one_hot(gt, self.class_encodings) \ No newline at end of file diff --git a/src/datamodules/RGB/utils/wrapper_transforms.py b/src/datamodules/RGB/utils/wrapper_transforms.py new file mode 100644 index 00000000..eaa5e437 --- /dev/null +++ b/src/datamodules/RGB/utils/wrapper_transforms.py @@ -0,0 +1,37 @@ +from typing import Callable + + +class OnlyImage(object): + """Wrapper function around a single parameter transform. It will be cast only on image""" + + def __init__(self, transform: Callable): + """Initialize the transformation with the transformation to be called. + Could be a compose. + + Parameters + ---------- + transform : torchvision.transforms.transforms + Transformation to wrap + """ + self.transform = transform + + def __call__(self, image, target): + return self.transform(image), target + + +class OnlyTarget(object): + """Wrapper function around a single parameter transform. It will be cast only on target""" + + def __init__(self, transform: Callable): + """Initialize the transformation with the transformation to be called. + Could be a compose. + + Parameters + ---------- + transform : torchvision.transforms.transforms + Transformation to wrap + """ + self.transform = transform + + def __call__(self, image, target): + return image, self.transform(target) \ No newline at end of file From ece230000299184d7c6e3bf3e9fa6d4096ecbff0 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 26 Oct 2021 16:33:08 +0200 Subject: [PATCH 009/108] :sparkle: new dataset for rotnet based on the cropped version of divahisdb --- src/datamodules/RotNet/datasets/__init__.py | 0 .../RotNet/datasets/cropped_dataset.py | 194 ++++++++++++++++++ tests/datamodules/RotNet/__init__.py | 0 tests/datamodules/RotNet/datasets/__init__.py | 0 .../RotNet/datasets/test_cropped_dataset.py | 154 ++++++++++++++ 5 files changed, 348 insertions(+) create mode 100644 src/datamodules/RotNet/datasets/__init__.py create mode 100644 src/datamodules/RotNet/datasets/cropped_dataset.py create mode 100644 tests/datamodules/RotNet/__init__.py create mode 100644 tests/datamodules/RotNet/datasets/__init__.py create mode 100644 tests/datamodules/RotNet/datasets/test_cropped_dataset.py diff --git a/src/datamodules/RotNet/datasets/__init__.py b/src/datamodules/RotNet/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datamodules/RotNet/datasets/cropped_dataset.py b/src/datamodules/RotNet/datasets/cropped_dataset.py new file mode 100644 index 00000000..dd9377b7 --- /dev/null +++ b/src/datamodules/RotNet/datasets/cropped_dataset.py @@ -0,0 +1,194 @@ +""" +Load a dataset of historic documents by specifying the folder where its located. +""" + +# Utils +import random +from pathlib import Path +from typing import List, Union, Optional + +import torchvision.transforms.functional +from omegaconf import ListConfig +from torch import is_tensor +from torchvision.transforms import ToTensor + +from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset +from src.datamodules.RotNet.utils.misc import has_extension, pil_loader +from src.utils import utils + +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] +ROTATION_ANGELS = [0, 90, 180, 270] + +log = utils.get_logger(__name__) + + +class CroppedRotNet(CroppedHisDBDataset): + """A generic data loader where the images are arranged in this way: :: + + root/gt/xxx.png + root/gt/xxy.png + root/gt/xxz.png + + root/data/xxx.png + root/data/xxy.png + root/data/xxz.png + """ + + def __init__(self, path: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + selection: Optional[Union[int, List[str]]] = None, + is_test=False, image_transform=None, twin_transform=None, + classes=None, **kwargs): + """ + #TODO doc + Parameters + ---------- + path : string + Path to dataset folder (train / val / test) + classes : + workers : int + imgs_in_memory : + crops_per_image : int + crop_size : int + image_transform : callable + target_transform : callable + twin_transform : callable + loader : callable + A function to load an image given its path. + """ + + super(CroppedRotNet, self).__init__(path=path, data_folder_name=data_folder_name, gt_folder_name=gt_folder_name, + selection=selection, + is_test=is_test, image_transform=image_transform, + target_transform=None, twin_transform=twin_transform, + classes=classes, **kwargs) + + def __getitem__(self, index): + data_img = self._load_data_and_gt(index=index) + img, gt = self._apply_transformation(data_img, index=index) + return img, gt + + def _load_data_and_gt(self, index): + data_img = pil_loader(self.img_paths_per_page[index]) + return data_img + + def _apply_transformation(self, img, index): + """ + Applies the transformations that have been defined in the setup (setup.py). If no transformations + have been defined, the PIL image is returned instead. + + Parameters + ---------- + img: PIL image + image data + gt: PIL image + ground truth image + coordinates: tuple (int, int) + coordinates where the sliding window should be cropped + Returns + ------- + tuple + img and gt after transformations + """ + if self.twin_transform is not None and not self.is_test: + img, gt = self.twin_transform(img, None) + + if self.image_transform is not None: + # perform transformations + img = self.image_transform(img) + + if not is_tensor(img): + img = ToTensor()(img) + + target_class = index % len(ROTATION_ANGELS) + rotation_angle = ROTATION_ANGELS[target_class] + + img = torchvision.transforms.functional.rotate(img=img, angle=rotation_angle) + + return img, target_class + + @staticmethod + def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + selection: Optional[Union[int, List[str]]] = None) \ + -> List[Path]: + """ + Structure of the folder + + directory/data/ORIGINAL_FILENAME/FILE_NAME_X_Y.png + directory/gt/ORIGINAL_FILENAME/FILE_NAME_X_Y.png + + + :param directory: + :param selection: + :return: tuple + (path_data_file, path_gt_file, original_image_name, (x, y)) + """ + paths = [] + directory = directory.expanduser() + + path_data_root = directory / data_folder_name + path_gt_root = directory / gt_folder_name + + if not (path_data_root.is_dir() or path_gt_root.is_dir()): + log.error("folder data or gt not found in " + str(directory)) + + # get all subitems (and files) sorted + subitems = sorted(path_data_root.iterdir()) + + # check the selection parameter + if selection: + subdirectories = [x.name for x in subitems if x.is_dir()] + + if isinstance(selection, int): + if selection < 0: + msg = f'Parameter "selection" is a negative integer ({selection}). ' \ + f'Negative values are not supported!' + log.error(msg) + raise ValueError(msg) + + elif selection == 0: + selection = None + + elif selection > len(subdirectories): + msg = f'Parameter "selection" is larger ({selection}) than ' \ + f'number of subdirectories ({len(subdirectories)}).' + log.error(msg) + raise ValueError(msg) + + elif isinstance(selection, ListConfig) or isinstance(selection, list): + if not all(x in subdirectories for x in selection): + msg = f'Parameter "selection" contains a non-existing subdirectory.)' + log.error(msg) + raise ValueError(msg) + + else: + msg = f'Parameter "selection" exists, but it is of unsupported type ({type(selection)})' + log.error(msg) + raise TypeError(msg) + + counter = 0 # Counter for subdirectories, needed for selection parameter + + for path_data_subdir in subitems: + if not path_data_subdir.is_dir(): + if has_extension(path_data_subdir.name, IMG_EXTENSIONS): + log.warning("image file found in data root: " + str(path_data_subdir)) + continue + + counter += 1 + + if selection: + if isinstance(selection, int): + if counter > selection: + break + + elif isinstance(selection, ListConfig) or isinstance(selection, list): + if path_data_subdir.name not in selection: + continue + + for path_data_file in sorted(path_data_subdir.iterdir()): + if has_extension(path_data_file.name, IMG_EXTENSIONS): + paths.append(path_data_file) + paths.append(path_data_file) + paths.append(path_data_file) + paths.append(path_data_file) + + return paths diff --git a/tests/datamodules/RotNet/__init__.py b/tests/datamodules/RotNet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datamodules/RotNet/datasets/__init__.py b/tests/datamodules/RotNet/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py new file mode 100644 index 00000000..83ab5076 --- /dev/null +++ b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py @@ -0,0 +1,154 @@ +from pathlib import PosixPath + +import numpy as np +import torch +from _pytest.fixtures import fixture +from torchvision.transforms import ToTensor +from torchvision.transforms.functional import rotate + +from src.datamodules.RotNet.datasets.cropped_dataset import CroppedRotNet, ROTATION_ANGELS +from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped + + +@fixture +def dataset_train(data_dir_cropped): + return CroppedRotNet(path=data_dir_cropped / 'train') + + +def test__load_data_and_gt(dataset_train): + img = dataset_train._load_data_and_gt(0) + assert img.size == (300, 300) + assert img.format == 'PNG' + assert np.all(np.array(img)[150][150] == np.array([97, 72, 32])) + + +def test__apply_transformation(dataset_train): + img0_o = dataset_train._load_data_and_gt(0) + img1_o = dataset_train._load_data_and_gt(1) + img2_o = dataset_train._load_data_and_gt(2) + img3_o = dataset_train._load_data_and_gt(3) + img4_o = dataset_train._load_data_and_gt(4) + + img0, gt0 = dataset_train._apply_transformation(img0_o, 0) + assert torch.equal(img0, ToTensor()(img0_o)) + assert gt0 == 0 + + img1, gt1 = dataset_train._apply_transformation(img1_o, 1) + assert not torch.equal(ToTensor()(img1_o), img1) + assert torch.equal(img1, rotate(img=ToTensor()(img1_o), angle=ROTATION_ANGELS[1])) + assert gt1 == 1 + + img2, gt2 = dataset_train._apply_transformation(img2_o, 2) + assert not torch.equal(ToTensor()(img2_o), img2) + assert torch.equal(img2, rotate(img=ToTensor()(img2_o), angle=ROTATION_ANGELS[2])) + assert gt2 == 2 + + img3, gt3 = dataset_train._apply_transformation(img3_o, 3) + assert not torch.equal(ToTensor()(img3_o), img3) + assert torch.equal(img3, rotate(img=ToTensor()(img3_o), angle=ROTATION_ANGELS[3])) + assert gt3 == 3 + + img4, gt4 = dataset_train._apply_transformation(img4_o, 0) + assert torch.equal(img4, ToTensor()(img4_o)) + assert gt4 == 0 + + +def test_get_gt_data_paths(data_dir_cropped): + file_paths = CroppedRotNet.get_gt_data_paths(directory=data_dir_cropped / 'train') + assert len(file_paths) == 48 + assert file_paths == [PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), + PosixPath( + data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png')] From bd4a441acad2e2a943038731199da78d5f060a4d Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 26 Oct 2021 17:07:34 +0200 Subject: [PATCH 010/108] :art: now returning a one hot encoded gt --- .../RotNet/datasets/cropped_dataset.py | 20 ++++++++++--------- .../RotNet/datasets/test_cropped_dataset.py | 18 ++++++++--------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/datamodules/RotNet/datasets/cropped_dataset.py b/src/datamodules/RotNet/datasets/cropped_dataset.py index dd9377b7..d116d198 100644 --- a/src/datamodules/RotNet/datasets/cropped_dataset.py +++ b/src/datamodules/RotNet/datasets/cropped_dataset.py @@ -3,10 +3,11 @@ """ # Utils -import random from pathlib import Path from typing import List, Union, Optional +import numpy as np +import torch import torchvision.transforms.functional from omegaconf import ListConfig from torch import is_tensor @@ -17,7 +18,7 @@ from src.utils import utils IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] -ROTATION_ANGELS = [0, 90, 180, 270] +ROTATION_ANGLES = [0, 90, 180, 270] log = utils.get_logger(__name__) @@ -36,8 +37,7 @@ class CroppedRotNet(CroppedHisDBDataset): def __init__(self, path: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', selection: Optional[Union[int, List[str]]] = None, - is_test=False, image_transform=None, twin_transform=None, - classes=None, **kwargs): + is_test=False, image_transform=None, **kwargs): """ #TODO doc Parameters @@ -59,8 +59,8 @@ def __init__(self, path: Path, data_folder_name: str = 'data', gt_folder_name: s super(CroppedRotNet, self).__init__(path=path, data_folder_name=data_folder_name, gt_folder_name=gt_folder_name, selection=selection, is_test=is_test, image_transform=image_transform, - target_transform=None, twin_transform=twin_transform, - classes=classes, **kwargs) + target_transform=None, twin_transform=None, + classes=None, **kwargs) def __getitem__(self, index): data_img = self._load_data_and_gt(index=index) @@ -99,12 +99,14 @@ def _apply_transformation(self, img, index): if not is_tensor(img): img = ToTensor()(img) - target_class = index % len(ROTATION_ANGELS) - rotation_angle = ROTATION_ANGELS[target_class] + target_class = index % len(ROTATION_ANGLES) + rotation_angle = ROTATION_ANGLES[target_class] + hot_hot_encoded = np.zeros(len(ROTATION_ANGLES)) + hot_hot_encoded[target_class] = 1 img = torchvision.transforms.functional.rotate(img=img, angle=rotation_angle) - return img, target_class + return img, torch.LongTensor(hot_hot_encoded) @staticmethod def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', diff --git a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py index 83ab5076..ee658dd3 100644 --- a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py +++ b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py @@ -6,7 +6,7 @@ from torchvision.transforms import ToTensor from torchvision.transforms.functional import rotate -from src.datamodules.RotNet.datasets.cropped_dataset import CroppedRotNet, ROTATION_ANGELS +from src.datamodules.RotNet.datasets.cropped_dataset import CroppedRotNet, ROTATION_ANGLES from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped @@ -31,26 +31,26 @@ def test__apply_transformation(dataset_train): img0, gt0 = dataset_train._apply_transformation(img0_o, 0) assert torch.equal(img0, ToTensor()(img0_o)) - assert gt0 == 0 + assert torch.equal(gt0, torch.LongTensor([1, 0, 0, 0])) img1, gt1 = dataset_train._apply_transformation(img1_o, 1) assert not torch.equal(ToTensor()(img1_o), img1) - assert torch.equal(img1, rotate(img=ToTensor()(img1_o), angle=ROTATION_ANGELS[1])) - assert gt1 == 1 + assert torch.equal(img1, rotate(img=ToTensor()(img1_o), angle=ROTATION_ANGLES[1])) + assert torch.equal(gt1, torch.LongTensor([0, 1, 0, 0])) img2, gt2 = dataset_train._apply_transformation(img2_o, 2) assert not torch.equal(ToTensor()(img2_o), img2) - assert torch.equal(img2, rotate(img=ToTensor()(img2_o), angle=ROTATION_ANGELS[2])) - assert gt2 == 2 + assert torch.equal(img2, rotate(img=ToTensor()(img2_o), angle=ROTATION_ANGLES[2])) + assert torch.equal(gt2, torch.LongTensor([0, 0, 1, 0])) img3, gt3 = dataset_train._apply_transformation(img3_o, 3) assert not torch.equal(ToTensor()(img3_o), img3) - assert torch.equal(img3, rotate(img=ToTensor()(img3_o), angle=ROTATION_ANGELS[3])) - assert gt3 == 3 + assert torch.equal(img3, rotate(img=ToTensor()(img3_o), angle=ROTATION_ANGLES[3])) + assert torch.equal(gt3, torch.LongTensor([0, 0, 0, 1])) img4, gt4 = dataset_train._apply_transformation(img4_o, 0) assert torch.equal(img4, ToTensor()(img4_o)) - assert gt4 == 0 + assert torch.equal(gt4, torch.LongTensor([1, 0, 0, 0])) def test_get_gt_data_paths(data_dir_cropped): From cea6532042fe4ed392f80b657703cd9cbcd626e0 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 26 Oct 2021 17:16:15 +0200 Subject: [PATCH 011/108] :sparkle: misc for self_supervised training plus test --- src/datamodules/RotNet/utils/__init__.py | 0 src/datamodules/RotNet/utils/misc.py | 74 ++++++++++++++++++++++++ tests/datamodules/RotNet/test_misc.py | 6 ++ 3 files changed, 80 insertions(+) create mode 100644 src/datamodules/RotNet/utils/__init__.py create mode 100644 src/datamodules/RotNet/utils/misc.py create mode 100644 tests/datamodules/RotNet/test_misc.py diff --git a/src/datamodules/RotNet/utils/__init__.py b/src/datamodules/RotNet/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datamodules/RotNet/utils/misc.py b/src/datamodules/RotNet/utils/misc.py new file mode 100644 index 00000000..78688046 --- /dev/null +++ b/src/datamodules/RotNet/utils/misc.py @@ -0,0 +1,74 @@ +""" +General purpose utility functions. + +""" + +from pathlib import Path + +# Utils +import numpy as np +from PIL import Image + +from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir + +try: + import accimage +except ImportError: + accimage = None + + +def has_extension(filename, extensions): + filename_lower = filename.lower() + return any(filename_lower.endswith(ext) for ext in extensions) + + +def pil_loader(path, to_rgb=True): + pic = Image.open(path) + if to_rgb: + pic = convert_to_rgb(pic) + return pic + + +def convert_to_rgb(pic): + if pic.mode == "RGB": + pass + elif pic.mode in ("CMYK", "RGBA", "P"): + pic = pic.convert('RGB') + elif pic.mode == "I": + img = (np.divide(np.array(pic, np.int32), 2 ** 16 - 1) * 255).astype(np.uint8) + pic = Image.fromarray(np.stack((img, img, img), axis=2)) + elif pic.mode == "I;16": + img = (np.divide(np.array(pic, np.int16), 2 ** 8 - 1) * 255).astype(np.uint8) + pic = Image.fromarray(np.stack((img, img, img), axis=2)) + elif pic.mode == "L": + img = np.array(pic).astype(np.uint8) + pic = Image.fromarray(np.stack((img, img, img), axis=2)) + else: + raise TypeError(f"unsupported image type {pic.mode}") + return pic + + +def validate_path_for_self_supervised(data_dir, data_folder_name: str = 'data'): + if data_dir is None: + raise PathNone("Please provide the path to root dir of the dataset " + "(folder containing the train/val/test folder)") + else: + split_names = ['train', 'val', 'test'] + type_names = [data_folder_name] + + data_folder = Path(data_dir) + if not data_folder.is_dir(): + raise PathNotDir("Please provide the path to root dir of the dataset " + "(folder containing the train/val/test folder)") + split_folders = [d for d in data_folder.iterdir() if d.is_dir() and d.name in split_names] + if len(split_folders) != 3: + raise PathMissingSplitDir(f'Your path needs to contain train/val/test and ' + f'each of them a folder {data_folder_name}') + + # check if we have train/test/val + for split in split_folders: + type_folders = [d for d in split.iterdir() if d.is_dir() and d.name in type_names] + # check if we have data/gt + if len(type_folders) != 1: + raise PathMissingDirinSplitDir(f'Folder {split.name} does not contain a {data_folder_name}') + return Path(data_dir) diff --git a/tests/datamodules/RotNet/test_misc.py b/tests/datamodules/RotNet/test_misc.py new file mode 100644 index 00000000..d9d163d4 --- /dev/null +++ b/tests/datamodules/RotNet/test_misc.py @@ -0,0 +1,6 @@ +from src.datamodules.RotNet.utils.misc import validate_path_for_self_supervised +from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped + + +def test_validate_path_for_self_supervised(data_dir_cropped): + assert data_dir_cropped == validate_path_for_self_supervised(data_dir_cropped) From 60c243a8360733a9b83a8eb0d6cb05f130d9062b Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 26 Oct 2021 17:26:19 +0200 Subject: [PATCH 012/108] :art: replaced fixtrue with pytest.fixture --- .../DivaHisDB/datasets/test_cropped_hisdb_dataset.py | 9 ++++----- tests/datamodules/DivaHisDB/test_hisDBDataModule.py | 3 +-- tests/datamodules/DivaHisDB/test_misc.py | 5 ++--- .../datamodules/RotNet/datasets/test_cropped_dataset.py | 4 ++-- tests/tasks/utils/test_functional.py | 5 ++--- tests/tasks/utils/test_outputs.py | 4 ++-- tests/test_data/dummy_data_hisdb/dummy_data.py | 6 +++--- tests/utils/test_utils.py | 3 +-- 8 files changed, 17 insertions(+), 22 deletions(-) diff --git a/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py b/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py index 03d95c9b..cb9d721f 100644 --- a/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py +++ b/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py @@ -2,23 +2,22 @@ import pytest import torch -from pytest import fixture from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped -@fixture +@pytest.fixture def dataset_train(data_dir_cropped): return CroppedHisDBDataset(path=data_dir_cropped / 'train') -@fixture +@pytest.fixture def dataset_val(data_dir_cropped): return CroppedHisDBDataset(path=data_dir_cropped / 'val') -@fixture +@pytest.fixture def dataset_test(data_dir_cropped): return CroppedHisDBDataset(path=data_dir_cropped / 'test') @@ -215,7 +214,7 @@ def test_get_gt_data_paths_test(data_dir_cropped): assert files_from_method == expected_result -@fixture +@pytest.fixture def get_train_file_names(data_dir_cropped): return [(PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), diff --git a/tests/datamodules/DivaHisDB/test_hisDBDataModule.py b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py index d4b6dc63..ae6c55d3 100644 --- a/tests/datamodules/DivaHisDB/test_hisDBDataModule.py +++ b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py @@ -2,7 +2,6 @@ import pytest from numpy import uint8 from omegaconf import OmegaConf -from pytest import fixture from src.datamodules.DivaHisDB.datamodule_cropped import DivaHisDBDataModuleCropped from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped @@ -11,7 +10,7 @@ NUM_WORKERS = 4 -@fixture +@pytest.fixture def data_module_cropped(data_dir_cropped): OmegaConf.clear_resolvers() datamodules = DivaHisDBDataModuleCropped(data_dir_cropped, num_workers=NUM_WORKERS) diff --git a/tests/datamodules/DivaHisDB/test_misc.py b/tests/datamodules/DivaHisDB/test_misc.py index f18382b8..34c8c445 100644 --- a/tests/datamodules/DivaHisDB/test_misc.py +++ b/tests/datamodules/DivaHisDB/test_misc.py @@ -1,11 +1,10 @@ import pytest -from pytest import fixture from src.datamodules.utils.exceptions import PathNone, PathNotDir, PathMissingSplitDir, PathMissingDirinSplitDir from src.datamodules.DivaHisDB.utils.misc import validate_path_for_segmentation -@fixture +@pytest.fixture def path_missing_split(tmp_path): list_splits = ['train', 'test'] @@ -16,7 +15,7 @@ def path_missing_split(tmp_path): return tmp_path -@fixture +@pytest.fixture def path_missing_subfolder(tmp_path): list_splits_good = ['train', 'val'] list_types_good = ['data', 'gt'] diff --git a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py index ee658dd3..aebf958a 100644 --- a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py +++ b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py @@ -1,8 +1,8 @@ from pathlib import PosixPath import numpy as np +import pytest import torch -from _pytest.fixtures import fixture from torchvision.transforms import ToTensor from torchvision.transforms.functional import rotate @@ -10,7 +10,7 @@ from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped -@fixture +@pytest.fixture def dataset_train(data_dir_cropped): return CroppedRotNet(path=data_dir_cropped / 'train') diff --git a/tests/tasks/utils/test_functional.py b/tests/tasks/utils/test_functional.py index fb5b2a86..f4b91c37 100644 --- a/tests/tasks/utils/test_functional.py +++ b/tests/tasks/utils/test_functional.py @@ -1,16 +1,15 @@ import pytest import torch -from _pytest.fixtures import fixture from src.datamodules.DivaHisDB.utils.functional import gt_to_one_hot -@fixture +@pytest.fixture def get_class_encodings(): return [1, 2] -@fixture +@pytest.fixture def get_input_tensor(): return torch.tensor( [[[0.01, 0.1], [0.001, 0.01], [0.01, 0.1]], [[0.01, 0.1], [0.01, 0.1], [3.01, 0.1]], diff --git a/tests/tasks/utils/test_outputs.py b/tests/tasks/utils/test_outputs.py index 0cdeb483..eb35a3e7 100644 --- a/tests/tasks/utils/test_outputs.py +++ b/tests/tasks/utils/test_outputs.py @@ -1,9 +1,9 @@ -from _pytest.fixtures import fixture +import pytest from src.tasks.utils.outputs import OutputKeys, reduce_dict -@fixture +@pytest.fixture def get_dict(): return {OutputKeys.PREDICTION: [1, 2, 3, 3], OutputKeys.TARGET: [1, 2, 3, 4], diff --git a/tests/test_data/dummy_data_hisdb/dummy_data.py b/tests/test_data/dummy_data_hisdb/dummy_data.py index df445215..592dee01 100644 --- a/tests/test_data/dummy_data_hisdb/dummy_data.py +++ b/tests/test_data/dummy_data_hisdb/dummy_data.py @@ -1,10 +1,10 @@ import os from distutils import dir_util -from pytest import fixture +import pytest -@fixture +@pytest.fixture def data_dir(tmp_path): """ Moves the test data into the tmp path of the testing environment. @@ -21,7 +21,7 @@ def data_dir(tmp_path): return tmp_path -@fixture +@pytest.fixture def data_dir_cropped(tmp_path): """ Moves the test data into the tmp path of the testing environment. diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 25fe6156..aba67e62 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,11 +1,10 @@ import pytest -from _pytest.fixtures import fixture from omegaconf import DictConfig from src.utils.utils import _check_if_in_config, REQUIRED_CONFIGS, check_config -@fixture +@pytest.fixture def get_dict(): return DictConfig({'plugins': { 'ddp_plugin': {'_target_': 'pytorch_lightning.plugins.DDPPlugin', 'find_unused_parameters': False}}, 'task': { From d5d29f754736e5a9c873a2abb95879f3473a0f91 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 26 Oct 2021 17:35:14 +0200 Subject: [PATCH 013/108] :bug: replaced some old code --- src/datamodules/DivaHisDB/utils/image_analytics.py | 12 ++++-------- src/datamodules/RotNet/datasets/cropped_dataset.py | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/datamodules/DivaHisDB/utils/image_analytics.py b/src/datamodules/DivaHisDB/utils/image_analytics.py index 5dcc90c7..c95899b7 100644 --- a/src/datamodules/DivaHisDB/utils/image_analytics.py +++ b/src/datamodules/DivaHisDB/utils/image_analytics.py @@ -15,7 +15,7 @@ from PIL import Image -def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): +def get_analytics(input_path: Path, get_data_paths_func, **kwargs): """ Parameters ---------- @@ -30,19 +30,15 @@ def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): analytics_dict = json.load(fp=f) else: train_path = input_path / 'train' - gt_data_path_list = get_gt_data_paths_func(train_path) - file_names_data = np.asarray([str(item[0]) for item in gt_data_path_list]) - file_names_gt = np.asarray([str(item[1]) for item in gt_data_path_list]) + data_path_list = get_data_paths_func(train_path) + file_names_data = np.asarray([str(item) for item in data_path_list]) mean, std = compute_mean_std(file_names=file_names_data, **kwargs) # Measure weights for class balancing logging.info(f'Measuring class weights') # create a list with all gt file paths - class_weights, class_encodings = _get_class_frequencies_weights_segmentation(gt_images=file_names_gt, **kwargs) analytics_dict = {'mean': mean.tolist(), - 'std': std.tolist(), - 'class_weights': class_weights.tolist(), - 'class_encodings': class_encodings.tolist()} + 'std': std.tolist()} # save json try: with analytics_file_path.open(mode='w') as f: diff --git a/src/datamodules/RotNet/datasets/cropped_dataset.py b/src/datamodules/RotNet/datasets/cropped_dataset.py index d116d198..ea790c14 100644 --- a/src/datamodules/RotNet/datasets/cropped_dataset.py +++ b/src/datamodules/RotNet/datasets/cropped_dataset.py @@ -109,7 +109,7 @@ def _apply_transformation(self, img, index): return img, torch.LongTensor(hot_hot_encoded) @staticmethod - def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + def get_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', selection: Optional[Union[int, List[str]]] = None) \ -> List[Path]: """ From c585d32bd29a6380161e1985795082d1e2c8a981 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 26 Oct 2021 17:44:12 +0200 Subject: [PATCH 014/108] :bug: renamed it to the parent name --- src/datamodules/RotNet/datasets/cropped_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datamodules/RotNet/datasets/cropped_dataset.py b/src/datamodules/RotNet/datasets/cropped_dataset.py index ea790c14..d116d198 100644 --- a/src/datamodules/RotNet/datasets/cropped_dataset.py +++ b/src/datamodules/RotNet/datasets/cropped_dataset.py @@ -109,7 +109,7 @@ def _apply_transformation(self, img, index): return img, torch.LongTensor(hot_hot_encoded) @staticmethod - def get_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', selection: Optional[Union[int, List[str]]] = None) \ -> List[Path]: """ From fff07ab67f5d4cc1c160e49ec59ee425d7731c22 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 26 Oct 2021 17:44:45 +0200 Subject: [PATCH 015/108] :art: use now np.array_equal --- tests/datamodules/RotNet/datasets/test_cropped_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py index aebf958a..dced31f9 100644 --- a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py +++ b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py @@ -19,7 +19,7 @@ def test__load_data_and_gt(dataset_train): img = dataset_train._load_data_and_gt(0) assert img.size == (300, 300) assert img.format == 'PNG' - assert np.all(np.array(img)[150][150] == np.array([97, 72, 32])) + assert np.array_equal(np.array(img)[150][150], np.array([97, 72, 32])) def test__apply_transformation(dataset_train): From 8e1aa91ac22942a5e648e7acf934ea2042f7e924 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 26 Oct 2021 17:45:39 +0200 Subject: [PATCH 016/108] :bug: refactoring --- .../RotNet/utils/wrapper_transforms.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/datamodules/RotNet/utils/wrapper_transforms.py diff --git a/src/datamodules/RotNet/utils/wrapper_transforms.py b/src/datamodules/RotNet/utils/wrapper_transforms.py new file mode 100644 index 00000000..eaa5e437 --- /dev/null +++ b/src/datamodules/RotNet/utils/wrapper_transforms.py @@ -0,0 +1,37 @@ +from typing import Callable + + +class OnlyImage(object): + """Wrapper function around a single parameter transform. It will be cast only on image""" + + def __init__(self, transform: Callable): + """Initialize the transformation with the transformation to be called. + Could be a compose. + + Parameters + ---------- + transform : torchvision.transforms.transforms + Transformation to wrap + """ + self.transform = transform + + def __call__(self, image, target): + return self.transform(image), target + + +class OnlyTarget(object): + """Wrapper function around a single parameter transform. It will be cast only on target""" + + def __init__(self, transform: Callable): + """Initialize the transformation with the transformation to be called. + Could be a compose. + + Parameters + ---------- + transform : torchvision.transforms.transforms + Transformation to wrap + """ + self.transform = transform + + def __call__(self, image, target): + return image, self.transform(target) \ No newline at end of file From a505bdc0c4dd999bbd9b680423ba61fb846aa8c4 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 26 Oct 2021 17:53:47 +0200 Subject: [PATCH 017/108] :sparkle: rotnet datamodule --- src/datamodules/RotNet/datamodule_cropped.py | 119 ++++++++++++++++++ .../RotNet/test_datamodule_cropped.py | 36 ++++++ 2 files changed, 155 insertions(+) create mode 100644 src/datamodules/RotNet/datamodule_cropped.py create mode 100644 tests/datamodules/RotNet/test_datamodule_cropped.py diff --git a/src/datamodules/RotNet/datamodule_cropped.py b/src/datamodules/RotNet/datamodule_cropped.py new file mode 100644 index 00000000..3c6b4169 --- /dev/null +++ b/src/datamodules/RotNet/datamodule_cropped.py @@ -0,0 +1,119 @@ +from pathlib import Path +from typing import Union, List, Optional + +import numpy as np +from torch.utils.data import DataLoader +from torchvision import transforms + +from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics +from src.datamodules.RotNet.datasets.cropped_dataset import CroppedRotNet, ROTATION_ANGLES +from src.datamodules.RotNet.utils.misc import validate_path_for_self_supervised +from src.datamodules.RotNet.utils.wrapper_transforms import OnlyImage +from src.datamodules.base_datamodule import AbstractDatamodule +from src.utils import utils + +log = utils.get_logger(__name__) + + +class RotNetDivaHisDBDataModuleCropped(AbstractDatamodule): + def __init__(self, data_dir: str = None, data_folder_name: str = 'data', + selection_train: Optional[Union[int, List[str]]] = None, + selection_val: Optional[Union[int, List[str]]] = None, + selection_test: Optional[Union[int, List[str]]] = None, + crop_size: int = 256, num_workers: int = 4, batch_size: int = 8, + shuffle: bool = True, drop_last: bool = True): + super().__init__() + + analytics = get_analytics(input_path=Path(data_dir), + get_data_paths_func=CroppedRotNet.get_gt_data_paths) + + self.mean = analytics['mean'] + self.std = analytics['std'] + self.class_encodings = np.array(ROTATION_ANGLES) + self.num_classes = len(self.class_encodings) + self.class_weights = np.ones(self.num_classes) + + self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + transforms.RandomCrop(size=crop_size)])) + + self.num_workers = num_workers + self.batch_size = batch_size + + self.shuffle = shuffle + self.drop_last = drop_last + + self.data_folder_name = data_folder_name + self.data_dir = validate_path_for_self_supervised(data_dir=data_dir, data_folder_name=self.data_folder_name) + + self.selection_train = selection_train + self.selection_val = selection_val + self.selection_test = selection_test + + self.dims = (3, crop_size, crop_size) + + def setup(self, stage: Optional[str] = None): + super().setup() + if stage == 'fit' or stage is None: + self.train = CroppedRotNet(**self._create_dataset_parameters('train'), selection=self.selection_train) + self.val = CroppedRotNet(**self._create_dataset_parameters('val'), selection=self.selection_val) + + self._check_min_num_samples(num_samples=len(self.train), data_split='train', + drop_last=self.drop_last) + self._check_min_num_samples(num_samples=len(self.val), data_split='val', + drop_last=self.drop_last) + + if stage == 'test' or stage is not None: + self.test = CroppedRotNet(**self._create_dataset_parameters('test'), selection=self.selection_test) + # self._check_min_num_samples(num_samples=len(self.test), data_split='test', + # drop_last=False) + + def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: bool): + num_processes = self.trainer.num_processes + batch_size = self.batch_size + if drop_last: + if num_samples < (self.trainer.num_processes * self.batch_size): + log.error( + f'#samples ({num_samples}) in "{data_split}" smaller than ' + f'#processes({num_processes}) times batch size ({batch_size}). ' + f'This only works if drop_last is false!') + raise ValueError() + else: + if num_samples < (self.trainer.num_processes * self.batch_size): + log.warning( + f'#samples ({num_samples}) in "{data_split}" smaller than ' + f'#processes ({num_processes}) times batch size ({batch_size}). ' + f'This works due to drop_last=False, however samples will occur multiple times. ' + f'Check if this behavior is intended!') + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(self.train, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=self.shuffle, + drop_last=self.drop_last, + pin_memory=True) + + def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + return DataLoader(self.val, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=self.shuffle, + drop_last=self.drop_last, + pin_memory=True) + + def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + return DataLoader(self.test, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + drop_last=False, + pin_memory=True) + + def _create_dataset_parameters(self, dataset_type: str = 'train'): + is_test = dataset_type == 'test' + return {'path': self.data_dir / dataset_type, + 'image_transform': self.image_transform, + 'classes': self.class_encodings, + 'is_test': is_test} + diff --git a/tests/datamodules/RotNet/test_datamodule_cropped.py b/tests/datamodules/RotNet/test_datamodule_cropped.py new file mode 100644 index 00000000..3b5793f3 --- /dev/null +++ b/tests/datamodules/RotNet/test_datamodule_cropped.py @@ -0,0 +1,36 @@ +import numpy as np +import pytest +from omegaconf import OmegaConf + +from src.datamodules.RotNet.datamodule_cropped import RotNetDivaHisDBDataModuleCropped +from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped + +NUM_WORKERS = 4 + + +@pytest.fixture +def data_module_cropped(data_dir_cropped): + OmegaConf.clear_resolvers() + return RotNetDivaHisDBDataModuleCropped(data_dir_cropped, num_workers=NUM_WORKERS) + + +def test_init_datamodule(data_module_cropped): + assert data_module_cropped.batch_size == 8 + assert data_module_cropped.num_workers == NUM_WORKERS + assert data_module_cropped.dims == (3, 256, 256) + assert data_module_cropped.num_classes == 4 + assert data_module_cropped.class_encodings == [0, 90, 180, 270] + assert np.array_equal(data_module_cropped.class_weights, [1., 1., 1., 1.]) + assert data_module_cropped.mean == [0.7050454974582423, 0.6503181590413941, 0.5567698583877997] + assert data_module_cropped.std == [0.3104060859619884, 0.3053311838884033, 0.28919611393432726] + with pytest.raises(AttributeError): + getattr(data_module_cropped, 'train') + getattr(data_module_cropped, 'val') + getattr(data_module_cropped, 'test') + + +def test__create_dataset_parameters_cropped(data_module_cropped): + parameters = data_module_cropped._create_dataset_parameters() + assert 'train' in str(parameters['path']) + assert not parameters['is_test'] + assert np.array_equal(parameters['classes'], np.array([0, 90, 180, 270], dtype=np.int64)) From 49bd8b33c26820994905eca3b4314035e71c33fe Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Wed, 27 Oct 2021 10:36:10 +0200 Subject: [PATCH 018/108] :bug: forgot init file --- src/datamodules/RotNet/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/datamodules/RotNet/__init__.py diff --git a/src/datamodules/RotNet/__init__.py b/src/datamodules/RotNet/__init__.py new file mode 100644 index 00000000..e69de29b From 3bb0c8a95f071f90e897f350c204c2b4364fdca0 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Wed, 27 Oct 2021 12:13:51 +0200 Subject: [PATCH 019/108] :art: optimized list of paths --- .../RotNet/datasets/cropped_dataset.py | 17 ++-- .../RotNet/datasets/test_cropped_dataset.py | 77 +------------------ .../RotNet/test_datamodule_cropped.py | 6 +- 3 files changed, 18 insertions(+), 82 deletions(-) diff --git a/src/datamodules/RotNet/datasets/cropped_dataset.py b/src/datamodules/RotNet/datasets/cropped_dataset.py index d116d198..c6ba78bf 100644 --- a/src/datamodules/RotNet/datasets/cropped_dataset.py +++ b/src/datamodules/RotNet/datasets/cropped_dataset.py @@ -60,13 +60,21 @@ def __init__(self, path: Path, data_folder_name: str = 'data', gt_folder_name: s selection=selection, is_test=is_test, image_transform=image_transform, target_transform=None, twin_transform=None, - classes=None, **kwargs) + **kwargs) def __getitem__(self, index): - data_img = self._load_data_and_gt(index=index) + data_img = self._load_data_and_gt(index=int(index/len(ROTATION_ANGLES))) img, gt = self._apply_transformation(data_img, index=index) return img, gt + def __len__(self): + """ + This function returns the length of an epoch so the data loader knows when to stop. + The length is different during train/val and test, because we process the whole image during testing, + and only sample from the images during train/val. + """ + return self.num_samples * len(ROTATION_ANGLES) + def _load_data_and_gt(self, index): data_img = pil_loader(self.img_paths_per_page[index]) return data_img @@ -94,7 +102,7 @@ def _apply_transformation(self, img, index): if self.image_transform is not None: # perform transformations - img = self.image_transform(img) + img = self.image_transform(img, None) if not is_tensor(img): img = ToTensor()(img) @@ -189,8 +197,5 @@ def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder for path_data_file in sorted(path_data_subdir.iterdir()): if has_extension(path_data_file.name, IMG_EXTENSIONS): paths.append(path_data_file) - paths.append(path_data_file) - paths.append(path_data_file) - paths.append(path_data_file) return paths diff --git a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py index dced31f9..4c3b9217 100644 --- a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py +++ b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py @@ -55,100 +55,31 @@ def test__apply_transformation(dataset_train): def test_get_gt_data_paths(data_dir_cropped): file_paths = CroppedRotNet.get_gt_data_paths(directory=data_dir_cropped / 'train') - assert len(file_paths) == 48 - assert file_paths == [PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + expected_result = [ PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), PosixPath( data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), - PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png')] + ] + assert len(file_paths) == len(expected_result) + assert file_paths == expected_result diff --git a/tests/datamodules/RotNet/test_datamodule_cropped.py b/tests/datamodules/RotNet/test_datamodule_cropped.py index 3b5793f3..4204503b 100644 --- a/tests/datamodules/RotNet/test_datamodule_cropped.py +++ b/tests/datamodules/RotNet/test_datamodule_cropped.py @@ -19,10 +19,10 @@ def test_init_datamodule(data_module_cropped): assert data_module_cropped.num_workers == NUM_WORKERS assert data_module_cropped.dims == (3, 256, 256) assert data_module_cropped.num_classes == 4 - assert data_module_cropped.class_encodings == [0, 90, 180, 270] + assert np.array_equal(data_module_cropped.class_encodings, [0, 90, 180, 270]) assert np.array_equal(data_module_cropped.class_weights, [1., 1., 1., 1.]) - assert data_module_cropped.mean == [0.7050454974582423, 0.6503181590413941, 0.5567698583877997] - assert data_module_cropped.std == [0.3104060859619884, 0.3053311838884033, 0.28919611393432726] + assert data_module_cropped.mean == [0.7050454974582426, 0.6503181590413943, 0.5567698583877997] + assert data_module_cropped.std == [0.3104060859619883, 0.3053311838884032, 0.28919611393432726] with pytest.raises(AttributeError): getattr(data_module_cropped, 'train') getattr(data_module_cropped, 'val') From b82c63979ace4c6be408ee7e3fc9b990f1e946a7 Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 27 Oct 2021 13:42:18 +0200 Subject: [PATCH 020/108] :wrench: added two experiment configs --- .../experiment/cb55_select_train15_unet.yaml | 66 ++++++++++++++++++ .../cb55_select_train1_val1_unet.yaml | 67 +++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 configs/experiment/cb55_select_train15_unet.yaml create mode 100644 configs/experiment/cb55_select_train1_val1_unet.yaml diff --git a/configs/experiment/cb55_select_train15_unet.yaml b/configs/experiment/cb55_select_train15_unet.yaml new file mode 100644 index 00000000..5dc9bd32 --- /dev/null +++ b/configs/experiment/cb55_select_train15_unet.yaml @@ -0,0 +1,66 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_task.yaml + - /loss: crossentropyloss.yaml + - /metric: hisdbiou.yaml + - /model/backbone: unet_model.yaml + - /model/header: identity.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +train: True +test: True + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 50 + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 10 + confusion_matrix_val: True + confusion_matrix_test: True + +datamodule: + _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped + + data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + gt_folder_name: gt + selection_train: 15 + +callbacks: + model_checkpoint: + monitor: "val/hisdbiou" + mode: "max" + filename: ${checkpoint_folder_name}cb55-full-unet +# watch_model: +# log_freq: 1000 + +logger: + wandb: + name: 'cb55-select-train15-unet' + tags: ["best_model", "USL", "baseline"] + group: 'baseline' diff --git a/configs/experiment/cb55_select_train1_val1_unet.yaml b/configs/experiment/cb55_select_train1_val1_unet.yaml new file mode 100644 index 00000000..66b97e4a --- /dev/null +++ b/configs/experiment/cb55_select_train1_val1_unet.yaml @@ -0,0 +1,67 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_task.yaml + - /loss: crossentropyloss.yaml + - /metric: hisdbiou.yaml + - /model/backbone: unet_model.yaml + - /model/header: identity.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +train: True +test: True + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 200 + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 10 + confusion_matrix_val: True + confusion_matrix_test: True + +datamodule: + _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped + + data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + gt_folder_name: gt + selection_train: 1 + selection_val: 1 + +callbacks: + model_checkpoint: + monitor: "val/hisdbiou" + mode: "max" + filename: ${checkpoint_folder_name}cb55-full-unet +# watch_model: +# log_freq: 1000 + +logger: + wandb: + name: 'cb55-select-train1-val1-unet' + tags: ["best_model", "USL", "baseline"] + group: 'baseline' From 16c051361aeee8f21ed8cb4dc3107da7a4e20c5b Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Wed, 27 Oct 2021 13:56:41 +0200 Subject: [PATCH 021/108] :sparkle: added new task classification with a new network --- .../RotNet/datasets/cropped_dataset.py | 8 +- src/models/backbones/baby_cnn.py | 92 +++++++++++++++++++ src/tasks/classification/__init__.py | 0 src/tasks/classification/classification.py | 83 +++++++++++++++++ tests/tasks/classification/__init__.py | 0 .../classification/test_classification.py | 39 ++++++++ 6 files changed, 217 insertions(+), 5 deletions(-) create mode 100644 src/models/backbones/baby_cnn.py create mode 100644 src/tasks/classification/__init__.py create mode 100644 src/tasks/classification/classification.py create mode 100644 tests/tasks/classification/__init__.py create mode 100644 tests/tasks/classification/test_classification.py diff --git a/src/datamodules/RotNet/datasets/cropped_dataset.py b/src/datamodules/RotNet/datasets/cropped_dataset.py index c6ba78bf..b53a3fb4 100644 --- a/src/datamodules/RotNet/datasets/cropped_dataset.py +++ b/src/datamodules/RotNet/datasets/cropped_dataset.py @@ -98,23 +98,21 @@ def _apply_transformation(self, img, index): img and gt after transformations """ if self.twin_transform is not None and not self.is_test: - img, gt = self.twin_transform(img, None) + img, _ = self.twin_transform(img, None) if self.image_transform is not None: # perform transformations - img = self.image_transform(img, None) + img, _ = self.image_transform(img, None) if not is_tensor(img): img = ToTensor()(img) target_class = index % len(ROTATION_ANGLES) rotation_angle = ROTATION_ANGLES[target_class] - hot_hot_encoded = np.zeros(len(ROTATION_ANGLES)) - hot_hot_encoded[target_class] = 1 img = torchvision.transforms.functional.rotate(img=img, angle=rotation_angle) - return img, torch.LongTensor(hot_hot_encoded) + return img, target_class @staticmethod def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', diff --git a/src/models/backbones/baby_cnn.py b/src/models/backbones/baby_cnn.py new file mode 100644 index 00000000..55ab25d3 --- /dev/null +++ b/src/models/backbones/baby_cnn.py @@ -0,0 +1,92 @@ +""" +CNN with 3 conv layers and a fully connected classification layer +""" +import torch.nn as nn + + +class Flatten(nn.Module): + """ + Flatten a convolution block into a simple vector. + + Replaces the flattening line (view) often found into forward() methods of networks. This makes it + easier to navigate the network with introspection + """ + + def forward(self, x): + x = x.view(x.size()[0], -1) + return x + + +class CNN_basic(nn.Module): + """ + Simple feed forward convolutional neural network + + Attributes + ---------- + expected_input_size : tuple(int,int) + Expected input size (width, height) + conv1 : torch.nn.Sequential + conv2 : torch.nn.Sequential + conv3 : torch.nn.Sequential + Convolutional layers of the network + fc : torch.nn.Linear + Final classification fully connected layer + + """ + + def __init__(self, num_classes=10, input_channels=3, **kwargs): + """ + Creates an CNN_basic model from the scratch. + + Parameters + ---------- + num_classes : int + Number of neurons in the last layer + input_channels : int + Dimensionality of the input, typically 3 for RGB + """ + super(CNN_basic, self).__init__() + + self.expected_input_size = (32, 32) + + # First layer + self.conv1 = nn.Sequential( + nn.Conv2d(input_channels, 24, kernel_size=5, stride=3), + nn.LeakyReLU() + ) + # Second layer + self.conv2 = nn.Sequential( + nn.Conv2d(24, 48, kernel_size=3, stride=2), + nn.LeakyReLU() + ) + # Third layer + self.conv3 = nn.Sequential( + nn.Conv2d(48, 72, kernel_size=3, stride=1), + nn.LeakyReLU() + ) + + # Classification layer + self.fc = nn.Sequential( + Flatten(), + nn.Linear(109512, num_classes) + ) + + def forward(self, x): + """ + Computes forward pass on the network + + Parameters + ---------- + x : Variable + Sample to run forward pass on. (input to the model) + + Returns + ------- + Variable + Activations of the fully connected layer + """ + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.fc(x) + return x diff --git a/src/tasks/classification/__init__.py b/src/tasks/classification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tasks/classification/classification.py b/src/tasks/classification/classification.py new file mode 100644 index 00000000..7db4094c --- /dev/null +++ b/src/tasks/classification/classification.py @@ -0,0 +1,83 @@ +from typing import Optional, Callable + +import torch.nn as nn +import torch.optim +import torchmetrics + +from src.tasks.base_task import AbstractTask +from src.utils import utils +from src.tasks.utils.outputs import OutputKeys, reduce_dict + +log = utils.get_logger(__name__) + + +class Classification(AbstractTask): + + def __init__(self, + model: nn.Module, + optimizer: torch.optim.Optimizer, + loss_fn: Optional[Callable] = None, + metric_train: Optional[torchmetrics.Metric] = None, + metric_val: Optional[torchmetrics.Metric] = None, + metric_test: Optional[torchmetrics.Metric] = None, + confusion_matrix_val: Optional[bool] = False, + confusion_matrix_test: Optional[bool] = False, + confusion_matrix_log_every_n_epoch: Optional[int] = 1, + lr: float = 1e-3 + ) -> None: + """ + pixelvise semantic segmentation. The output of the network during test is a DIVAHisDB encoded image + + :param model: torch.nn.Module + The encoder for the segmentation e.g. unet + :param test_output_path: str + String with a path to the output folder of the testing + """ + super().__init__( + model=model, + optimizer=optimizer, + loss_fn=loss_fn, + metric_train=metric_train, + metric_val=metric_val, + metric_test=metric_test, + lr=lr, + confusion_matrix_val=confusion_matrix_val, + confusion_matrix_test=confusion_matrix_test, + confusion_matrix_log_every_n_epoch=confusion_matrix_log_every_n_epoch, + ) + self.save_hyperparameters() + + def setup(self, stage: str) -> None: + super().setup(stage) + + log.info("Setup done!") + + def forward(self, x): + return self.model(x) + + ############################################################################################# + ########################################### TRAIN ########################################### + ############################################################################################# + def training_step(self, batch, batch_idx, **kwargs): + input_batch, target_batch = batch + output = super().training_step(batch=(input_batch, target_batch), batch_idx=batch_idx) + return reduce_dict(input_dict=output, key_list=[OutputKeys.LOSS]) + + ############################################################################################# + ############################################ VAL ############################################ + ############################################################################################# + + def validation_step(self, batch, batch_idx, **kwargs): + input_batch, target_batch = batch + output = super().validation_step(batch=(input_batch, target_batch), batch_idx=batch_idx) + return reduce_dict(input_dict=output, key_list=[]) + + ############################################################################################# + ########################################### TEST ############################################ + ############################################################################################# + + def test_step(self, batch, batch_idx, **kwargs): + input_batch, target_batch = batch + output = super().test_step(batch=(input_batch, target_batch), batch_idx=batch_idx) + return reduce_dict(input_dict=output, key_list=[]) + diff --git a/tests/tasks/classification/__init__.py b/tests/tasks/classification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tasks/classification/test_classification.py b/tests/tasks/classification/test_classification.py new file mode 100644 index 00000000..ac1f9879 --- /dev/null +++ b/tests/tasks/classification/test_classification.py @@ -0,0 +1,39 @@ +import os + +import numpy as np +import pytorch_lightning as pl +import torch.optim.optimizer +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything + +from src.datamodules.RotNet.datamodule_cropped import RotNetDivaHisDBDataModuleCropped + +from src.models.backbones.baby_cnn import CNN_basic +from src.tasks.classification.classification import Classification +from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped + + +def test_classification(data_dir_cropped): + OmegaConf.clear_resolvers() + seed_everything(42) + + # datamodule + data_module = RotNetDivaHisDBDataModuleCropped( + data_dir=str(data_dir_cropped), + batch_size=2, num_workers=2) + + model = CNN_basic(num_classes=data_module.num_classes) + task = Classification(model=model, + optimizer=torch.optim.Adam(params=model.parameters()), + loss_fn=torch.nn.CrossEntropyLoss()) + + os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' + trainer = pl.Trainer(max_epochs=2, precision=32, default_root_dir=task.test_output_path, accelerator='ddp_cpu', + num_processes=1) + + trainer.fit(task, datamodule=data_module) + + results = trainer.test() + print(results) + assert np.isclose(results[0]['test/crossentropyloss'], 1.5777363777160645, rtol=2e-03) + assert np.isclose(results[0]['test/crossentropyloss_epoch'], 1.5777363777160645, rtol=2e-03) From b8077180cd585f6d3b1b7338621d32ec49ac7fbd Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 27 Oct 2021 15:17:23 +0200 Subject: [PATCH 022/108] :sparkles: generate_cropped_dataset.py supports gif and can handle multiple gt and data folder. --- tools/generate_cropped_dataset.py | 92 ++++++++++++++----------------- 1 file changed, 40 insertions(+), 52 deletions(-) diff --git a/tools/generate_cropped_dataset.py b/tools/generate_cropped_dataset.py index 1cd01f92..ae41c2dd 100644 --- a/tools/generate_cropped_dataset.py +++ b/tools/generate_cropped_dataset.py @@ -16,7 +16,7 @@ from torchvision.utils import save_image from tqdm import tqdm -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] def has_extension(filename, extensions): @@ -65,41 +65,37 @@ def convert_to_rgb(pic): return pic -def get_gt_data_paths_uncropped(directory): +def get_img_paths_uncropped(directory): """ Parameters ---------- directory: string - parent directory with gt and data folder inside + parent directory with images inside Returns ------- - paths: list of tuples + paths: list of paths """ paths = [] directory = Path(directory).expanduser() - path_imgs = Path(directory) / "data" - path_gts = Path(directory) / "gt" + if not directory.is_dir(): + logging.error(f'Directory not found ({directory})') - if not (path_imgs.is_dir() or path_gts.is_dir()): - logging.error("folder data or gt not found in " + str(directory)) + for subdir in sorted(directory.iterdir()): + if not subdir.is_dir(): + continue - for img_name, gt_name in zip(sorted(path_imgs.iterdir()), sorted(path_gts.iterdir())): - assert has_extension(str(img_name), IMG_EXTENSIONS) == has_extension(str(gt_name), IMG_EXTENSIONS), \ - 'get_gt_data_paths_uncropped(): image file aligned with non-image file' - - if has_extension(str(img_name), IMG_EXTENSIONS) and has_extension(str(gt_name), IMG_EXTENSIONS): - assert img_name.suffix[0] == gt_name.suffix[0], \ - 'get_gt_data_paths_uncropped(): mismatch between data filename and gt filename' - paths.append((path_imgs / img_name, path_gts / gt_name)) + for img_name in sorted(subdir.iterdir()): + if has_extension(str(img_name), IMG_EXTENSIONS): + paths.append((subdir / img_name, str(subdir.stem))) return paths -class ToTensorSlidingWindowCrop(object): +class ImageCrop(object): """ Crop the data and ground truth image at the specified coordinates to the specified size and convert them to a tensor. @@ -108,7 +104,7 @@ class ToTensorSlidingWindowCrop(object): def __init__(self, crop_size): self.crop_size = crop_size - def __call__(self, img, gt, coordinates): + def __call__(self, img, coordinates): """ Args: img (PIL Image): Data image to be cropped and converted to tensor. @@ -123,10 +119,8 @@ def __call__(self, img, gt, coordinates): img_crop = F.to_tensor( F.crop(img=img, left=x_position, top=y_position, width=self.crop_size, height=self.crop_size)) - gt_crop = F.to_tensor( - F.crop(img=gt, left=x_position, top=y_position, width=self.crop_size, height=self.crop_size)) - return img_crop, gt_crop + return img_crop class CroppedDatasetGenerator: @@ -211,49 +205,46 @@ def __init__(self, input_path, output_path, crop_size, overlap=0.5, leading_zero self.step_size = int(self.crop_size * (1 - self.overlap)) # List of tuples that contain the path to the gt and image that belong together - self.img_paths = get_gt_data_paths_uncropped(input_path) + self.img_paths = get_img_paths_uncropped(input_path) self.num_imgs_in_set = len(self.img_paths) if self.num_imgs_in_set == 0: raise RuntimeError("Found 0 images in subfolders of: {} \n Supported image extensions are: {}".format( input_path, ",".join(IMG_EXTENSIONS))) + self.current_split = '' self.current_img_index = -1 self.img_names_sizes, self.num_horiz_crops, self.num_vert_crops = self._get_img_size_and_crop_numbers() self.crop_list = self._get_crop_list() def write_crops(self): - crop_function = ToTensorSlidingWindowCrop(self.crop_size) + crop_function = ImageCrop(self.crop_size) for img_index, x, y in tqdm(self.crop_list, desc=self.progress_title): - self._load_image_and_var(img_index=img_index) + self._load_image(img_index=img_index) coordinates = (x, y) - img_full_name = self.img_names_sizes[img_index][0] + split_name = self.img_names_sizes[img_index][0] + img_full_name = self.img_names_sizes[img_index][1] img_full_name = Path(img_full_name) img_name = img_full_name.stem - dest_folder_data = self.output_path / 'data' / img_name - dest_folder_gt = self.output_path / 'gt' / img_name - dest_folder_data.mkdir(parents=True, exist_ok=True) - dest_folder_gt.mkdir(parents=True, exist_ok=True) + dest_folder = self.output_path / split_name / img_name + dest_folder.mkdir(parents=True, exist_ok=True) extension = img_full_name.suffix filename = f'{img_name}_x{x:0{self.leading_zeros_length}d}_y{y:0{self.leading_zeros_length}d}{extension}' - dest_filename_data = dest_folder_data / filename - dest_filename_gt = dest_folder_gt / filename + dest_filename = dest_folder / filename if not self.override_existing: - if dest_filename_data.exists() and dest_filename_gt.exists(): + if dest_filename.exists(): continue - img, gt = self.get_crops(self.current_data_img, self.current_gt_img, - coordinates=coordinates, crop_function=crop_function) + img = self.get_crop(self.current_img, coordinates=coordinates, crop_function=crop_function) - save_image(img, dest_filename_data) - save_image(gt, dest_filename_gt) + save_image(img, dest_filename) - def _load_image_and_var(self, img_index): + def _load_image(self, img_index): """ Inits the variables responsible of tracking which crop should be taken next, the current images and the like. This should be run every time a new page gets loaded for the test-set @@ -263,28 +254,25 @@ def _load_image_and_var(self, img_index): return # Load image - self.current_data_img = pil_loader(self.img_paths[img_index][0]) - self.current_gt_img = pil_loader(self.img_paths[img_index][1]) + self.current_img = pil_loader(self.img_paths[img_index][0]) # Update pointer to current image self.current_img_index = img_index + self.current_split = self.img_paths[img_index][1] - def get_crops(self, img, gt, coordinates, crop_function): - img, gt = crop_function(img, gt, coordinates) - return img, gt + def get_crop(self, img, coordinates, crop_function): + img = crop_function(img, coordinates) + return img def _get_img_size_and_crop_numbers(self): # TODO documentation - img_names_sizes = [] # list of tuples -> (gt_img_name, img_size (H, W)) + img_names_sizes = [] # list of tuples -> (split_name, img_name, img_size (H, W)) num_horiz_crops = [] num_vert_crops = [] - for img_path, gt_path in self.img_paths: + for img_path, split_name in self.img_paths: data_img = pil_loader(img_path) - gt_img = pil_loader(gt_path) - # Ensure that data and gt image are of the same size - assert gt_img.size == data_img.size - img_names_sizes.append((gt_path.name, data_img.size)) + img_names_sizes.append((split_name, img_path.name, data_img.size)) num_horiz_crops.append(math.ceil((data_img.size[0] - self.crop_size) / self.step_size + 1)) num_vert_crops.append(math.ceil((data_img.size[1] - self.crop_size) / self.step_size + 1)) @@ -300,18 +288,18 @@ def _convert_crop_id_to_coordinates(self, img_index, hcrop_index, vcrop_index): # X coordinate if hcrop_index == self.num_horiz_crops[img_index] - 1: # We are at the end of a line - x_position = self.img_names_sizes[img_index][1][0] - self.crop_size + x_position = self.img_names_sizes[img_index][2][0] - self.crop_size else: x_position = self.step_size * hcrop_index - assert x_position < self.img_names_sizes[img_index][1][0] - self.crop_size + assert x_position < self.img_names_sizes[img_index][2][0] - self.crop_size # Y coordinate if vcrop_index == self.num_vert_crops[img_index] - 1: # We are at the bottom end - y_position = self.img_names_sizes[img_index][1][1] - self.crop_size + y_position = self.img_names_sizes[img_index][2][1] - self.crop_size else: y_position = self.step_size * vcrop_index - assert y_position < self.img_names_sizes[img_index][1][1] - self.crop_size + assert y_position < self.img_names_sizes[img_index][2][1] - self.crop_size return img_index, x_position, y_position From 78a1883145117551458dcc5cdac06cf298376174 Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 27 Oct 2021 16:01:58 +0200 Subject: [PATCH 023/108] :construction: started development of RGB dataset/datamodule --- .../development_baby_unet_rgb_data.yaml | 68 +++++++++++++++++++ src/datamodules/RGB/datamodule_cropped.py | 25 ++++--- .../RGB/datasets/cropped_dataset.py | 4 +- src/datamodules/RGB/utils/image_analytics.py | 4 +- 4 files changed, 88 insertions(+), 13 deletions(-) create mode 100644 configs/experiment/development_baby_unet_rgb_data.yaml diff --git a/configs/experiment/development_baby_unet_rgb_data.yaml b/configs/experiment/development_baby_unet_rgb_data.yaml new file mode 100644 index 00000000..50bf5ad6 --- /dev/null +++ b/configs/experiment/development_baby_unet_rgb_data.yaml @@ -0,0 +1,68 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_task.yaml + - /loss: crossentropyloss.yaml + - /metric: iou.yaml + - /model/backbone: baby_unet_model.yaml + - /model/header: identity.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +seed: 42 + +train: True +test: True + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 3 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 1 + confusion_matrix_val: True + confusion_matrix_test: True + +datamodule: + _target_: src.datamodules.RGB.datamodule_cropped.DataModuleCroppedRGB + + data_dir: /netscratch/datasets/semantic_segmentation/synthetic_cropped/SetA1_sizeM/layoutD/split + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + gt_folder_name: gtA + +callbacks: + model_checkpoint: + filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'dev-baby-unet-rgb-data' + tags: [ "best_model", "USL" ] + group: 'dev-runs' + notes: "Testing" diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index 50ad0c67..96699630 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -5,17 +5,19 @@ from torchvision import transforms from src.datamodules.base_datamodule import AbstractDatamodule -from src.datamodules.RGB.datasets.cropped_dataset import CroppedHisDBDataset +from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB from src.datamodules.RGB.utils.image_analytics import get_analytics from src.datamodules.RGB.utils.misc import validate_path_for_segmentation from src.datamodules.RGB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils +from functools import partial + log = utils.get_logger(__name__) -class DivaHisDBDataModuleCropped(AbstractDatamodule): +class DataModuleCroppedRGB(AbstractDatamodule): def __init__(self, data_dir: str = None, data_folder_name: str = 'data', gt_folder_name: str = 'gt', selection_train: Optional[Union[int, List[str]]] = None, selection_val: Optional[Union[int, List[str]]] = None, @@ -24,8 +26,13 @@ def __init__(self, data_dir: str = None, data_folder_name: str = 'data', gt_fold shuffle: bool = True, drop_last: bool = True): super().__init__() + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + analytics = get_analytics(input_path=Path(data_dir), - get_gt_data_paths_func=CroppedHisDBDataset.get_gt_data_paths) + get_gt_data_paths_func=partial(CroppedDatasetRGB.get_gt_data_paths, + data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name)) self.mean = analytics['mean'] self.std = analytics['std'] @@ -48,8 +55,6 @@ def __init__(self, data_dir: str = None, data_folder_name: str = 'data', gt_fold self.shuffle = shuffle self.drop_last = drop_last - self.data_folder_name = data_folder_name - self.gt_folder_name = gt_folder_name self.data_dir = validate_path_for_segmentation(data_dir=data_dir, data_folder_name=self.data_folder_name, gt_folder_name=self.gt_folder_name) @@ -62,8 +67,8 @@ def __init__(self, data_dir: str = None, data_folder_name: str = 'data', gt_fold def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: - self.train = CroppedHisDBDataset(**self._create_dataset_parameters('train'), selection=self.selection_train) - self.val = CroppedHisDBDataset(**self._create_dataset_parameters('val'), selection=self.selection_val) + self.train = CroppedDatasetRGB(**self._create_dataset_parameters('train'), selection=self.selection_train) + self.val = CroppedDatasetRGB(**self._create_dataset_parameters('val'), selection=self.selection_val) self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) @@ -71,7 +76,7 @@ def setup(self, stage: Optional[str] = None): drop_last=self.drop_last) if stage == 'test' or stage is not None: - self.test = CroppedHisDBDataset(**self._create_dataset_parameters('test'), selection=self.selection_test) + self.test = CroppedDatasetRGB(**self._create_dataset_parameters('test'), selection=self.selection_test) # self._check_min_num_samples(num_samples=len(self.test), data_split='test', # drop_last=False) @@ -120,6 +125,8 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] def _create_dataset_parameters(self, dataset_type: str = 'train'): is_test = dataset_type == 'test' return {'path': self.data_dir / dataset_type, + 'data_folder_name': self.data_folder_name, + 'gt_folder_name': self.gt_folder_name, 'image_transform': self.image_transform, 'target_transform': self.target_transform, 'twin_transform': self.twin_transform, @@ -137,5 +144,3 @@ def get_img_name_coordinates(self, index): raise Exception('This method can just be called during testing') return self.test.img_paths_per_page[index][2:] - - diff --git a/src/datamodules/RGB/datasets/cropped_dataset.py b/src/datamodules/RGB/datasets/cropped_dataset.py index 55e277df..beb790aa 100644 --- a/src/datamodules/RGB/datasets/cropped_dataset.py +++ b/src/datamodules/RGB/datasets/cropped_dataset.py @@ -15,11 +15,11 @@ from src.datamodules.RGB.utils.misc import has_extension, pil_loader from src.utils import utils -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] log = utils.get_logger(__name__) -class CroppedHisDBDataset(data.Dataset): +class CroppedDatasetRGB(data.Dataset): """A generic data loader where the images are arranged in this way: :: root/gt/xxx.png diff --git a/src/datamodules/RGB/utils/image_analytics.py b/src/datamodules/RGB/utils/image_analytics.py index 5dcc90c7..c4317fde 100644 --- a/src/datamodules/RGB/utils/image_analytics.py +++ b/src/datamodules/RGB/utils/image_analytics.py @@ -14,6 +14,8 @@ import torchvision.transforms as transforms from PIL import Image +from src.datamodules.RGB.utils.misc import pil_loader + def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): """ @@ -317,7 +319,7 @@ def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): label_counter = {} for path in gt_images: - img = np.array(Image.open(path))[:, :, 2].flatten() + img = np.array(pil_loader(path))[:, :, 2].flatten() total_num_pixels += len(img) for i, j in zip(*np.unique(img, return_counts=True)): label_counter[i] = label_counter.get(i, 0) + j From f476e5a192e16168f6fbc6c3e8736693a8b91c88 Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 28 Oct 2021 02:24:36 +0200 Subject: [PATCH 024/108] :construction: separated data and gt analytics --- src/datamodules/RGB/datamodule_cropped.py | 18 ++- src/datamodules/RGB/utils/image_analytics.py | 113 ++++++++++++------- 2 files changed, 83 insertions(+), 48 deletions(-) diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index 96699630..38cafc0d 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -12,8 +12,6 @@ from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils -from functools import partial - log = utils.get_logger(__name__) @@ -29,16 +27,16 @@ def __init__(self, data_dir: str = None, data_folder_name: str = 'data', gt_fold self.data_folder_name = data_folder_name self.gt_folder_name = gt_folder_name - analytics = get_analytics(input_path=Path(data_dir), - get_gt_data_paths_func=partial(CroppedDatasetRGB.get_gt_data_paths, - data_folder_name=self.data_folder_name, - gt_folder_name=self.gt_folder_name)) + analytics_data, analytics_gt = get_analytics(input_path=Path(data_dir), + data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name, + get_gt_data_paths_func=CroppedDatasetRGB.get_gt_data_paths) - self.mean = analytics['mean'] - self.std = analytics['std'] - self.class_encodings = analytics['class_encodings'] + self.mean = analytics_data['mean'] + self.std = analytics_data['std'] + self.class_encodings = analytics_gt['class_encodings'] self.num_classes = len(self.class_encodings) - self.class_weights = analytics['class_weights'] + self.class_weights = analytics_gt['class_weights'] self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std)])) diff --git a/src/datamodules/RGB/utils/image_analytics.py b/src/datamodules/RGB/utils/image_analytics.py index c4317fde..e213a385 100644 --- a/src/datamodules/RGB/utils/image_analytics.py +++ b/src/datamodules/RGB/utils/image_analytics.py @@ -17,7 +17,7 @@ from src.datamodules.RGB.utils.misc import pil_loader -def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): +def get_analytics(input_path: Path, data_folder_name, gt_folder_name, get_gt_data_paths_func, **kwargs): """ Parameters ---------- @@ -26,36 +26,72 @@ def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): Returns ------- """ - analytics_file_path = input_path / 'analytics.json' - if analytics_file_path.exists(): - with analytics_file_path.open(mode='r') as f: - analytics_dict = json.load(fp=f) - else: + expected_keys_data = ['mean', 'std'] + expected_keys_gt = ['class_weights', 'class_encodings'] + + analytics_path_data = input_path / f'analytics.data.{data_folder_name}.json' + analytics_path_gt = input_path / f'analytics.gt.{gt_folder_name}.json' + + analytics_data = None + analytics_gt = None + + missing_analytics_data = True + missing_analytics_gt = True + + if analytics_path_data.exists(): + with analytics_path_data.open(mode='r') as f: + analytics_data = json.load(fp=f) + # check if analytics file is complete + if all(k in analytics_data for k in expected_keys_data): + missing_analytics_data = False + + if analytics_path_gt.exists(): + with analytics_path_gt.open(mode='r') as f: + analytics_gt = json.load(fp=f) + # check if analytics file is complete + if all(k in analytics_gt for k in expected_keys_gt): + missing_analytics_gt = False + + if missing_analytics_data or missing_analytics_gt: train_path = input_path / 'train' - gt_data_path_list = get_gt_data_paths_func(train_path) + gt_data_path_list = get_gt_data_paths_func(train_path, data_folder_name=data_folder_name, + gt_folder_name=gt_folder_name) file_names_data = np.asarray([str(item[0]) for item in gt_data_path_list]) file_names_gt = np.asarray([str(item[1]) for item in gt_data_path_list]) - mean, std = compute_mean_std(file_names=file_names_data, **kwargs) - - # Measure weights for class balancing - logging.info(f'Measuring class weights') - # create a list with all gt file paths - class_weights, class_encodings = _get_class_frequencies_weights_segmentation(gt_images=file_names_gt, **kwargs) - analytics_dict = {'mean': mean.tolist(), - 'std': std.tolist(), - 'class_weights': class_weights.tolist(), - 'class_encodings': class_encodings.tolist()} - # save json - try: - with analytics_file_path.open(mode='w') as f: - json.dump(obj=analytics_dict, fp=f) - except IOError as e: - if e.errno == errno.EACCES: - print(f'WARNING: No permissions to write analytics file ({analytics_file_path})') - else: - raise - # returns the 'mean[RGB]', 'std[RGB]', 'class_frequencies_weights[num_classes]', 'class_encodings' - return analytics_dict + + if missing_analytics_data: + mean, std = compute_mean_std(file_names=file_names_data, **kwargs) + analytics_data = {'mean': mean.tolist(), + 'std': std.tolist()} + # save json + try: + with analytics_path_data.open(mode='w') as f: + json.dump(obj=analytics_data, fp=f) + except IOError as e: + if e.errno == errno.EACCES: + print(f'WARNING: No permissions to write analytics file ({analytics_path_data})') + else: + raise + + if missing_analytics_gt: + # Measure weights for class balancing + logging.info(f'Measuring class weights') + # create a list with all gt file paths + class_weights, class_encodings = _get_class_frequencies_weights_segmentation(gt_images=file_names_gt, + **kwargs) + analytics_gt = {'class_weights': class_weights, + 'class_encodings': class_encodings} + # save json + try: + with analytics_path_gt.open(mode='w') as f: + json.dump(obj=analytics_gt, fp=f) + except IOError as e: + if e.errno == errno.EACCES: + print(f'WARNING: No permissions to write analytics file ({analytics_path_gt})') + else: + raise + + return analytics_data, analytics_gt def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): @@ -319,18 +355,19 @@ def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): label_counter = {} for path in gt_images: - img = np.array(pil_loader(path))[:, :, 2].flatten() - total_num_pixels += len(img) - for i, j in zip(*np.unique(img, return_counts=True)): - label_counter[i] = label_counter.get(i, 0) + j - - classes = np.array(sorted(label_counter.keys())) - num_samples_per_class = np.array([label_counter[k] for k in classes]) - class_frequencies = (num_samples_per_class / total_num_pixels) + img_raw = pil_loader(path) + colors = img_raw.getcolors() + + for count, color in colors: + total_num_pixels += count + label_counter[color] = label_counter.get(color, 0) + count + + classes = sorted(label_counter.keys()) + num_samples_per_class = np.asarray([label_counter[k] for k in classes]) logging.info('Finished computing class frequencies weights') - logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) - return (1 / num_samples_per_class) / ((1 / num_samples_per_class).sum()), classes + class_weights = (1 / num_samples_per_class) / ((1 / num_samples_per_class).sum()) + return class_weights.tolist(), classes if __name__ == '__main__': From fd4804955182ca3695a42a9b0478ea3bda96bf5f Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Thu, 28 Oct 2021 11:18:39 +0200 Subject: [PATCH 025/108] :wrench: :sparkle: added configs for the new task as well as a new backbone and a new header --- .../dev_rotnet_cnn_basic_cb55_10.yaml | 67 +++++++++++++++++++ configs/model/backbone/cnn_basic.yaml | 1 + configs/model/header/single_layer.yaml | 5 ++ configs/task/classification.yaml | 1 + src/models/backbones/baby_cnn.py | 26 +------ src/models/headers/fully_connected.py | 17 +++++ .../models/utils}/__init__.py | 0 src/models/utils/utils.py | 14 ++++ .../classification/test_classification.py | 39 ----------- tests/test_data/__init__.py | 0 10 files changed, 107 insertions(+), 63 deletions(-) create mode 100644 configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml create mode 100644 configs/model/backbone/cnn_basic.yaml create mode 100644 configs/model/header/single_layer.yaml create mode 100644 configs/task/classification.yaml create mode 100644 src/models/headers/fully_connected.py rename {tests/tasks/classification => src/models/utils}/__init__.py (100%) create mode 100644 src/models/utils/utils.py delete mode 100644 tests/tasks/classification/test_classification.py create mode 100644 tests/test_data/__init__.py diff --git a/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml b/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml new file mode 100644 index 00000000..3d2e9735 --- /dev/null +++ b/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml @@ -0,0 +1,67 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: classification.yaml + - /loss: crossentropyloss.yaml + - /metric: accuracy.yaml + - /model/backbone: cnn_basic.yaml + - /model/header: single_layer.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +seed: 42 + +train: True +test: False + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 3 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 1 + confusion_matrix_val: False + confusion_matrix_test: False + +datamodule: + _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped + + data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + +callbacks: + model_checkpoint: + filename: ${checkpoint_folder_name}dev-rotnet-basic-cnn-cb55-10 + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'dev-rotnet-basic-cnn-cb55-10' + tags: [ "best_model", "USL" ] + group: 'dev-runs' + notes: "Testing" diff --git a/configs/model/backbone/cnn_basic.yaml b/configs/model/backbone/cnn_basic.yaml new file mode 100644 index 00000000..20778555 --- /dev/null +++ b/configs/model/backbone/cnn_basic.yaml @@ -0,0 +1 @@ +_target_: src.models.backbones.baby_cnn.CNN_basic diff --git a/configs/model/header/single_layer.yaml b/configs/model/header/single_layer.yaml new file mode 100644 index 00000000..05b64b26 --- /dev/null +++ b/configs/model/header/single_layer.yaml @@ -0,0 +1,5 @@ +_target_: src.models.headers.fully_connected.SingleLinear + +num_classes: ${datamodule:num_classes} +# needs to be calculated from the output of the last layer of the backbone (do not forget to flatten!) +input_size: 109512 \ No newline at end of file diff --git a/configs/task/classification.yaml b/configs/task/classification.yaml new file mode 100644 index 00000000..4e52a39c --- /dev/null +++ b/configs/task/classification.yaml @@ -0,0 +1 @@ +_target_: src.tasks.classification.classification.Classification \ No newline at end of file diff --git a/src/models/backbones/baby_cnn.py b/src/models/backbones/baby_cnn.py index 55ab25d3..d7aaae31 100644 --- a/src/models/backbones/baby_cnn.py +++ b/src/models/backbones/baby_cnn.py @@ -4,19 +4,6 @@ import torch.nn as nn -class Flatten(nn.Module): - """ - Flatten a convolution block into a simple vector. - - Replaces the flattening line (view) often found into forward() methods of networks. This makes it - easier to navigate the network with introspection - """ - - def forward(self, x): - x = x.view(x.size()[0], -1) - return x - - class CNN_basic(nn.Module): """ Simple feed forward convolutional neural network @@ -34,7 +21,7 @@ class CNN_basic(nn.Module): """ - def __init__(self, num_classes=10, input_channels=3, **kwargs): + def __init__(self, **kwargs): """ Creates an CNN_basic model from the scratch. @@ -47,11 +34,9 @@ def __init__(self, num_classes=10, input_channels=3, **kwargs): """ super(CNN_basic, self).__init__() - self.expected_input_size = (32, 32) - # First layer self.conv1 = nn.Sequential( - nn.Conv2d(input_channels, 24, kernel_size=5, stride=3), + nn.Conv2d(3, 24, kernel_size=5, stride=3), nn.LeakyReLU() ) # Second layer @@ -65,12 +50,6 @@ def __init__(self, num_classes=10, input_channels=3, **kwargs): nn.LeakyReLU() ) - # Classification layer - self.fc = nn.Sequential( - Flatten(), - nn.Linear(109512, num_classes) - ) - def forward(self, x): """ Computes forward pass on the network @@ -88,5 +67,4 @@ def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) - x = self.fc(x) return x diff --git a/src/models/headers/fully_connected.py b/src/models/headers/fully_connected.py new file mode 100644 index 00000000..10c7af17 --- /dev/null +++ b/src/models/headers/fully_connected.py @@ -0,0 +1,17 @@ +from torch import nn + +from src.models.utils.utils import Flatten + + +class SingleLinear(nn.Module): + def __init__(self, num_classes: int = 4, input_size: int = 109512): + super(SingleLinear, self).__init__() + + self.fc = nn.Sequential( + Flatten(), + nn.Linear(input_size, num_classes) + ) + + def forward(self, x): + x = self.fc(x) + return x diff --git a/tests/tasks/classification/__init__.py b/src/models/utils/__init__.py similarity index 100% rename from tests/tasks/classification/__init__.py rename to src/models/utils/__init__.py diff --git a/src/models/utils/utils.py b/src/models/utils/utils.py new file mode 100644 index 00000000..02df24be --- /dev/null +++ b/src/models/utils/utils.py @@ -0,0 +1,14 @@ +from torch import nn + + +class Flatten(nn.Module): + """ + Flatten a convolution block into a simple vector. + + Replaces the flattening line (view) often found into forward() methods of networks. This makes it + easier to navigate the network with introspection + """ + + def forward(self, x): + x = x.view(x.size()[0], -1) + return x diff --git a/tests/tasks/classification/test_classification.py b/tests/tasks/classification/test_classification.py deleted file mode 100644 index ac1f9879..00000000 --- a/tests/tasks/classification/test_classification.py +++ /dev/null @@ -1,39 +0,0 @@ -import os - -import numpy as np -import pytorch_lightning as pl -import torch.optim.optimizer -from omegaconf import OmegaConf -from pytorch_lightning import seed_everything - -from src.datamodules.RotNet.datamodule_cropped import RotNetDivaHisDBDataModuleCropped - -from src.models.backbones.baby_cnn import CNN_basic -from src.tasks.classification.classification import Classification -from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped - - -def test_classification(data_dir_cropped): - OmegaConf.clear_resolvers() - seed_everything(42) - - # datamodule - data_module = RotNetDivaHisDBDataModuleCropped( - data_dir=str(data_dir_cropped), - batch_size=2, num_workers=2) - - model = CNN_basic(num_classes=data_module.num_classes) - task = Classification(model=model, - optimizer=torch.optim.Adam(params=model.parameters()), - loss_fn=torch.nn.CrossEntropyLoss()) - - os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' - trainer = pl.Trainer(max_epochs=2, precision=32, default_root_dir=task.test_output_path, accelerator='ddp_cpu', - num_processes=1) - - trainer.fit(task, datamodule=data_module) - - results = trainer.test() - print(results) - assert np.isclose(results[0]['test/crossentropyloss'], 1.5777363777160645, rtol=2e-03) - assert np.isclose(results[0]['test/crossentropyloss_epoch'], 1.5777363777160645, rtol=2e-03) diff --git a/tests/test_data/__init__.py b/tests/test_data/__init__.py new file mode 100644 index 00000000..e69de29b From 7d6897ea9b58c7e4924249e417bd0cf522b7cacd Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Thu, 28 Oct 2021 11:53:54 +0200 Subject: [PATCH 026/108] :white_check_mark: stupid not copied the image_analytics --- src/datamodules/DivaHisDB/utils/image_analytics.py | 12 ++++++++---- src/datamodules/RotNet/datamodule_cropped.py | 4 ++-- .../RotNet/datasets/test_cropped_dataset.py | 10 +++++----- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/datamodules/DivaHisDB/utils/image_analytics.py b/src/datamodules/DivaHisDB/utils/image_analytics.py index c95899b7..5dcc90c7 100644 --- a/src/datamodules/DivaHisDB/utils/image_analytics.py +++ b/src/datamodules/DivaHisDB/utils/image_analytics.py @@ -15,7 +15,7 @@ from PIL import Image -def get_analytics(input_path: Path, get_data_paths_func, **kwargs): +def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): """ Parameters ---------- @@ -30,15 +30,19 @@ def get_analytics(input_path: Path, get_data_paths_func, **kwargs): analytics_dict = json.load(fp=f) else: train_path = input_path / 'train' - data_path_list = get_data_paths_func(train_path) - file_names_data = np.asarray([str(item) for item in data_path_list]) + gt_data_path_list = get_gt_data_paths_func(train_path) + file_names_data = np.asarray([str(item[0]) for item in gt_data_path_list]) + file_names_gt = np.asarray([str(item[1]) for item in gt_data_path_list]) mean, std = compute_mean_std(file_names=file_names_data, **kwargs) # Measure weights for class balancing logging.info(f'Measuring class weights') # create a list with all gt file paths + class_weights, class_encodings = _get_class_frequencies_weights_segmentation(gt_images=file_names_gt, **kwargs) analytics_dict = {'mean': mean.tolist(), - 'std': std.tolist()} + 'std': std.tolist(), + 'class_weights': class_weights.tolist(), + 'class_encodings': class_encodings.tolist()} # save json try: with analytics_file_path.open(mode='w') as f: diff --git a/src/datamodules/RotNet/datamodule_cropped.py b/src/datamodules/RotNet/datamodule_cropped.py index 3c6b4169..7f2604e1 100644 --- a/src/datamodules/RotNet/datamodule_cropped.py +++ b/src/datamodules/RotNet/datamodule_cropped.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader from torchvision import transforms -from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics +from src.datamodules.RotNet.utils.image_analytics import get_analytics from src.datamodules.RotNet.datasets.cropped_dataset import CroppedRotNet, ROTATION_ANGLES from src.datamodules.RotNet.utils.misc import validate_path_for_self_supervised from src.datamodules.RotNet.utils.wrapper_transforms import OnlyImage @@ -25,7 +25,7 @@ def __init__(self, data_dir: str = None, data_folder_name: str = 'data', super().__init__() analytics = get_analytics(input_path=Path(data_dir), - get_data_paths_func=CroppedRotNet.get_gt_data_paths) + get_gt_data_paths_func=CroppedRotNet.get_gt_data_paths) self.mean = analytics['mean'] self.std = analytics['std'] diff --git a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py index 4c3b9217..8fb8db5a 100644 --- a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py +++ b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py @@ -31,26 +31,26 @@ def test__apply_transformation(dataset_train): img0, gt0 = dataset_train._apply_transformation(img0_o, 0) assert torch.equal(img0, ToTensor()(img0_o)) - assert torch.equal(gt0, torch.LongTensor([1, 0, 0, 0])) + assert gt0 == 0 img1, gt1 = dataset_train._apply_transformation(img1_o, 1) assert not torch.equal(ToTensor()(img1_o), img1) assert torch.equal(img1, rotate(img=ToTensor()(img1_o), angle=ROTATION_ANGLES[1])) - assert torch.equal(gt1, torch.LongTensor([0, 1, 0, 0])) + assert gt1 == 1 img2, gt2 = dataset_train._apply_transformation(img2_o, 2) assert not torch.equal(ToTensor()(img2_o), img2) assert torch.equal(img2, rotate(img=ToTensor()(img2_o), angle=ROTATION_ANGLES[2])) - assert torch.equal(gt2, torch.LongTensor([0, 0, 1, 0])) + assert gt2 == 2 img3, gt3 = dataset_train._apply_transformation(img3_o, 3) assert not torch.equal(ToTensor()(img3_o), img3) assert torch.equal(img3, rotate(img=ToTensor()(img3_o), angle=ROTATION_ANGLES[3])) - assert torch.equal(gt3, torch.LongTensor([0, 0, 0, 1])) + assert gt3 == 3 img4, gt4 = dataset_train._apply_transformation(img4_o, 0) assert torch.equal(img4, ToTensor()(img4_o)) - assert torch.equal(gt4, torch.LongTensor([1, 0, 0, 0])) + assert gt4 == 0 def test_get_gt_data_paths(data_dir_cropped): From 437c899c596bfbc104695f8189f274bcbea857b7 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Thu, 28 Oct 2021 11:57:08 +0200 Subject: [PATCH 027/108] :bug: forgot to add the file to git --- .../RotNet/utils/image_analytics.py | 332 ++++++++++++++++++ 1 file changed, 332 insertions(+) create mode 100644 src/datamodules/RotNet/utils/image_analytics.py diff --git a/src/datamodules/RotNet/utils/image_analytics.py b/src/datamodules/RotNet/utils/image_analytics.py new file mode 100644 index 00000000..3bcc57f7 --- /dev/null +++ b/src/datamodules/RotNet/utils/image_analytics.py @@ -0,0 +1,332 @@ +# Utils +import errno +import json +import logging +import os +from multiprocessing import Pool +from pathlib import Path +from typing import List + +import numpy as np +# Torch related stuff +import torch +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from PIL import Image + + +def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): + """ + Parameters + ---------- + input_path: Path to dataset + + Returns + ------- + """ + analytics_file_path = input_path / 'analytics.json' + if analytics_file_path.exists(): + with analytics_file_path.open(mode='r') as f: + analytics_dict = json.load(fp=f) + else: + train_path = input_path / 'train' + data_path_list = get_gt_data_paths_func(train_path) + file_names_data = np.asarray([str(item) for item in data_path_list]) + mean, std = compute_mean_std(file_names=file_names_data, **kwargs) + + # Measure weights for class balancing + logging.info(f'Measuring class weights') + # create a list with all gt file paths + analytics_dict = {'mean': mean.tolist(), + 'std': std.tolist()} + # save json + try: + with analytics_file_path.open(mode='w') as f: + json.dump(obj=analytics_dict, fp=f) + except IOError as e: + if e.errno == errno.EACCES: + print(f'WARNING: No permissions to write analytics file ({analytics_file_path})') + else: + raise + # returns the 'mean[RGB]', 'std[RGB]', 'class_frequencies_weights[num_classes]', 'class_encodings' + return analytics_dict + + +def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): + """ + Computes mean and std of all images present at target folder. + + Parameters + ---------- + input_folder : String (path) + Path to the dataset folder (see above for details) + inmem : Boolean + Specifies whether is should be computed i nan online of offline fashion. + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + mean : float + Mean value of all pixels of the images in the input folder + std : float + Standard deviation of all pixels of the images in the input folder + """ + file_names_np = np.array(list(map(str, file_names))) + # Compute mean and std + mean, std = _cms_inmem(file_names_np) if inmem else _cms_online(file_names_np, workers) + return mean, std + + +def _cms_online(file_names, workers=4): + """ + Computes mean and image_classification deviation in an online fashion. + This is useful when the dataset is too big to be allocated in memory. + + Parameters + ---------- + file_names : List of String + List of file names of the dataset + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + mean : double + std : double + """ + logging.info('Begin computing the mean') + + # Set up a pool of workers + pool = Pool(workers + 1) + + # Online mean + results = pool.map(_return_mean, file_names) + mean_sum = np.sum(np.array(results), axis=0) + + # Divide by number of samples in train set + mean = mean_sum / file_names.size + + logging.info('Finished computing the mean') + logging.info('Begin computing the std') + + # Online image_classification deviation + results = pool.starmap(_return_std, [[item, mean] for item in file_names]) + std_sum = np.sum(np.array([item[0] for item in results]), axis=0) + total_pixel_count = np.sum(np.array([item[1] for item in results])) + std = np.sqrt(std_sum / total_pixel_count) + logging.info('Finished computing the std') + + # Shut down the pool + pool.close() + + return mean, std + + +# Loads an image with OpenCV and returns the channel wise means of the image. +def _return_mean(image_path): + img = np.array(Image.open(image_path).convert('RGB')) + mean = np.array([np.mean(img[:, :, 0]), np.mean(img[:, :, 1]), np.mean(img[:, :, 2])]) / 255.0 + return mean + + +# Loads an image with OpenCV and returns the channel wise std of the image. +def _return_std(image_path, mean): + img = np.array(Image.open(image_path).convert('RGB')) / 255.0 + m2 = np.square(np.array([img[:, :, 0] - mean[0], img[:, :, 1] - mean[1], img[:, :, 2] - mean[2]])) + return np.sum(np.sum(m2, axis=1), 1), m2.size / 3.0 + + +def _cms_inmem(file_names): + """ + Computes mean and image_classification deviation in an offline fashion. This is possible only when the dataset can + be allocated in memory. + + Parameters + ---------- + file_names: List of String + List of file names of the dataset + Returns + ------- + mean : double + std : double + """ + img = np.zeros([file_names.size] + list(np.array(Image.open(file_names[0]).convert('RGB')).shape)) + + # Load all samples + for i, sample in enumerate(file_names): + img[i] = np.array(Image.open(sample).convert('RGB')) + + mean = np.array([np.mean(img[:, :, :, 0]), np.mean(img[:, :, :, 1]), np.mean(img[:, :, :, 2])]) / 255.0 + std = np.array([np.std(img[:, :, :, 0]), np.std(img[:, :, :, 1]), np.std(img[:, :, :, 2])]) / 255.0 + + return mean, std + + +def get_class_weights(input_folder, workers=4, **kwargs): + """ + Get the weights proportional to the inverse of their class frequencies. + The vector sums up to 1 + + Parameters + ---------- + input_folder : String (path) + Path to the dataset folder (see above for details) + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + ndarray[double] of size (num_classes) + The weights vector as a 1D array normalized (sum up to 1) + """ + # Sanity check on the folder + if not os.path.isdir(input_folder): + logging.error(f"Folder {input_folder} does not exist") + raise FileNotFoundError + + # Load the dataset + ds = datasets.ImageFolder(input_folder, transform=transforms.Compose([transforms.ToTensor()])) + + logging.info('Begin computing class frequencies weights') + + if hasattr(ds, 'targets'): + labels = ds.targets + elif hasattr(ds, 'labels'): + labels = ds.labels + else: + # This is a fail-safe net in case a custom dataset changed the name of the internal variables + data_loader = torch.utils.data.DataLoader(ds, batch_size=1, num_workers=workers) + labels = [] + for target, label in data_loader: + labels.append(label) + labels = np.concatenate(labels).reshape(len(ds)) + + class_support = np.unique(labels, return_counts=True)[1] + class_frequencies = class_support / len(labels) + # Class weights are the inverse of the class frequencies + class_weights = 1 / class_frequencies + # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) + class_weights /= class_weights.sum() + + logging.info('Finished computing class frequencies weights ') + logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') + logging.info(f'Class weights (rounded): {np.around(class_weights * 100, decimals=2)}') + + return class_weights + + +def compute_mean_std_graphs(dataset, **kwargs): + """ + Computes mean and std of all node and edge features present in the given ParsedGxlDataset (see gxl_parser.py). + + Parameters + ---------- + input_folder : ParsedGxlDataset + Dataset object (see above for details) + + # TODO implement online version + + Returns + ------- + node_features : {"mean": list, "std": list} + Mean and std value of all node features in the input dataset + edge_features : {"mean": list, "std": list} + Mean and std value of all edge features in the input dataset + """ + if dataset.data.x is not None: + logging.info('Begin computing the node feature mean and std') + nodes = _get_feature_mean_std(dataset.data.x) + logging.info('Finished computing the node feature mean and std') + else: + nodes = {} + logging.info('No node features present') + + if dataset.data.edge_attr is not None: + logging.info('Begin computing the edge feature mean and std') + edges = _get_feature_mean_std(dataset.data.edge_attr) + logging.info('Finished computing the edge feature mean and std') + else: + edges = {} + logging.info('No edge features present') + + return nodes, edges + + +def _get_feature_mean_std(torch_array): + array = np.array(torch_array) + return {'mean': [np.mean(col) for col in array.T], 'std': [np.std(col) for col in array.T]} + + +def get_class_weights_graphs(dataset, **kwargs): + """ + Get the weights proportional to the inverse of their class frequencies. + The vector sums up to 1 + + Parameters + ---------- + input_folder : ParsedGxlDataset + Dataset object (see above for details) + + # TODO implement online version + + Returns + ------- + ndarray[double] of size (num_classes) + The weights vector as a 1D array normalized (sum up to 1) + """ + logging.info('Begin computing class frequencies weights') + + class_frequencies = np.array(dataset.config['class_freq'][1]) + # Class weights are the inverse of the class frequencies + class_weights = 1 / class_frequencies + # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) + class_weights /= class_weights.sum() + + logging.info('Finished computing class frequencies weights ') + logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') + logging.info(f'Class weights (rounded): {np.around(class_weights)}') + + return class_weights + + +def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): + """ + Get the weights proportional to the inverse of their class frequencies. + The vector sums up to 1 + + Parameters + ---------- + gt_images: list of strings + Path to all ground truth images, which contain the pixel-wise label + workers: int + Number of workers to use for the mean/std computation + + Returns + ------- + ndarray[double] of size (num_classes) and ints the classes are represented as + The weights vector as a 1D array normalized (sum up to 1) + """ + logging.info('Begin computing class frequencies weights') + + total_num_pixels = 0 + label_counter = {} + + for path in gt_images: + img = np.array(Image.open(path))[:, :, 2].flatten() + total_num_pixels += len(img) + for i, j in zip(*np.unique(img, return_counts=True)): + label_counter[i] = label_counter.get(i, 0) + j + + classes = np.array(sorted(label_counter.keys())) + num_samples_per_class = np.array([label_counter[k] for k in classes]) + class_frequencies = (num_samples_per_class / total_num_pixels) + logging.info('Finished computing class frequencies weights') + logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') + # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) + return (1 / num_samples_per_class) / ((1 / num_samples_per_class).sum()), classes + + +if __name__ == '__main__': + # print(get_analytics(input_path=Path('/netscratch/datasets/semantic_segmentation/datasets/CB55/'), inmem=True, workers=16)) + print(get_analytics(input_path=Path('tests/dummy_data/dummy_dataset'))) From c43541bd3e40f810f8aa53b8b3bee29270799bab Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 28 Oct 2021 16:16:54 +0200 Subject: [PATCH 028/108] :construction: working on RGB encoding --- .../DivaHisDB/datamodule_cropped.py | 2 +- .../DivaHisDB/datasets/cropped_dataset.py | 4 +- src/datamodules/DivaHisDB/utils/functional.py | 13 ++---- src/datamodules/DivaHisDB/utils/misc.py | 2 +- src/datamodules/RGB/datamodule_cropped.py | 10 ++-- .../RGB/datasets/cropped_dataset.py | 5 +- src/datamodules/RGB/utils/functional.py | 46 ++++++------------- src/datamodules/RGB/utils/misc.py | 2 +- 8 files changed, 33 insertions(+), 51 deletions(-) diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index 8f172007..043fc490 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -16,7 +16,7 @@ class DivaHisDBDataModuleCropped(AbstractDatamodule): - def __init__(self, data_dir: str = None, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, selection_train: Optional[Union[int, List[str]]] = None, selection_val: Optional[Union[int, List[str]]] = None, selection_test: Optional[Union[int, List[str]]] = None, diff --git a/src/datamodules/DivaHisDB/datasets/cropped_dataset.py b/src/datamodules/DivaHisDB/datasets/cropped_dataset.py index 16707640..59ea9c37 100644 --- a/src/datamodules/DivaHisDB/datasets/cropped_dataset.py +++ b/src/datamodules/DivaHisDB/datasets/cropped_dataset.py @@ -31,7 +31,7 @@ class CroppedHisDBDataset(data.Dataset): root/data/xxz.png """ - def __init__(self, path: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None, is_test=False, image_transform=None, target_transform=None, twin_transform=None, classes=None, **kwargs): @@ -148,7 +148,7 @@ def _apply_transformation(self, img, gt): return img, gt, border_mask @staticmethod - def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None) \ -> List[Tuple[Path, Path, str, str, Tuple[int, int]]]: """ diff --git a/src/datamodules/DivaHisDB/utils/functional.py b/src/datamodules/DivaHisDB/utils/functional.py index bf9e24de..69b0ffae 100644 --- a/src/datamodules/DivaHisDB/utils/functional.py +++ b/src/datamodules/DivaHisDB/utils/functional.py @@ -23,15 +23,10 @@ def gt_to_one_hot(matrix: torch.Tensor, class_encodings: List[int]): """ num_classes = len(class_encodings) - if type(matrix).__module__ == np.__name__: - im_np = matrix[:, :, 2].astype(np.uint8) - border_mask = matrix[:, :, 0].astype(np.uint8) != 0 - else: - # TODO: ugly fix -> better to not normalize in the first place - np_array = (matrix * 255).numpy().astype(np.uint8) - im_np = np_array[2, :, :].astype(np.uint8) - border_mask = np_array[0, :, :].astype(np.uint8) != 0 - im_np[border_mask] = 1 + np_array = (matrix * 255).numpy().astype(np.uint8) + im_np = np_array[2, :, :].astype(np.uint8) + border_mask = np_array[0, :, :].astype(np.uint8) != 0 + im_np[border_mask] = 1 integer_encoded = np.array([i for i in range(num_classes)]) onehot_encoder = OneHotEncoder(sparse=False, categories='auto') diff --git a/src/datamodules/DivaHisDB/utils/misc.py b/src/datamodules/DivaHisDB/utils/misc.py index 3520992d..c0de22d4 100644 --- a/src/datamodules/DivaHisDB/utils/misc.py +++ b/src/datamodules/DivaHisDB/utils/misc.py @@ -48,7 +48,7 @@ def convert_to_rgb(pic): return pic -def validate_path_for_segmentation(data_dir, data_folder_name: str = 'data', gt_folder_name: str = 'gt'): +def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_name: str): if data_dir is None: raise PathNone("Please provide the path to root dir of the dataset " "(folder containing the train/val/test folder)") diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index 38cafc0d..7d348dbb 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -1,22 +1,23 @@ from pathlib import Path from typing import Union, List, Optional +import torch from torch.utils.data import DataLoader from torchvision import transforms -from src.datamodules.base_datamodule import AbstractDatamodule from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB from src.datamodules.RGB.utils.image_analytics import get_analytics from src.datamodules.RGB.utils.misc import validate_path_for_segmentation from src.datamodules.RGB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget +from src.datamodules.base_datamodule import AbstractDatamodule from src.utils import utils log = utils.get_logger(__name__) class DataModuleCroppedRGB(AbstractDatamodule): - def __init__(self, data_dir: str = None, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, selection_train: Optional[Union[int, List[str]]] = None, selection_val: Optional[Union[int, List[str]]] = None, selection_test: Optional[Union[int, List[str]]] = None, @@ -35,17 +36,18 @@ def __init__(self, data_dir: str = None, data_folder_name: str = 'data', gt_fold self.mean = analytics_data['mean'] self.std = analytics_data['std'] self.class_encodings = analytics_gt['class_encodings'] + self.class_encodings_np = torch.tensor(self.class_encodings) / 255 self.num_classes = len(self.class_encodings) self.class_weights = analytics_gt['class_weights'] + self.twin_transform = TwinRandomCrop(crop_size=crop_size) self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std)])) self.target_transform = OnlyTarget(transforms.Compose([ # transforms the gt image into a one-hot encoded matrix - OneHotEncoding(class_encodings=self.class_encodings), + OneHotEncoding(class_encodings=self.class_encodings_np), # transforms the one hot encoding to argmax labels -> for the cross-entropy criterion OneHotToPixelLabelling()])) - self.twin_transform = TwinRandomCrop(crop_size=crop_size) self.num_workers = num_workers self.batch_size = batch_size diff --git a/src/datamodules/RGB/datasets/cropped_dataset.py b/src/datamodules/RGB/datasets/cropped_dataset.py index beb790aa..89b622dc 100644 --- a/src/datamodules/RGB/datasets/cropped_dataset.py +++ b/src/datamodules/RGB/datasets/cropped_dataset.py @@ -19,6 +19,7 @@ log = utils.get_logger(__name__) + class CroppedDatasetRGB(data.Dataset): """A generic data loader where the images are arranged in this way: :: @@ -31,7 +32,7 @@ class CroppedDatasetRGB(data.Dataset): root/data/xxz.png """ - def __init__(self, path: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None, is_test=False, image_transform=None, target_transform=None, twin_transform=None, classes=None, **kwargs): @@ -148,7 +149,7 @@ def _apply_transformation(self, img, gt): return img, gt, border_mask @staticmethod - def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None) \ -> List[Tuple[Path, Path, str, str, Tuple[int, int]]]: """ diff --git a/src/datamodules/RGB/utils/functional.py b/src/datamodules/RGB/utils/functional.py index bf9e24de..8bdd0f06 100644 --- a/src/datamodules/RGB/utils/functional.py +++ b/src/datamodules/RGB/utils/functional.py @@ -1,12 +1,10 @@ from typing import List -import numpy as np import torch +from torch.nn.functional import one_hot -from sklearn.preprocessing import OneHotEncoder - -def gt_to_one_hot(matrix: torch.Tensor, class_encodings: List[int]): +def gt_to_one_hot(matrix: torch.Tensor, class_encodings: torch.Tensor): """ Convert ground truth tensor or numpy matrix to one-hot encoded matrix @@ -21,33 +19,19 @@ def gt_to_one_hot(matrix: torch.Tensor, class_encodings: List[int]): torch.LongTensor of size [#C x H x W] sparse one-hot encoded multi-class matrix, where #C is the number of classes """ - num_classes = len(class_encodings) - - if type(matrix).__module__ == np.__name__: - im_np = matrix[:, :, 2].astype(np.uint8) - border_mask = matrix[:, :, 0].astype(np.uint8) != 0 - else: - # TODO: ugly fix -> better to not normalize in the first place - np_array = (matrix * 255).numpy().astype(np.uint8) - im_np = np_array[2, :, :].astype(np.uint8) - border_mask = np_array[0, :, :].astype(np.uint8) != 0 - im_np[border_mask] = 1 - - integer_encoded = np.array([i for i in range(num_classes)]) - onehot_encoder = OneHotEncoder(sparse=False, categories='auto') - integer_encoded = integer_encoded.reshape(len(integer_encoded), 1) - onehot_encoded = onehot_encoder.fit_transform(integer_encoded).astype(np.int8) - - np.place(im_np, im_np == 0, - 1) # needed to deal with 0 fillers at the borders during testing (replace with background) - replace_dict = {k: v for k, v in zip(class_encodings, onehot_encoded)} - - # create the one hot matrix - one_hot_matrix = np.asanyarray( - [[replace_dict[im_np[i, j]] for j in range(im_np.shape[1])] for i in range(im_np.shape[0])]).astype( - np.uint8) - - return torch.LongTensor(one_hot_matrix.transpose((2, 0, 1))) + num_classes = class_encodings.shape[0] + + integer_encoded = torch.full(size=matrix[0].shape, fill_value=-1, dtype=torch.long) + for index, encoding in enumerate(class_encodings): + mask = torch.logical_and(torch.logical_and( + torch.where(matrix[0] == encoding[0], True, False), + torch.where(matrix[1] == encoding[1], True, False)), + torch.where(matrix[2] == encoding[2], True, False)) + integer_encoded[mask] = index + + onehot_encoded = one_hot(input=integer_encoded, num_classes=num_classes) + + return onehot_encoded def argmax_onehot(tensor: torch.Tensor): diff --git a/src/datamodules/RGB/utils/misc.py b/src/datamodules/RGB/utils/misc.py index 3520992d..c0de22d4 100644 --- a/src/datamodules/RGB/utils/misc.py +++ b/src/datamodules/RGB/utils/misc.py @@ -48,7 +48,7 @@ def convert_to_rgb(pic): return pic -def validate_path_for_segmentation(data_dir, data_folder_name: str = 'data', gt_folder_name: str = 'gt'): +def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_name: str): if data_dir is None: raise PathNone("Please provide the path to root dir of the dataset " "(folder containing the train/val/test folder)") From 0404ab532849bf813bd19a9f43f9813fc6ba9d93 Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 28 Oct 2021 19:40:18 +0200 Subject: [PATCH 029/108] :sparkle: :wrench: :recycle: RGB encoding is working :tada: --- configs/experiment/cb55_full_run_unet.yaml | 2 +- .../experiment/cb55_select_train15_unet.yaml | 2 +- .../cb55_select_train1_val1_unet.yaml | 2 +- .../development_baby_unet_cb55_10.yaml | 2 +- .../development_baby_unet_rgb_data.yaml | 4 +- .../synthetic_baby_unet_rolf_gtD.yaml | 69 ++++ configs/metric/iou.yaml | 2 +- ....yaml => semantic_segmentation_HisDB.yaml} | 2 +- configs/task/semantic_segmentation_RGB.yaml | 2 + .../DivaHisDB/datamodule_cropped.py | 7 + .../DivaHisDB/utils/image_analytics.py | 5 +- src/datamodules/RGB/datamodule_cropped.py | 11 +- .../RGB/datasets/cropped_dataset.py | 14 +- src/datamodules/RGB/utils/functional.py | 36 +- src/datamodules/RGB/utils/image_analytics.py | 2 +- src/datamodules/RGB/utils/output_tools.py | 23 +- src/datamodules/RGB/utils/twin_transforms.py | 16 +- src/tasks/DivaHisDB/semantic_segmentation.py | 2 +- src/tasks/RGB/__init__.py | 0 src/tasks/RGB/semantic_segmentation.py | 121 +++++++ .../datasets/test_cropped_hisdb_dataset.py | 21 +- .../DivaHisDB/test_hisDBDataModule.py | 3 +- .../DivaHisDB/test_image_analytics.py | 3 +- tests/datamodules/DivaHisDB/test_misc.py | 8 +- .../sem_seg/test_semantic_segmentation.py | 13 +- tools/generate_cropped_dataset.py | 16 +- tools/merge_cropped_output_RGB.py | 314 ++++++++++++++++++ 27 files changed, 636 insertions(+), 66 deletions(-) create mode 100644 configs/experiment/synthetic_baby_unet_rolf_gtD.yaml rename configs/task/{semantic_segmentation_task.yaml => semantic_segmentation_HisDB.yaml} (82%) create mode 100644 configs/task/semantic_segmentation_RGB.yaml create mode 100644 src/tasks/RGB/__init__.py create mode 100644 src/tasks/RGB/semantic_segmentation.py create mode 100644 tools/merge_cropped_output_RGB.py diff --git a/configs/experiment/cb55_full_run_unet.yaml b/configs/experiment/cb55_full_run_unet.yaml index cc6e29de..e6cc1894 100644 --- a/configs/experiment/cb55_full_run_unet.yaml +++ b/configs/experiment/cb55_full_run_unet.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_task.yaml + - /task: semantic_segmentation_HisDB.yaml - /loss: crossentropyloss.yaml - /metric: hisdbiou.yaml - /model/backbone: unet_model.yaml diff --git a/configs/experiment/cb55_select_train15_unet.yaml b/configs/experiment/cb55_select_train15_unet.yaml index 5dc9bd32..400b4b97 100644 --- a/configs/experiment/cb55_select_train15_unet.yaml +++ b/configs/experiment/cb55_select_train15_unet.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_task.yaml + - /task: semantic_segmentation_HisDB.yaml - /loss: crossentropyloss.yaml - /metric: hisdbiou.yaml - /model/backbone: unet_model.yaml diff --git a/configs/experiment/cb55_select_train1_val1_unet.yaml b/configs/experiment/cb55_select_train1_val1_unet.yaml index 66b97e4a..fbc81849 100644 --- a/configs/experiment/cb55_select_train1_val1_unet.yaml +++ b/configs/experiment/cb55_select_train1_val1_unet.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_task.yaml + - /task: semantic_segmentation_HisDB.yaml - /loss: crossentropyloss.yaml - /metric: hisdbiou.yaml - /model/backbone: unet_model.yaml diff --git a/configs/experiment/development_baby_unet_cb55_10.yaml b/configs/experiment/development_baby_unet_cb55_10.yaml index ff8b5c1c..5ac3bcdf 100644 --- a/configs/experiment/development_baby_unet_cb55_10.yaml +++ b/configs/experiment/development_baby_unet_cb55_10.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_task.yaml + - /task: semantic_segmentation_HisDB.yaml - /loss: crossentropyloss.yaml - /metric: hisdbiou.yaml - /model/backbone: baby_unet_model.yaml diff --git a/configs/experiment/development_baby_unet_rgb_data.yaml b/configs/experiment/development_baby_unet_rgb_data.yaml index 50bf5ad6..64c21f4d 100644 --- a/configs/experiment/development_baby_unet_rgb_data.yaml +++ b/configs/experiment/development_baby_unet_rgb_data.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_task.yaml + - /task: semantic_segmentation_RGB.yaml - /loss: crossentropyloss.yaml - /metric: iou.yaml - /model/backbone: baby_unet_model.yaml @@ -56,6 +56,8 @@ datamodule: callbacks: model_checkpoint: + monitor: "val/iou" + mode: "max" filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data watch_model: log_freq: 1 diff --git a/configs/experiment/synthetic_baby_unet_rolf_gtD.yaml b/configs/experiment/synthetic_baby_unet_rolf_gtD.yaml new file mode 100644 index 00000000..a419468b --- /dev/null +++ b/configs/experiment/synthetic_baby_unet_rolf_gtD.yaml @@ -0,0 +1,69 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_RGB.yaml + - /loss: crossentropyloss.yaml + - /metric: iou.yaml + - /model/backbone: baby_unet_model.yaml + - /model/header: identity.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +seed: 42 + +train: True +test: True + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 2000 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 100 + confusion_matrix_val: True + confusion_matrix_test: True + +datamodule: + _target_: src.datamodules.RGB.datamodule_cropped.DataModuleCroppedRGB + + data_dir: /netscratch/datasets/semantic_segmentation/synthetic_cropped/SetA1_sizeM/layoutD/split + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + gt_folder_name: gtD + +callbacks: + model_checkpoint: + monitor: "val/iou" + mode: "max" + filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data + watch_model: + log_freq: 100 + +logger: + wandb: + name: 'synthetic-baby-unet-rolf-gtD' + tags: [ "best_model", "synthetic", "gtD", "Rolf" ] + group: 'synthetic' diff --git a/configs/metric/iou.yaml b/configs/metric/iou.yaml index 090a362f..ec242dda 100644 --- a/configs/metric/iou.yaml +++ b/configs/metric/iou.yaml @@ -1,4 +1,4 @@ # more infos about parameters: https://torchmetrics.readthedocs.io/en/latest/references/modules.html#iou _target_: torchmetrics.IoU -num_classes: 8 \ No newline at end of file +num_classes: ${datamodule:num_classes} \ No newline at end of file diff --git a/configs/task/semantic_segmentation_task.yaml b/configs/task/semantic_segmentation_HisDB.yaml similarity index 82% rename from configs/task/semantic_segmentation_task.yaml rename to configs/task/semantic_segmentation_HisDB.yaml index 20f5560e..89dd207f 100644 --- a/configs/task/semantic_segmentation_task.yaml +++ b/configs/task/semantic_segmentation_HisDB.yaml @@ -1,2 +1,2 @@ -_target_: src.tasks.DivaHisDB.semantic_segmentation.SemanticSegmentation +_target_: src.tasks.DivaHisDB.semantic_segmentation.SemanticSegmentationHisDB diff --git a/configs/task/semantic_segmentation_RGB.yaml b/configs/task/semantic_segmentation_RGB.yaml new file mode 100644 index 00000000..eb686785 --- /dev/null +++ b/configs/task/semantic_segmentation_RGB.yaml @@ -0,0 +1,2 @@ +_target_: src.tasks.RGB.semantic_segmentation.SemanticSegmentationRGB + diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index 043fc490..88818e70 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -24,7 +24,12 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, shuffle: bool = True, drop_last: bool = True): super().__init__() + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + analytics = get_analytics(input_path=Path(data_dir), + data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name, get_gt_data_paths_func=CroppedHisDBDataset.get_gt_data_paths) self.mean = analytics['mean'] @@ -120,6 +125,8 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] def _create_dataset_parameters(self, dataset_type: str = 'train'): is_test = dataset_type == 'test' return {'path': self.data_dir / dataset_type, + 'data_folder_name': self.data_folder_name, + 'gt_folder_name': self.gt_folder_name, 'image_transform': self.image_transform, 'target_transform': self.target_transform, 'twin_transform': self.twin_transform, diff --git a/src/datamodules/DivaHisDB/utils/image_analytics.py b/src/datamodules/DivaHisDB/utils/image_analytics.py index 5dcc90c7..46a7b97a 100644 --- a/src/datamodules/DivaHisDB/utils/image_analytics.py +++ b/src/datamodules/DivaHisDB/utils/image_analytics.py @@ -15,7 +15,7 @@ from PIL import Image -def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): +def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, get_gt_data_paths_func, **kwargs): """ Parameters ---------- @@ -30,7 +30,8 @@ def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): analytics_dict = json.load(fp=f) else: train_path = input_path / 'train' - gt_data_path_list = get_gt_data_paths_func(train_path) + gt_data_path_list = get_gt_data_paths_func(train_path, data_folder_name=data_folder_name, + gt_folder_name=gt_folder_name) file_names_data = np.asarray([str(item[0]) for item in gt_data_path_list]) file_names_gt = np.asarray([str(item[1]) for item in gt_data_path_list]) mean, std = compute_mean_std(file_names=file_names_data, **kwargs) diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index 7d348dbb..e85a14a5 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -8,7 +8,8 @@ from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB from src.datamodules.RGB.utils.image_analytics import get_analytics from src.datamodules.RGB.utils.misc import validate_path_for_segmentation -from src.datamodules.RGB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling +from src.datamodules.RGB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling, \ + IntegerEncoding from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.datamodules.base_datamodule import AbstractDatamodule from src.utils import utils @@ -36,18 +37,14 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.mean = analytics_data['mean'] self.std = analytics_data['std'] self.class_encodings = analytics_gt['class_encodings'] - self.class_encodings_np = torch.tensor(self.class_encodings) / 255 + self.class_encodings_tensor = torch.tensor(self.class_encodings) / 255 self.num_classes = len(self.class_encodings) self.class_weights = analytics_gt['class_weights'] self.twin_transform = TwinRandomCrop(crop_size=crop_size) self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std)])) - self.target_transform = OnlyTarget(transforms.Compose([ - # transforms the gt image into a one-hot encoded matrix - OneHotEncoding(class_encodings=self.class_encodings_np), - # transforms the one hot encoding to argmax labels -> for the cross-entropy criterion - OneHotToPixelLabelling()])) + self.target_transform = OnlyTarget(IntegerEncoding(class_encodings=self.class_encodings_tensor)) self.num_workers = num_workers self.batch_size = batch_size diff --git a/src/datamodules/RGB/datasets/cropped_dataset.py b/src/datamodules/RGB/datasets/cropped_dataset.py index 89b622dc..b31a2aae 100644 --- a/src/datamodules/RGB/datasets/cropped_dataset.py +++ b/src/datamodules/RGB/datasets/cropped_dataset.py @@ -98,13 +98,13 @@ def __getitem__(self, index): def _get_train_val_items(self, index): data_img, gt_img = self._load_data_and_gt(index=index) - img, gt, boundary_mask = self._apply_transformation(data_img, gt_img) - return img, gt, boundary_mask + img, gt = self._apply_transformation(data_img, gt_img) + return img, gt def _get_test_items(self, index): data_img, gt_img = self._load_data_and_gt(index=index) - img, gt, boundary_mask = self._apply_transformation(data_img, gt_img) - return img, gt, boundary_mask, index + img, gt = self._apply_transformation(data_img, gt_img) + return img, gt, index def _load_data_and_gt(self, index): data_img = pil_loader(self.img_paths_per_page[index][0]) @@ -142,11 +142,10 @@ def _apply_transformation(self, img, gt): if not is_tensor(gt): gt = ToTensor()(gt) - border_mask = gt[0, :, :] != 0 if self.target_transform is not None: img, gt = self.target_transform(img, gt) - return img, gt, border_mask + return img, gt @staticmethod def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, @@ -158,8 +157,9 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st directory/data/ORIGINAL_FILENAME/FILE_NAME_X_Y.png directory/gt/ORIGINAL_FILENAME/FILE_NAME_X_Y.png - :param directory: + :param data_folder_name: + :param gt_folder_name: :param selection: :return: tuple (path_data_file, path_gt_file, original_image_name, (x, y)) diff --git a/src/datamodules/RGB/utils/functional.py b/src/datamodules/RGB/utils/functional.py index 8bdd0f06..b052d201 100644 --- a/src/datamodules/RGB/utils/functional.py +++ b/src/datamodules/RGB/utils/functional.py @@ -4,7 +4,7 @@ from torch.nn.functional import one_hot -def gt_to_one_hot(matrix: torch.Tensor, class_encodings: torch.Tensor): +def gt_to_int_encoding(matrix: torch.Tensor, class_encodings: torch.Tensor): """ Convert ground truth tensor or numpy matrix to one-hot encoded matrix @@ -19,17 +19,38 @@ def gt_to_one_hot(matrix: torch.Tensor, class_encodings: torch.Tensor): torch.LongTensor of size [#C x H x W] sparse one-hot encoded multi-class matrix, where #C is the number of classes """ - num_classes = class_encodings.shape[0] - integer_encoded = torch.full(size=matrix[0].shape, fill_value=-1, dtype=torch.long) for index, encoding in enumerate(class_encodings): mask = torch.logical_and(torch.logical_and( - torch.where(matrix[0] == encoding[0], True, False), - torch.where(matrix[1] == encoding[1], True, False)), - torch.where(matrix[2] == encoding[2], True, False)) + torch.where(matrix[0] == encoding[0], True, False), + torch.where(matrix[1] == encoding[1], True, False)), + torch.where(matrix[2] == encoding[2], True, False)) integer_encoded[mask] = index + return integer_encoded + + +def gt_to_one_hot(matrix: torch.Tensor, class_encodings: torch.Tensor): + """ + Convert ground truth tensor or numpy matrix to one-hot encoded matrix + + Parameters + ------- + matrix: float tensor from to_tensor() or numpy array + shape (C x H x W) in the range [0.0, 1.0] or shape (H x W x C) BGR + class_encodings: List of int + Blue channel values that encode the different classes + Returns + ------- + torch.LongTensor of size [#C x H x W] + sparse one-hot encoded multi-class matrix, where #C is the number of classes + """ + integer_encoded = gt_to_int_encoding(matrix=matrix, class_encodings=class_encodings) + + num_classes = class_encodings.shape[0] + onehot_encoded = one_hot(input=integer_encoded, num_classes=num_classes) + onehot_encoded = onehot_encoded.swapaxes(1, 2).swapaxes(0, 1) # changes axis from (0, 1, 2) to (2, 0, 1) return onehot_encoded @@ -38,4 +59,5 @@ def argmax_onehot(tensor: torch.Tensor): """ # TODO """ - return torch.LongTensor(torch.argmax(tensor, dim=0)) + output = torch.LongTensor(torch.argmax(tensor, dim=0)) + return output diff --git a/src/datamodules/RGB/utils/image_analytics.py b/src/datamodules/RGB/utils/image_analytics.py index e213a385..bcf50254 100644 --- a/src/datamodules/RGB/utils/image_analytics.py +++ b/src/datamodules/RGB/utils/image_analytics.py @@ -17,7 +17,7 @@ from src.datamodules.RGB.utils.misc import pil_loader -def get_analytics(input_path: Path, data_folder_name, gt_folder_name, get_gt_data_paths_func, **kwargs): +def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, get_gt_data_paths_func, **kwargs): """ Parameters ---------- diff --git a/src/datamodules/RGB/utils/output_tools.py b/src/datamodules/RGB/utils/output_tools.py index 06fe4441..6a472955 100644 --- a/src/datamodules/RGB/utils/output_tools.py +++ b/src/datamodules/RGB/utils/output_tools.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Union +from typing import Union, Tuple, List import numpy as np import torch @@ -56,7 +56,7 @@ def merge_patches(patch, coordinates, full_output): return full_output -def save_output_page_image(image_name, output_image, output_folder: Path, class_encoding): +def save_output_page_image(image_name, output_image, output_folder: Path, class_encoding: List[Tuple[int]]): """ Helper function to save the output during testing in the DIVAHisDB format @@ -68,7 +68,7 @@ def save_output_page_image(image_name, output_image, output_folder: Path, class_ output image at full size output_folder: Path path to the output folder for the test data - class_encoding: list(int) + class_encoding: list(tuple(int)) list with the class encodings Returns @@ -87,7 +87,7 @@ def save_output_page_image(image_name, output_image, output_folder: Path, class_ Image.fromarray(output_encoded.astype(np.uint8)).save(str(dest_filename)) -def output_to_class_encodings(output, class_encodings, perform_argmax=True): +def output_to_class_encodings(output, class_encodings): """ This function converts the output prediction matrix to an image like it was provided in the ground truth @@ -104,15 +104,16 @@ def output_to_class_encodings(output, class_encodings, perform_argmax=True): numpy array of size [C x H x W] (BGR) """ - B = np.argmax(output, axis=0) if perform_argmax else output + integer_encoded = np.argmax(output, axis=0) - class_to_B = {i: j for i, j in enumerate(class_encodings)} + num_classes = len(class_encodings) - masks = [B == old for old in class_to_B.keys()] + masks = [integer_encoded == class_index for class_index in range(num_classes)] - for mask, (old, new) in zip(masks, class_to_B.items()): - B = np.where(mask, new, B) - - rgb = np.dstack((np.zeros(shape=(B.shape[0], B.shape[1], 2), dtype=np.int8), B)) + rgb = np.full((*integer_encoded.shape, 3), -1) + for mask, color in zip(masks, class_encodings): + rgb[:, :, 0] = np.where(mask, color[0], rgb[:, :, 0]) + rgb[:, :, 1] = np.where(mask, color[1], rgb[:, :, 1]) + rgb[:, :, 2] = np.where(mask, color[2], rgb[:, :, 2]) return rgb diff --git a/src/datamodules/RGB/utils/twin_transforms.py b/src/datamodules/RGB/utils/twin_transforms.py index b0382c5c..605509a9 100644 --- a/src/datamodules/RGB/utils/twin_transforms.py +++ b/src/datamodules/RGB/utils/twin_transforms.py @@ -98,4 +98,18 @@ def __call__(self, gt): Returns: """ - return F_custom.gt_to_one_hot(gt, self.class_encodings) \ No newline at end of file + return F_custom.gt_to_one_hot(gt, self.class_encodings) + + +class IntegerEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_int_encoding(gt, self.class_encodings) \ No newline at end of file diff --git a/src/tasks/DivaHisDB/semantic_segmentation.py b/src/tasks/DivaHisDB/semantic_segmentation.py index 3fd21b3a..68d0635b 100644 --- a/src/tasks/DivaHisDB/semantic_segmentation.py +++ b/src/tasks/DivaHisDB/semantic_segmentation.py @@ -14,7 +14,7 @@ log = utils.get_logger(__name__) -class SemanticSegmentation(AbstractTask): +class SemanticSegmentationHisDB(AbstractTask): def __init__(self, model: nn.Module, diff --git a/src/tasks/RGB/__init__.py b/src/tasks/RGB/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tasks/RGB/semantic_segmentation.py b/src/tasks/RGB/semantic_segmentation.py new file mode 100644 index 00000000..3e978751 --- /dev/null +++ b/src/tasks/RGB/semantic_segmentation.py @@ -0,0 +1,121 @@ +from pathlib import Path +from typing import Optional, Callable, Union + +import numpy as np +import torch.nn as nn +import torch.optim +import torchmetrics + +from src.tasks.base_task import AbstractTask +from src.datamodules.DivaHisDB.utils.output_tools import _get_argmax +from src.utils import utils +from src.tasks.utils.outputs import OutputKeys, reduce_dict + +log = utils.get_logger(__name__) + + +class SemanticSegmentationRGB(AbstractTask): + + def __init__(self, + model: nn.Module, + optimizer: torch.optim.Optimizer, + loss_fn: Optional[Callable] = None, + metric_train: Optional[torchmetrics.Metric] = None, + metric_val: Optional[torchmetrics.Metric] = None, + metric_test: Optional[torchmetrics.Metric] = None, + test_output_path: Optional[Union[str, Path]] = 'predictions', + confusion_matrix_val: Optional[bool] = False, + confusion_matrix_test: Optional[bool] = False, + confusion_matrix_log_every_n_epoch: Optional[int] = 1, + lr: float = 1e-3 + ) -> None: + """ + pixelvise semantic segmentation. The output of the network during test is a DIVAHisDB encoded image + + :param model: torch.nn.Module + The encoder for the segmentation e.g. unet + :param test_output_path: str + String with a path to the output folder of the testing + """ + super().__init__( + model=model, + optimizer=optimizer, + loss_fn=loss_fn, + metric_train=metric_train, + metric_val=metric_val, + metric_test=metric_test, + test_output_path=test_output_path, + lr=lr, + confusion_matrix_val=confusion_matrix_val, + confusion_matrix_test=confusion_matrix_test, + confusion_matrix_log_every_n_epoch=confusion_matrix_log_every_n_epoch, + ) + self.save_hyperparameters() + + def setup(self, stage: str) -> None: + super().setup(stage) + + if not hasattr(self.trainer.datamodule, 'get_img_name_coordinates'): + raise NotImplementedError('DataModule needs to implement get_img_name_coordinates function') + + log.info("Setup done!") + + def forward(self, x): + return self.model(x) + + @staticmethod + def to_metrics_format(x: torch.Tensor, **kwargs) -> torch.Tensor: + return _get_argmax(x, **kwargs) + + ############################################################################################# + ########################################### TRAIN ########################################### + ############################################################################################# + def training_step(self, batch, batch_idx, **kwargs): + input_batch, target_batch = batch + output = super().training_step(batch=(input_batch, target_batch), batch_idx=batch_idx) + return reduce_dict(input_dict=output, key_list=[OutputKeys.LOSS]) + + ############################################################################################# + ############################################ VAL ############################################ + ############################################################################################# + + def validation_step(self, batch, batch_idx, **kwargs): + input_batch, target_batch = batch + output = super().validation_step(batch=(input_batch, target_batch), batch_idx=batch_idx) + return reduce_dict(input_dict=output, key_list=[]) + + ############################################################################################# + ########################################### TEST ############################################ + ############################################################################################# + + def test_step(self, batch, batch_idx, **kwargs): + input_batch, target_batch, input_idx = batch + output = super().test_step(batch=(input_batch, target_batch), batch_idx=batch_idx) + + if not hasattr(self.trainer.datamodule, 'get_img_name_coordinates'): + raise NotImplementedError('Datamodule does not provide detailed information of the crop') + + for patch, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), + input_idx.detach().cpu().numpy()): + patch_info = self.trainer.datamodule.get_img_name_coordinates(idx) + img_name = patch_info[0] + patch_name = patch_info[1] + dest_folder = self.test_output_path / 'patches' / img_name + dest_folder.mkdir(parents=True, exist_ok=True) + dest_filename = dest_folder / f'{patch_name}.npy' + + np.save(file=str(dest_filename), arr=patch) + + return reduce_dict(input_dict=output, key_list=[]) + + def on_test_end(self) -> None: + datamodule_path = self.trainer.datamodule.data_dir + prediction_path = (self.test_output_path / 'patches').absolute() + output_path = (self.test_output_path / 'result').absolute() + + data_folder_name = self.trainer.datamodule.data_folder_name + gt_folder_name = self.trainer.datamodule.gt_folder_name + + log.info(f'To run the merging of patches:') + log.info(f'python tools/merge_cropped_output_RGB.py -d {datamodule_path} -p {prediction_path} -o {output_path} ' + f'-df {data_folder_name} -gf {gt_folder_name}') diff --git a/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py b/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py index 03d95c9b..17c44464 100644 --- a/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py +++ b/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py @@ -10,17 +10,17 @@ @fixture def dataset_train(data_dir_cropped): - return CroppedHisDBDataset(path=data_dir_cropped / 'train') + return CroppedHisDBDataset(path=data_dir_cropped / 'train', data_folder_name='data', gt_folder_name='gt') @fixture def dataset_val(data_dir_cropped): - return CroppedHisDBDataset(path=data_dir_cropped / 'val') + return CroppedHisDBDataset(path=data_dir_cropped / 'val', data_folder_name='data', gt_folder_name='gt') @fixture def dataset_test(data_dir_cropped): - return CroppedHisDBDataset(path=data_dir_cropped / 'test') + return CroppedHisDBDataset(path=data_dir_cropped / 'test', data_folder_name='data', gt_folder_name='gt') def test__load_data_and_gt(dataset_train): @@ -53,23 +53,27 @@ def test__get_train_val_items_test(dataset_test): def test_dataset_train_selection_int_error(data_dir_cropped): with pytest.raises(ValueError): - CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train', selection=2) + CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train', + data_folder_name='data', gt_folder_name='gt', selection=2) def test_dataset_train_selection_int(data_dir_cropped, get_train_file_names): - files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train', selection=1) + files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train', + data_folder_name='data', gt_folder_name='gt', selection=1) assert len(files_from_method) == 12 assert files_from_method == get_train_file_names def test_get_gt_data_paths_train(data_dir_cropped, get_train_file_names): - files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train') + files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train', + data_folder_name='data', gt_folder_name='gt') assert len(files_from_method) == 12 assert files_from_method == get_train_file_names def test_get_gt_data_paths_val(data_dir_cropped): - files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'val') + files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'val', + data_folder_name='data', gt_folder_name='gt') expected_result = [(PosixPath( data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), PosixPath( @@ -135,7 +139,8 @@ def test_get_gt_data_paths_val(data_dir_cropped): def test_get_gt_data_paths_test(data_dir_cropped): - files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'test') + files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'test', + data_folder_name='data', gt_folder_name='gt') expected_result = [(PosixPath( data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), PosixPath( diff --git a/tests/datamodules/DivaHisDB/test_hisDBDataModule.py b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py index d4b6dc63..7c1695fa 100644 --- a/tests/datamodules/DivaHisDB/test_hisDBDataModule.py +++ b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py @@ -14,7 +14,8 @@ @fixture def data_module_cropped(data_dir_cropped): OmegaConf.clear_resolvers() - datamodules = DivaHisDBDataModuleCropped(data_dir_cropped, num_workers=NUM_WORKERS) + datamodules = DivaHisDBDataModuleCropped(data_dir_cropped, data_folder_name='data', gt_folder_name='gt', + num_workers=NUM_WORKERS) return datamodules diff --git a/tests/datamodules/DivaHisDB/test_image_analytics.py b/tests/datamodules/DivaHisDB/test_image_analytics.py index d00c3017..c8b7bc83 100644 --- a/tests/datamodules/DivaHisDB/test_image_analytics.py +++ b/tests/datamodules/DivaHisDB/test_image_analytics.py @@ -13,7 +13,8 @@ def test_get_analytics_no_file(data_dir_cropped): - output = get_analytics(input_path=data_dir_cropped, get_gt_data_paths_func=CroppedHisDBDataset.get_gt_data_paths) + output = get_analytics(input_path=data_dir_cropped, data_folder_name='data', gt_folder_name='gt', + get_gt_data_paths_func=CroppedHisDBDataset.get_gt_data_paths) assert np.array_equal(np.round(TEST_JSON['mean'], 8), np.round(output['mean'], 8)) assert np.array_equal(np.round(TEST_JSON['std'], 8), np.round(output['std'], 8)) diff --git a/tests/datamodules/DivaHisDB/test_misc.py b/tests/datamodules/DivaHisDB/test_misc.py index f18382b8..8a93655d 100644 --- a/tests/datamodules/DivaHisDB/test_misc.py +++ b/tests/datamodules/DivaHisDB/test_misc.py @@ -42,21 +42,21 @@ def path_missing_subfolder(tmp_path): def test_validate_path_none(): with pytest.raises(PathNone): - validate_path_for_segmentation(data_dir=None) + validate_path_for_segmentation(data_dir=None, data_folder_name='data', gt_folder_name='gt') def test_validate_path_not_dir(tmp_path): tmp_file = tmp_path / "newfile" tmp_file.touch() with pytest.raises(PathNotDir): - validate_path_for_segmentation(data_dir=tmp_file) + validate_path_for_segmentation(data_dir=tmp_file, data_folder_name='data', gt_folder_name='gt') def test_validate_path_missing_split(path_missing_split): with pytest.raises(PathMissingSplitDir): - validate_path_for_segmentation(data_dir=path_missing_split) + validate_path_for_segmentation(data_dir=path_missing_split, data_folder_name='data', gt_folder_name='gt') def test_validate_path_missing_subfolder(path_missing_subfolder): with pytest.raises(PathMissingDirinSplitDir): - validate_path_for_segmentation(data_dir=path_missing_subfolder) + validate_path_for_segmentation(data_dir=path_missing_subfolder, data_folder_name='data', gt_folder_name='gt') diff --git a/tests/tasks/sem_seg/test_semantic_segmentation.py b/tests/tasks/sem_seg/test_semantic_segmentation.py index af225008..243d580c 100644 --- a/tests/tasks/sem_seg/test_semantic_segmentation.py +++ b/tests/tasks/sem_seg/test_semantic_segmentation.py @@ -8,7 +8,7 @@ from pytorch_lightning import seed_everything from src.datamodules.DivaHisDB.datamodule_cropped import DivaHisDBDataModuleCropped -from src.tasks.DivaHisDB.semantic_segmentation import SemanticSegmentation +from src.tasks.DivaHisDB.semantic_segmentation import SemanticSegmentationHisDB from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped @@ -19,17 +19,18 @@ def test_semantic_segmentation(data_dir_cropped, tmp_path): # datamodule data_module = DivaHisDBDataModuleCropped( data_dir=str(data_dir_cropped), + data_folder_name='data', gt_folder_name='gt', batch_size=2, num_workers=2) def baby_unet(): return UNet(num_classes=len(data_module.class_encodings), num_layers=2, features_start=32) model = baby_unet() - segmentation = SemanticSegmentation(model=model, - optimizer=torch.optim.Adam(params=model.parameters()), - loss_fn=torch.nn.CrossEntropyLoss(), - test_output_path=tmp_path - ) + segmentation = SemanticSegmentationHisDB(model=model, + optimizer=torch.optim.Adam(params=model.parameters()), + loss_fn=torch.nn.CrossEntropyLoss(), + test_output_path=tmp_path + ) # different paths needed later patches_path = segmentation.test_output_path / 'patches' diff --git a/tools/generate_cropped_dataset.py b/tools/generate_cropped_dataset.py index ae41c2dd..8d5791df 100644 --- a/tools/generate_cropped_dataset.py +++ b/tools/generate_cropped_dataset.py @@ -17,7 +17,7 @@ from tqdm import tqdm IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] - +JPG_EXTENSIONS = ['.jpg', '.jpeg'] def has_extension(filename, extensions): """Checks if a file is an allowed extension. @@ -161,6 +161,11 @@ def __init__(self, input_path: Path, output_path, crop_size_train, crop_size_val def write_crops(self): info_list = ['Running CroppedDatasetGenerator.write_crops():', + f'- full_command:', + f'python tools/generate_cropped_dataset.py -i {self.input_path} -o {self.output_path} ' + f'-tr {self.crop_size_train} -v {self.crop_size_val} -te {self.crop_size_test} -ov {self.overlap} ' + f'-l {self.leading_zeros_length}', + f'', f'- start_time: \t{datetime.now():%Y-%m-%d_%H-%M-%S}', f'- input_path: \t{self.input_path}', f'- output_path: \t{self.output_path}', @@ -242,7 +247,14 @@ def write_crops(self): img = self.get_crop(self.current_img, coordinates=coordinates, crop_function=crop_function) - save_image(img, dest_filename) + pil_img = F.to_pil_image(img, mode='RGB') + + if extension in JPG_EXTENSIONS: + pil_img.save(dest_filename, quality=95) + else: + # save_image(img, dest_filename) + pil_img.save(dest_filename) + def _load_image(self, img_index): """ diff --git a/tools/merge_cropped_output_RGB.py b/tools/merge_cropped_output_RGB.py new file mode 100644 index 00000000..04462db7 --- /dev/null +++ b/tools/merge_cropped_output_RGB.py @@ -0,0 +1,314 @@ +import argparse +import math +import re +import threading +from collections import defaultdict +from dataclasses import dataclass +from datetime import datetime +from multiprocessing.pool import ThreadPool +from pathlib import Path + +import numpy as np +from PIL import Image +from tqdm import tqdm + +from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB +from src.datamodules.RGB.utils.output_tools import merge_patches, save_output_page_image +from src.datamodules.RGB.datamodule_cropped import DataModuleCroppedRGB +from tools.generate_cropped_dataset import pil_loader +from tools.viz import visualize + + +@dataclass +class CropData: + name: Path + offset_x: int + offset_y: int + height: int + width: int + pred_path: Path + img_path: Path + gt_path: Path + + +class CroppedOutputMerger: + def __init__(self, datamodule_path: Path, prediction_path: Path, output_path: Path, + data_folder_name: str, gt_folder_name: str, num_threads: int = 10): + # Defaults + self.load_only_first_crop_for_size = True # All crops have to be the same size in the current implementation + + self.datamodule_path = datamodule_path + self.prediction_path = prediction_path + self.output_path = output_path + + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + + data_module = DataModuleCroppedRGB(data_dir=str(datamodule_path), data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name) + self.num_classes = data_module.num_classes + self.class_encodings = data_module.class_encodings + + img_paths_per_page = CroppedDatasetRGB.get_gt_data_paths(directory=datamodule_path / 'test', + data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name) + + dataset_img_name_list = [] + self.dataset_dict = defaultdict(list) + for img_path, gt_path, img_name, pred_path, (x, y) in img_paths_per_page: + if img_name not in dataset_img_name_list: + dataset_img_name_list.append(img_name) + self.dataset_dict[img_name].append((img_path, gt_path, pred_path, x, y)) + + dataset_img_name_list = sorted(dataset_img_name_list) + + # sort dataset_dict lists + for img_name in self.dataset_dict.keys(): + self.dataset_dict[img_name] = sorted(self.dataset_dict[img_name], key=lambda v: (v[4], v[3])) + + self.img_name_list = sorted([str(n.name) for n in prediction_path.iterdir() if n.is_dir()]) + + # check if all images from the dataset are found in the prediction output + assert sorted(dataset_img_name_list) == sorted(self.img_name_list) + + self.num_pages = len(self.img_name_list) + if self.num_pages >= num_threads: + self.num_threads = num_threads + else: + self.num_threads = self.num_pages + + assert self.num_pages > 0 + + def merge_all(self): + start_time = datetime.now() + info_list = ['Running merge_cropped_output_RGB.py:', + f'- start_time: \t{start_time:%Y-%m-%d_%H-%M-%S}', + f'- datamodule_path: \t{self.datamodule_path}', + f'- prediction_path: \t{self.prediction_path}', + f'- output_path: \t{self.output_path}', + f'- data_folder_name: \t{self.data_folder_name}', + f'- gt_folder_name: \t{self.gt_folder_name}', + f'- num_pages: \t{self.num_pages}', + f'- num_threads: \t{self.num_threads}', + ''] # empty string to get linebreak at the end when using join + info_str = '\n'.join(info_list) + print(info_str, flush=True) + + # Write info_cropped_dataset.txt + self.output_path.mkdir(parents=True, exist_ok=True) + info_file = self.output_path / 'info_merge_cropped_output.txt' + with info_file.open('a') as f: + f.write(info_str) + + pool = ThreadPool(self.num_threads) + lock = threading.Lock() + results = [] + for position, img_name in enumerate(self.img_name_list): + results.append(pool.apply_async(self.merge_page, args=(img_name, lock, position))) + pool.close() + pool.join() + + results = [r.get() for r in results] + + # Closing the progress bars in order for a beautiful output + for i in range(3): + for pbars in results: + pbars[i].close() + + end_time = datetime.now() + duration = end_time - start_time + + # Write final info + info_list = [f'- end_time: \t{datetime.now():%Y-%m-%d_%H-%M-%S}', + f'- duration: \t{duration}', + ''] # empty string to get linebreak at the end when using join + info_str = '\n'.join(info_list) + + print('\n' + info_str) + # print(f'- log_file: \t{info_file}\n') + + with info_file.open('a') as f: + f.write(info_str) + f.write('\n') + + print('DONE!') + + def merge_page(self, img_name: str, lock, position): + page_info_str = f'[{str(position + 1).rjust(int(math.log10(self.num_pages)) + 1)}/{self.num_pages}] {img_name}' + + preds_folder = self.prediction_path / img_name + coordinates = re.compile(r'.+_x(\d+)_y(\d+)\.npy$') + + if not preds_folder.is_dir(): + print(f'Skipping {preds_folder}. Not a directory!') + return + + preds_list = [] + for pred_path in preds_folder.glob(f'{img_name}*.npy'): + m = coordinates.match(pred_path.name) + if m is None: + continue + x = int(m.group(1)) + y = int(m.group(2)) + preds_list.append((x, y, pred_path)) + preds_list = sorted(preds_list, key=lambda v: (v[1], v[0])) + + img_gt_list = self.dataset_dict[img_name] + + # The number of patches in the prediction should be equal to number of patches in dataset + assert len(preds_list) == len(img_gt_list) + + crop_data_list = [] + + # merge into one list + with lock: + pbar1 = tqdm(total=len(preds_list), + position=position, + # file=sys.stdout, + leave=True, + desc=f'{page_info_str}: Merging path lists') + + crop_width = -1 + crop_height = -1 + + if self.load_only_first_crop_for_size: + pred_path = preds_list[0][2] + pred = np.load(str(pred_path)) + crop_width = pred.shape[1] + crop_height = pred.shape[2] + + for (x, y, pred_path), (img_path, gt_path, crop_name, x_data, y_data) in zip(preds_list, img_gt_list): + assert (x, y) == (x_data, y_data) + assert pred_path.name.startswith(crop_name) + assert img_path.name.startswith(crop_name) + assert gt_path.name.startswith(crop_name) + + if not self.load_only_first_crop_for_size: + pred = np.load(str(pred_path)) + crop_width = pred.shape[1] + crop_height = pred.shape[2] + + crop_data_list.append( + CropData(name=crop_name, offset_x=x, offset_y=y, width=crop_width, height=crop_height, + img_path=img_path, gt_path=gt_path, pred_path=pred_path)) # , pred=pred)) + + pbar1.update() + + with lock: + pbar1.refresh() + + # Create new canvas + canvas_width = crop_data_list[-1].width + crop_data_list[-1].offset_x + canvas_height = crop_data_list[-1].height + crop_data_list[-1].offset_y + + pred_canvas_size = (self.num_classes, canvas_height, canvas_width) + pred_canvas = np.empty(pred_canvas_size) + pred_canvas.fill(np.nan) + + img_canvas = Image.new(mode='RGB', size=(canvas_width, canvas_height)) + gt_canvas = Image.new(mode='RGB', size=(canvas_width, canvas_height)) + + with lock: + pbar2 = tqdm(total=len(crop_data_list), + position=position + (1 * self.num_pages), + # file=sys.stdout, + leave=True, + desc=f'{page_info_str}: Merging crops') + + for crop_data in crop_data_list: + # Add the pred to the pred_canvas + pred = np.load(str(crop_data.pred_path)) + + # make sure all crops have same size + assert crop_width == pred.shape[1] + assert crop_height == pred.shape[2] + + pred_canvas = merge_patches(pred, (crop_data.offset_x, crop_data.offset_y), pred_canvas) + + img_crop = pil_loader(crop_data.img_path) + img_canvas.paste(img_crop, (crop_data.offset_x, crop_data.offset_y)) + + gt_crop = pil_loader(crop_data.gt_path) + gt_canvas.paste(gt_crop, (crop_data.offset_x, crop_data.offset_y)) + + pbar2.update() + + with lock: + pbar2.refresh() + + # Save the image when done + outdir_img = self.output_path / 'img' + outdir_gt = self.output_path / 'gt' + outdir_pred = self.output_path / 'pred' + + outdir_img.mkdir(parents=True, exist_ok=True) + outdir_gt.mkdir(parents=True, exist_ok=True) + outdir_pred.mkdir(parents=True, exist_ok=True) + + # Loop to allow progress bar + with lock: + pbar3 = tqdm(total=3, + position=position + (2 * self.num_pages), + # file=sys.stdout, + leave=True, + desc=f'{page_info_str}: Saving merged image files') + + for i in range(3): + if i == 0: + pbar3.set_description(f'{page_info_str}: Saving merged image files ' + '(img)'.ljust(10)) + img_canvas.save(fp=outdir_img / f'{img_name}.png') + + elif i == 1: + pbar3.set_description(f'{page_info_str}: Saving merged image files ' + '(gt)'.ljust(10)) + gt_canvas.save(fp=outdir_gt / f'{img_name}.gif') + + elif i == 2: + pbar3.set_description(f'{page_info_str}: Saving merged image files ' + '(pred)'.ljust(10)) + # Save prediction only when complete + if not np.isnan(np.sum(pred_canvas)): + # Save the final image (image_name, output_image, output_folder, class_encoding) + save_output_page_image(image_name=f'{img_name}.gif', output_image=pred_canvas, + output_folder=outdir_pred, class_encoding=self.class_encodings) + else: + print(f'WARNING: Test image {img_name} was not written! It still contains NaN values.') + break # so last step is not + + pbar3.update() + + with lock: + pbar3.refresh() + + # The progress bars will be close in order in main thread + return pbar1, pbar2, pbar3 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-d', '--datamodule_path', + help='Path to the root folder of the dataset (contains train/val/test)', + type=Path, + required=True) + parser.add_argument('-p', '--prediction_path', + help='Path to the prediction patches folder', + type=Path, + required=True) + parser.add_argument('-o', '--output_path', + help='Path to the output folder', + type=Path, + required=True) + parser.add_argument('-df', '--data_folder_name', + help='Name of data folder', + type=str, + required=True) + parser.add_argument('-gf', '--gt_folder_name', + help='Name of gt folder', + type=str, + required=True) + parser.add_argument('-n', '--num_threads', + help='Number of threads for parallel processing', + type=int, + default=10) + + args = parser.parse_args() + merger = CroppedOutputMerger(**args.__dict__) + merger.merge_all() From 5fe9510e4c049483402092d492567bc6f16d06e7 Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 28 Oct 2021 23:51:10 +0200 Subject: [PATCH 030/108] :wrench: additional config --- .../synthetic_baby_unet_rolf_layoutR_gtD.yaml | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml diff --git a/configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml b/configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml new file mode 100644 index 00000000..81ecb71b --- /dev/null +++ b/configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml @@ -0,0 +1,69 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_RGB.yaml + - /loss: crossentropyloss.yaml + - /metric: iou.yaml + - /model/backbone: baby_unet_model.yaml + - /model/header: identity.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +seed: 42 + +train: True +test: True + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 2000 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 100 + confusion_matrix_val: True + confusion_matrix_test: True + +datamodule: + _target_: src.datamodules.RGB.datamodule_cropped.DataModuleCroppedRGB + + data_dir: /netscratch/datasets/semantic_segmentation/synthetic_cropped/SetA1_sizeM/layoutR/split + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + gt_folder_name: gtD + +callbacks: + model_checkpoint: + monitor: "val/iou" + mode: "max" + filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data + watch_model: + log_freq: 100 + +logger: + wandb: + name: 'synthetic-baby-unet-rolf-layoutR-gtD' + tags: [ "best_model", "synthetic", "layoutR", "gtD", "Rolf" ] + group: 'synthetic' From 3ad7afd0fd3ee5c884e0566aa522f5e90c1646b9 Mon Sep 17 00:00:00 2001 From: Paul M Date: Fri, 29 Oct 2021 10:21:32 +0200 Subject: [PATCH 031/108] :horse: only create integer encoding since one-hot encoding is not use atm --- .../DivaHisDB/datamodule_cropped.py | 13 ++++-------- src/datamodules/DivaHisDB/utils/functional.py | 20 +++++++++++++++++++ .../DivaHisDB/utils/twin_transforms.py | 16 ++++++++++++++- src/datamodules/RGB/utils/twin_transforms.py | 2 +- src/tasks/DivaHisDB/semantic_segmentation.py | 2 +- ...utput.py => merge_cropped_output_HisDB.py} | 2 +- 6 files changed, 42 insertions(+), 13 deletions(-) rename tools/{merge_cropped_output.py => merge_cropped_output_HisDB.py} (99%) diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index 88818e70..48a8819d 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -8,7 +8,8 @@ from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics from src.datamodules.DivaHisDB.utils.misc import validate_path_for_segmentation -from src.datamodules.DivaHisDB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling +from src.datamodules.DivaHisDB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling, \ + IntegerEncoding from src.datamodules.DivaHisDB.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils @@ -38,14 +39,10 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.num_classes = len(self.class_encodings) self.class_weights = analytics['class_weights'] + self.twin_transform = TwinRandomCrop(crop_size=crop_size) self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std)])) - self.target_transform = OnlyTarget(transforms.Compose([ - # transforms the gt image into a one-hot encoded matrix - OneHotEncoding(class_encodings=self.class_encodings), - # transforms the one hot encoding to argmax labels -> for the cross-entropy criterion - OneHotToPixelLabelling()])) - self.twin_transform = TwinRandomCrop(crop_size=crop_size) + self.target_transform = OnlyTarget(IntegerEncoding(class_encodings=self.class_encodings)) self.num_workers = num_workers self.batch_size = batch_size @@ -144,5 +141,3 @@ def get_img_name_coordinates(self, index): raise Exception('This method can just be called during testing') return self.test.img_paths_per_page[index][2:] - - diff --git a/src/datamodules/DivaHisDB/utils/functional.py b/src/datamodules/DivaHisDB/utils/functional.py index 69b0ffae..33b31346 100644 --- a/src/datamodules/DivaHisDB/utils/functional.py +++ b/src/datamodules/DivaHisDB/utils/functional.py @@ -6,6 +6,26 @@ from sklearn.preprocessing import OneHotEncoder +def gt_to_int_encoding(matrix: torch.Tensor, class_encodings: List[int]): + np_array = (matrix * 255).numpy().astype(np.uint8) + + # take only blue channel + im_np = np_array[2, :, :].astype(np.uint8) + + # change border pixels to background + border_mask = np_array[0, :, :].astype(np.uint8) != 0 + im_np[border_mask] = 1 + + im_tensor = torch.tensor(im_np) + + integer_encoded = torch.full(size=im_tensor.shape, fill_value=-1, dtype=torch.long) + for index, encoding in enumerate(class_encodings): + mask = torch.where(im_tensor == encoding, True, False) + integer_encoded[mask] = index + + return integer_encoded + + def gt_to_one_hot(matrix: torch.Tensor, class_encodings: List[int]): """ Convert ground truth tensor or numpy matrix to one-hot encoded matrix diff --git a/src/datamodules/DivaHisDB/utils/twin_transforms.py b/src/datamodules/DivaHisDB/utils/twin_transforms.py index 5b193825..9b984508 100644 --- a/src/datamodules/DivaHisDB/utils/twin_transforms.py +++ b/src/datamodules/DivaHisDB/utils/twin_transforms.py @@ -98,4 +98,18 @@ def __call__(self, gt): Returns: """ - return F_custom.gt_to_one_hot(gt, self.class_encodings) \ No newline at end of file + return F_custom.gt_to_one_hot(gt, self.class_encodings) + + +class IntegerEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_int_encoding(gt, self.class_encodings) diff --git a/src/datamodules/RGB/utils/twin_transforms.py b/src/datamodules/RGB/utils/twin_transforms.py index 605509a9..5ba68cad 100644 --- a/src/datamodules/RGB/utils/twin_transforms.py +++ b/src/datamodules/RGB/utils/twin_transforms.py @@ -112,4 +112,4 @@ def __call__(self, gt): Returns: """ - return F_custom.gt_to_int_encoding(gt, self.class_encodings) \ No newline at end of file + return F_custom.gt_to_int_encoding(gt, self.class_encodings) diff --git a/src/tasks/DivaHisDB/semantic_segmentation.py b/src/tasks/DivaHisDB/semantic_segmentation.py index 68d0635b..1d9989a9 100644 --- a/src/tasks/DivaHisDB/semantic_segmentation.py +++ b/src/tasks/DivaHisDB/semantic_segmentation.py @@ -119,4 +119,4 @@ def on_test_end(self) -> None: output_path = (self.test_output_path / 'result').absolute() log.info(f'To run the merging of patches:') - log.info(f'python tools/merge_cropped_output.py -d {datamodule_path} -p {prediction_path} -o {output_path}') + log.info(f'python tools/merge_cropped_output_HisDB.py -d {datamodule_path} -p {prediction_path} -o {output_path}') diff --git a/tools/merge_cropped_output.py b/tools/merge_cropped_output_HisDB.py similarity index 99% rename from tools/merge_cropped_output.py rename to tools/merge_cropped_output_HisDB.py index b5ac8943..8c1f7457 100644 --- a/tools/merge_cropped_output.py +++ b/tools/merge_cropped_output_HisDB.py @@ -74,7 +74,7 @@ def __init__(self, datamodule_path: Path, prediction_path: Path, output_path: Pa def merge_all(self): start_time = datetime.now() - info_list = ['Running merge_cropped_output.py:', + info_list = ['Running merge_cropped_output_HisDB.py:', f'- start_time: \t{start_time:%Y-%m-%d_%H-%M-%S}', f'- datamodule_path: \t{self.datamodule_path}', f'- prediction_path: \t{self.prediction_path}', From 425d1eb5aa4f597f4d9c779500fe853989f988db Mon Sep 17 00:00:00 2001 From: Paul M Date: Fri, 29 Oct 2021 10:27:30 +0200 Subject: [PATCH 032/108] :bug: fixed folder_name problem in merge_cropped_output_HisDB.py --- src/tasks/DivaHisDB/semantic_segmentation.py | 6 +++++- tools/merge_cropped_output_HisDB.py | 21 +++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/tasks/DivaHisDB/semantic_segmentation.py b/src/tasks/DivaHisDB/semantic_segmentation.py index 1d9989a9..8860c9a1 100644 --- a/src/tasks/DivaHisDB/semantic_segmentation.py +++ b/src/tasks/DivaHisDB/semantic_segmentation.py @@ -118,5 +118,9 @@ def on_test_end(self) -> None: prediction_path = (self.test_output_path / 'patches').absolute() output_path = (self.test_output_path / 'result').absolute() + data_folder_name = self.trainer.datamodule.data_folder_name + gt_folder_name = self.trainer.datamodule.gt_folder_name + log.info(f'To run the merging of patches:') - log.info(f'python tools/merge_cropped_output_HisDB.py -d {datamodule_path} -p {prediction_path} -o {output_path}') + log.info(f'python tools/merge_cropped_output_HisDB.py -d {datamodule_path} -p {prediction_path} ' + f'-o {output_path} -df {data_folder_name} -gf {gt_folder_name}') diff --git a/tools/merge_cropped_output_HisDB.py b/tools/merge_cropped_output_HisDB.py index 8c1f7457..5ffead16 100644 --- a/tools/merge_cropped_output_HisDB.py +++ b/tools/merge_cropped_output_HisDB.py @@ -32,7 +32,8 @@ class CropData: class CroppedOutputMerger: - def __init__(self, datamodule_path: Path, prediction_path: Path, output_path: Path, num_threads: int = 10): + def __init__(self, datamodule_path: Path, prediction_path: Path, output_path: Path, + data_folder_name: str, gt_folder_name: str, num_threads: int = 10): # Defaults self.load_only_first_crop_for_size = True # All crops have to be the same size in the current implementation @@ -40,11 +41,17 @@ def __init__(self, datamodule_path: Path, prediction_path: Path, output_path: Pa self.prediction_path = prediction_path self.output_path = output_path - data_module = DivaHisDBDataModuleCropped(data_dir=str(datamodule_path)) + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + + data_module = DivaHisDBDataModuleCropped(data_dir=str(datamodule_path), data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name) self.num_classes = data_module.num_classes self.class_encodings = data_module.class_encodings - img_paths_per_page = CroppedHisDBDataset.get_gt_data_paths(directory=datamodule_path / 'test') + img_paths_per_page = CroppedHisDBDataset.get_gt_data_paths(directory=datamodule_path / 'test', + data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name) dataset_img_name_list = [] self.dataset_dict = defaultdict(list) @@ -309,6 +316,14 @@ def merge_page(self, img_name: str, lock, position): help='Path to the output folder', type=Path, required=True) + parser.add_argument('-df', '--data_folder_name', + help='Name of data folder', + type=str, + required=True) + parser.add_argument('-gf', '--gt_folder_name', + help='Name of gt folder', + type=str, + required=True) parser.add_argument('-n', '--num_threads', help='Number of threads for parallel processing', type=int, From babd5338edeea5a809e60140abe0e8ef290af824 Mon Sep 17 00:00:00 2001 From: Paul M Date: Fri, 29 Oct 2021 10:32:44 +0200 Subject: [PATCH 033/108] :princess: changes for Lars --- ...f_gtD.yaml => synthetic_baby_unet_rolf_layoutD_gtD.yaml} | 6 ++---- .../experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml | 2 -- 2 files changed, 2 insertions(+), 6 deletions(-) rename configs/experiment/{synthetic_baby_unet_rolf_gtD.yaml => synthetic_baby_unet_rolf_layoutD_gtD.yaml} (93%) diff --git a/configs/experiment/synthetic_baby_unet_rolf_gtD.yaml b/configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD.yaml similarity index 93% rename from configs/experiment/synthetic_baby_unet_rolf_gtD.yaml rename to configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD.yaml index a419468b..78b65fe4 100644 --- a/configs/experiment/synthetic_baby_unet_rolf_gtD.yaml +++ b/configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD.yaml @@ -23,8 +23,6 @@ defaults: # instead we define all modules and their paths directly in this config, # so everything is stored in one place for more readibility -seed: 42 - train: True test: True @@ -64,6 +62,6 @@ callbacks: logger: wandb: - name: 'synthetic-baby-unet-rolf-gtD' - tags: [ "best_model", "synthetic", "gtD", "Rolf" ] + name: 'synthetic-baby-unet-rolf-layoutD-gtD' + tags: [ "best_model", "synthetic", "layoutD", "gtD", "Rolf" ] group: 'synthetic' diff --git a/configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml b/configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml index 81ecb71b..1ab54d6a 100644 --- a/configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml +++ b/configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml @@ -23,8 +23,6 @@ defaults: # instead we define all modules and their paths directly in this config, # so everything is stored in one place for more readibility -seed: 42 - train: True test: True From 988cd6485442c02aa31bb0fef97292a2d6bd6ccf Mon Sep 17 00:00:00 2001 From: Paul M Date: Fri, 29 Oct 2021 13:05:06 +0200 Subject: [PATCH 034/108] :white_check_mark: fixed rotation test --- .../RotNet/datasets/test_cropped_dataset.py | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py index 8fb8db5a..3e56ed85 100644 --- a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py +++ b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py @@ -23,33 +23,43 @@ def test__load_data_and_gt(dataset_train): def test__apply_transformation(dataset_train): - img0_o = dataset_train._load_data_and_gt(0) - img1_o = dataset_train._load_data_and_gt(1) - img2_o = dataset_train._load_data_and_gt(2) - img3_o = dataset_train._load_data_and_gt(3) - img4_o = dataset_train._load_data_and_gt(4) + org0 = dataset_train._load_data_and_gt(0) + org1 = dataset_train._load_data_and_gt(1) - img0, gt0 = dataset_train._apply_transformation(img0_o, 0) - assert torch.equal(img0, ToTensor()(img0_o)) - assert gt0 == 0 + img0, gt0 = dataset_train._apply_transformation(org0, 0) + img_index0, gt_index0 = dataset_train[0] + assert torch.equal(img0, img_index0) + assert gt0 == gt_index0 - img1, gt1 = dataset_train._apply_transformation(img1_o, 1) - assert not torch.equal(ToTensor()(img1_o), img1) - assert torch.equal(img1, rotate(img=ToTensor()(img1_o), angle=ROTATION_ANGLES[1])) + img1, gt1 = dataset_train._apply_transformation(org0, 1) + img_index1, gt_index1 = dataset_train[1] + assert not torch.equal(ToTensor()(org0), img1) + assert torch.equal(img1, rotate(img=ToTensor()(org0), angle=ROTATION_ANGLES[1])) + assert torch.equal(img1, img_index1) + assert gt1 == gt_index1 assert gt1 == 1 - img2, gt2 = dataset_train._apply_transformation(img2_o, 2) - assert not torch.equal(ToTensor()(img2_o), img2) - assert torch.equal(img2, rotate(img=ToTensor()(img2_o), angle=ROTATION_ANGLES[2])) + img2, gt2 = dataset_train._apply_transformation(org0, 2) + img_index2, gt_index2 = dataset_train[2] + assert not torch.equal(ToTensor()(org0), img2) + assert torch.equal(img2, rotate(img=ToTensor()(org0), angle=ROTATION_ANGLES[2])) + assert torch.equal(img2, img_index2) + assert gt2 == gt_index2 assert gt2 == 2 - img3, gt3 = dataset_train._apply_transformation(img3_o, 3) - assert not torch.equal(ToTensor()(img3_o), img3) - assert torch.equal(img3, rotate(img=ToTensor()(img3_o), angle=ROTATION_ANGLES[3])) + img3, gt3 = dataset_train._apply_transformation(org0, 3) + img_index3, gt_index3 = dataset_train[3] + assert not torch.equal(ToTensor()(org0), img3) + assert torch.equal(img3, rotate(img=ToTensor()(org0), angle=ROTATION_ANGLES[3])) + assert torch.equal(img3, img_index3) + assert gt3 == gt_index3 assert gt3 == 3 - img4, gt4 = dataset_train._apply_transformation(img4_o, 0) - assert torch.equal(img4, ToTensor()(img4_o)) + img4, gt4 = dataset_train._apply_transformation(org1, 0) + img_index4, gt_index4 = dataset_train[4] + assert torch.equal(img4, ToTensor()(org1)) + assert torch.equal(img4, img_index4) + assert gt4 == gt_index4 assert gt4 == 0 From fed5f5f774eb996b430fbfbb88b677f4275f7c22 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 29 Oct 2021 15:27:18 +0200 Subject: [PATCH 035/108] :sound_loud: now print out the backbone flattened if the header and backbone is not matching as well as logging the test results --- src/callbacks/model_callbacks.py | 2 ++ src/execute.py | 3 ++- src/models/headers/fully_connected.py | 5 ++--- src/models/utils/utils.py | 14 -------------- 4 files changed, 6 insertions(+), 18 deletions(-) delete mode 100644 src/models/utils/utils.py diff --git a/src/callbacks/model_callbacks.py b/src/callbacks/model_callbacks.py index a2fa0ce6..1300285a 100644 --- a/src/callbacks/model_callbacks.py +++ b/src/callbacks/model_callbacks.py @@ -85,6 +85,8 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O except RuntimeError as e: log.error(f'Backbone and Header are not fitting together! Backbone output dimensions {b_output.shape}.' f'Perhaps flatten header input first.') + log.error(f'Output size (first dimension = batch size) of the backbone flattened:' + f' {torch.nn.Flatten()(b_output).shape}') log.error(e) log.error(traceback.format_exc()) sys.exit(1) diff --git a/src/execute.py b/src/execute.py index e0adbe08..17661fb4 100644 --- a/src/execute.py +++ b/src/execute.py @@ -124,7 +124,8 @@ def execute(config: DictConfig) -> Optional[float]: # Evaluate model on test set after training if config.test: log.info("Starting testing!") - trainer.test(model=task, datamodule=datamodule) + results = trainer.test(model=task, datamodule=datamodule) + log.info(f'Test output: {results}') # Make sure everything closed properly log.info("Finalizing!") diff --git a/src/models/headers/fully_connected.py b/src/models/headers/fully_connected.py index 10c7af17..29d3482d 100644 --- a/src/models/headers/fully_connected.py +++ b/src/models/headers/fully_connected.py @@ -1,14 +1,13 @@ +import torch from torch import nn -from src.models.utils.utils import Flatten - class SingleLinear(nn.Module): def __init__(self, num_classes: int = 4, input_size: int = 109512): super(SingleLinear, self).__init__() self.fc = nn.Sequential( - Flatten(), + torch.nn.Flatten(), nn.Linear(input_size, num_classes) ) diff --git a/src/models/utils/utils.py b/src/models/utils/utils.py deleted file mode 100644 index 02df24be..00000000 --- a/src/models/utils/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -from torch import nn - - -class Flatten(nn.Module): - """ - Flatten a convolution block into a simple vector. - - Replaces the flattening line (view) often found into forward() methods of networks. This makes it - easier to navigate the network with introspection - """ - - def forward(self, x): - x = x.view(x.size()[0], -1) - return x From f5feec0cdab599622d4811635b732438ea381c04 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 29 Oct 2021 16:27:23 +0200 Subject: [PATCH 036/108] :sparkles: :wrench: added default resnet architectures without header --- .../dev_rotnet_resnet18_cb55_10.yaml | 75 ++++++++ configs/model/backbone/resnet101.yaml | 1 + configs/model/backbone/resnet152.yaml | 1 + configs/model/backbone/resnet18.yaml | 1 + configs/model/backbone/resnet34.yaml | 1 + configs/model/backbone/resnet50.yaml | 1 + src/models/backbones/resnet.py | 177 ++++++++++++++++++ 7 files changed, 257 insertions(+) create mode 100644 configs/experiment/dev_rotnet_resnet18_cb55_10.yaml create mode 100644 configs/model/backbone/resnet101.yaml create mode 100644 configs/model/backbone/resnet152.yaml create mode 100644 configs/model/backbone/resnet18.yaml create mode 100644 configs/model/backbone/resnet34.yaml create mode 100644 configs/model/backbone/resnet50.yaml create mode 100644 src/models/backbones/resnet.py diff --git a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml new file mode 100644 index 00000000..e2150025 --- /dev/null +++ b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml @@ -0,0 +1,75 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: classification.yaml + - /loss: crossentropyloss.yaml + - /metric: accuracy.yaml + - /model/backbone: resnet18.yaml + - /model/header: null + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +seed: 42 + +train: True +test: False + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 3 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 1 + confusion_matrix_val: False + confusion_matrix_test: False + +datamodule: + _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped + + data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + +model: + header: + _target_: src.models.headers.fully_connected.SingleLinear + + num_classes: ${datamodule:num_classes} + # needs to be calculated from the output of the last layer of the backbone (do not forget to flatten!) + input_size: 2048 + +callbacks: + model_checkpoint: + filename: ${checkpoint_folder_name}dev-rotnet-basic-cnn-cb55-10 + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'dev-rotnet-basic-cnn-cb55-10' + tags: [ "best_model", "USL" ] + group: 'dev-runs' + notes: "Testing" diff --git a/configs/model/backbone/resnet101.yaml b/configs/model/backbone/resnet101.yaml new file mode 100644 index 00000000..b84b3e3a --- /dev/null +++ b/configs/model/backbone/resnet101.yaml @@ -0,0 +1 @@ +_target_: src.models.backbones.resnet.ResNet101 \ No newline at end of file diff --git a/configs/model/backbone/resnet152.yaml b/configs/model/backbone/resnet152.yaml new file mode 100644 index 00000000..476a9435 --- /dev/null +++ b/configs/model/backbone/resnet152.yaml @@ -0,0 +1 @@ +_target_: src.models.backbones.resnet.ResNet152 \ No newline at end of file diff --git a/configs/model/backbone/resnet18.yaml b/configs/model/backbone/resnet18.yaml new file mode 100644 index 00000000..9cedcf6e --- /dev/null +++ b/configs/model/backbone/resnet18.yaml @@ -0,0 +1 @@ +_target_: src.models.backbones.resnet.ResNet18 \ No newline at end of file diff --git a/configs/model/backbone/resnet34.yaml b/configs/model/backbone/resnet34.yaml new file mode 100644 index 00000000..bc80debd --- /dev/null +++ b/configs/model/backbone/resnet34.yaml @@ -0,0 +1 @@ +_target_: src.models.backbones.resnet.ResNet34 \ No newline at end of file diff --git a/configs/model/backbone/resnet50.yaml b/configs/model/backbone/resnet50.yaml new file mode 100644 index 00000000..de337ed4 --- /dev/null +++ b/configs/model/backbone/resnet50.yaml @@ -0,0 +1 @@ +_target_: src.models.backbones.resnet.ResNet50 \ No newline at end of file diff --git a/src/models/backbones/resnet.py b/src/models/backbones/resnet.py new file mode 100644 index 00000000..fb90522e --- /dev/null +++ b/src/models/backbones/resnet.py @@ -0,0 +1,177 @@ +""" +Model definition adapted from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +""" +import logging +import math + +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class _BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(_BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class _Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(_Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, **kwargs): + self.inplanes = 64 + super(ResNet, self).__init__() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + + return x + + +class ResNet18(ResNet): + def __init__(self, **kwargs): + super(ResNet18, self).__init__(_BasicBlock, [2, 2, 2, 2], **kwargs) + + +class ResNet34(ResNet): + def __init__(self,**kwargs): + super(ResNet34, self).__init__(_BasicBlock, [3, 4, 6, 3], **kwargs) + + +class ResNet50(ResNet): + def __init__(self, **kwargs): + super(ResNet50, self).__init__(_Bottleneck, [3, 4, 6, 3], **kwargs) + + +class ResNet101(ResNet): + def __init__(self, **kwargs): + super(ResNet101, self).__init__(_Bottleneck, [3, 4, 23, 3], **kwargs) + + +class ResNet152(ResNet): + def __init__(self, **kwargs): + super(ResNet152, self).__init__(_Bottleneck, [3, 8, 36, 3], **kwargs) From ed7d20887efc33e9981e67cab475c5d9fb24c5e2 Mon Sep 17 00:00:00 2001 From: Paul M Date: Fri, 29 Oct 2021 16:59:12 +0200 Subject: [PATCH 037/108] :white_check_mark: fixed rotation test --- .../DivaHisDB/datamodule_cropped.py | 18 +- .../DivaHisDB/utils/image_analytics.py | 90 +++++-- src/datamodules/RGB/utils/image_analytics.py | 1 - src/datamodules/RotNet/datamodule_cropped.py | 18 +- .../RotNet/datasets/cropped_dataset.py | 7 +- .../RotNet/utils/image_analytics.py | 64 +++-- src/datamodules/RotNet/utils/misc.py | 2 +- .../datasets/test_cropped_hisdb_dataset.py | 249 +++++++++--------- .../DivaHisDB/test_image_analytics.py | 39 ++- .../RotNet/datasets/test_cropped_dataset.py | 35 ++- .../RotNet/test_datamodule_cropped.py | 8 +- tests/datamodules/RotNet/test_misc.py | 5 +- 12 files changed, 303 insertions(+), 233 deletions(-) diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index 48a8819d..40c42ff5 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -28,16 +28,16 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.data_folder_name = data_folder_name self.gt_folder_name = gt_folder_name - analytics = get_analytics(input_path=Path(data_dir), - data_folder_name=self.data_folder_name, - gt_folder_name=self.gt_folder_name, - get_gt_data_paths_func=CroppedHisDBDataset.get_gt_data_paths) - - self.mean = analytics['mean'] - self.std = analytics['std'] - self.class_encodings = analytics['class_encodings'] + analytics_data, analytics_gt = get_analytics(input_path=Path(data_dir), + data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name, + get_gt_data_paths_func=CroppedHisDBDataset.get_gt_data_paths) + + self.mean = analytics_data['mean'] + self.std = analytics_data['std'] + self.class_encodings = analytics_gt['class_encodings'] self.num_classes = len(self.class_encodings) - self.class_weights = analytics['class_weights'] + self.class_weights = analytics_gt['class_weights'] self.twin_transform = TwinRandomCrop(crop_size=crop_size) self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), diff --git a/src/datamodules/DivaHisDB/utils/image_analytics.py b/src/datamodules/DivaHisDB/utils/image_analytics.py index 46a7b97a..87042a59 100644 --- a/src/datamodules/DivaHisDB/utils/image_analytics.py +++ b/src/datamodules/DivaHisDB/utils/image_analytics.py @@ -24,37 +24,72 @@ def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, Returns ------- """ - analytics_file_path = input_path / 'analytics.json' - if analytics_file_path.exists(): - with analytics_file_path.open(mode='r') as f: - analytics_dict = json.load(fp=f) - else: + expected_keys_data = ['mean', 'std'] + expected_keys_gt = ['class_weights', 'class_encodings'] + + analytics_path_data = input_path / f'analytics.data.{data_folder_name}.json' + analytics_path_gt = input_path / f'analytics.gt.hisDB.{gt_folder_name}.json' + + analytics_data = None + analytics_gt = None + + missing_analytics_data = True + missing_analytics_gt = True + + if analytics_path_data.exists(): + with analytics_path_data.open(mode='r') as f: + analytics_data = json.load(fp=f) + # check if analytics file is complete + if all(k in analytics_data for k in expected_keys_data): + missing_analytics_data = False + + if analytics_path_gt.exists(): + with analytics_path_gt.open(mode='r') as f: + analytics_gt = json.load(fp=f) + # check if analytics file is complete + if all(k in analytics_gt for k in expected_keys_gt): + missing_analytics_gt = False + + if missing_analytics_data or missing_analytics_gt: train_path = input_path / 'train' gt_data_path_list = get_gt_data_paths_func(train_path, data_folder_name=data_folder_name, gt_folder_name=gt_folder_name) file_names_data = np.asarray([str(item[0]) for item in gt_data_path_list]) file_names_gt = np.asarray([str(item[1]) for item in gt_data_path_list]) - mean, std = compute_mean_std(file_names=file_names_data, **kwargs) - - # Measure weights for class balancing - logging.info(f'Measuring class weights') - # create a list with all gt file paths - class_weights, class_encodings = _get_class_frequencies_weights_segmentation(gt_images=file_names_gt, **kwargs) - analytics_dict = {'mean': mean.tolist(), - 'std': std.tolist(), - 'class_weights': class_weights.tolist(), - 'class_encodings': class_encodings.tolist()} - # save json - try: - with analytics_file_path.open(mode='w') as f: - json.dump(obj=analytics_dict, fp=f) - except IOError as e: - if e.errno == errno.EACCES: - print(f'WARNING: No permissions to write analytics file ({analytics_file_path})') - else: - raise - # returns the 'mean[RGB]', 'std[RGB]', 'class_frequencies_weights[num_classes]', 'class_encodings' - return analytics_dict + + if missing_analytics_data: + mean, std = compute_mean_std(file_names=file_names_data, **kwargs) + analytics_data = {'mean': mean.tolist(), + 'std': std.tolist()} + # save json + try: + with analytics_path_data.open(mode='w') as f: + json.dump(obj=analytics_data, fp=f) + except IOError as e: + if e.errno == errno.EACCES: + print(f'WARNING: No permissions to write analytics file ({analytics_path_data})') + else: + raise + + if missing_analytics_gt: + # Measure weights for class balancing + logging.info(f'Measuring class weights') + # create a list with all gt file paths + class_weights, class_encodings = _get_class_frequencies_weights_segmentation_hisdb(gt_images=file_names_gt, + **kwargs) + analytics_gt = {'class_weights': class_weights.tolist(), + 'class_encodings': class_encodings.tolist()} + # save json + try: + with analytics_path_gt.open(mode='w') as f: + json.dump(obj=analytics_gt, fp=f) + except IOError as e: + if e.errno == errno.EACCES: + print(f'WARNING: No permissions to write analytics file ({analytics_path_gt})') + else: + raise + + return analytics_data, analytics_gt def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): @@ -295,7 +330,7 @@ def get_class_weights_graphs(dataset, **kwargs): return class_weights -def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): +def _get_class_frequencies_weights_segmentation_hisdb(gt_images, **kwargs): """ Get the weights proportional to the inverse of their class frequencies. The vector sums up to 1 @@ -333,5 +368,4 @@ def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): if __name__ == '__main__': - # print(get_analytics(input_path=Path('/netscratch/datasets/semantic_segmentation/datasets/CB55/'), inmem=True, workers=16)) print(get_analytics(input_path=Path('tests/dummy_data/dummy_dataset'))) diff --git a/src/datamodules/RGB/utils/image_analytics.py b/src/datamodules/RGB/utils/image_analytics.py index bcf50254..839636a0 100644 --- a/src/datamodules/RGB/utils/image_analytics.py +++ b/src/datamodules/RGB/utils/image_analytics.py @@ -371,5 +371,4 @@ def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): if __name__ == '__main__': - # print(get_analytics(input_path=Path('/netscratch/datasets/semantic_segmentation/datasets/CB55/'), inmem=True, workers=16)) print(get_analytics(input_path=Path('tests/dummy_data/dummy_dataset'))) diff --git a/src/datamodules/RotNet/datamodule_cropped.py b/src/datamodules/RotNet/datamodule_cropped.py index 7f2604e1..a1dfb2bd 100644 --- a/src/datamodules/RotNet/datamodule_cropped.py +++ b/src/datamodules/RotNet/datamodule_cropped.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader from torchvision import transforms -from src.datamodules.RotNet.utils.image_analytics import get_analytics +from src.datamodules.RotNet.utils.image_analytics import get_analytics_data from src.datamodules.RotNet.datasets.cropped_dataset import CroppedRotNet, ROTATION_ANGLES from src.datamodules.RotNet.utils.misc import validate_path_for_self_supervised from src.datamodules.RotNet.utils.wrapper_transforms import OnlyImage @@ -16,7 +16,7 @@ class RotNetDivaHisDBDataModuleCropped(AbstractDatamodule): - def __init__(self, data_dir: str = None, data_folder_name: str = 'data', + def __init__(self, data_dir: str, data_folder_name: str, selection_train: Optional[Union[int, List[str]]] = None, selection_val: Optional[Union[int, List[str]]] = None, selection_test: Optional[Union[int, List[str]]] = None, @@ -24,14 +24,15 @@ def __init__(self, data_dir: str = None, data_folder_name: str = 'data', shuffle: bool = True, drop_last: bool = True): super().__init__() - analytics = get_analytics(input_path=Path(data_dir), - get_gt_data_paths_func=CroppedRotNet.get_gt_data_paths) + self.data_folder_name = data_folder_name + analytics_data = get_analytics_data(input_path=Path(data_dir), data_folder_name=self.data_folder_name, + get_gt_data_paths_func=CroppedRotNet.get_gt_data_paths) - self.mean = analytics['mean'] - self.std = analytics['std'] + self.mean = analytics_data['mean'] + self.std = analytics_data['std'] self.class_encodings = np.array(ROTATION_ANGLES) self.num_classes = len(self.class_encodings) - self.class_weights = np.ones(self.num_classes) + self.class_weights = np.array([1 / self.num_classes for _ in range(self.num_classes)]) self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), @@ -43,7 +44,6 @@ def __init__(self, data_dir: str = None, data_folder_name: str = 'data', self.shuffle = shuffle self.drop_last = drop_last - self.data_folder_name = data_folder_name self.data_dir = validate_path_for_self_supervised(data_dir=data_dir, data_folder_name=self.data_folder_name) self.selection_train = selection_train @@ -113,7 +113,7 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] def _create_dataset_parameters(self, dataset_type: str = 'train'): is_test = dataset_type == 'test' return {'path': self.data_dir / dataset_type, + 'data_folder_name': self.data_folder_name, 'image_transform': self.image_transform, 'classes': self.class_encodings, 'is_test': is_test} - diff --git a/src/datamodules/RotNet/datasets/cropped_dataset.py b/src/datamodules/RotNet/datasets/cropped_dataset.py index b53a3fb4..55602e00 100644 --- a/src/datamodules/RotNet/datasets/cropped_dataset.py +++ b/src/datamodules/RotNet/datasets/cropped_dataset.py @@ -35,7 +35,7 @@ class CroppedRotNet(CroppedHisDBDataset): root/data/xxz.png """ - def __init__(self, path: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str = None, selection: Optional[Union[int, List[str]]] = None, is_test=False, image_transform=None, **kwargs): """ @@ -115,7 +115,7 @@ def _apply_transformation(self, img, index): return img, target_class @staticmethod - def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder_name: str = 'gt', + def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str = None, selection: Optional[Union[int, List[str]]] = None) \ -> List[Path]: """ @@ -134,9 +134,8 @@ def get_gt_data_paths(directory: Path, data_folder_name: str = 'data', gt_folder directory = directory.expanduser() path_data_root = directory / data_folder_name - path_gt_root = directory / gt_folder_name - if not (path_data_root.is_dir() or path_gt_root.is_dir()): + if not (path_data_root.is_dir()): log.error("folder data or gt not found in " + str(directory)) # get all subitems (and files) sorted diff --git a/src/datamodules/RotNet/utils/image_analytics.py b/src/datamodules/RotNet/utils/image_analytics.py index 3bcc57f7..750733ac 100644 --- a/src/datamodules/RotNet/utils/image_analytics.py +++ b/src/datamodules/RotNet/utils/image_analytics.py @@ -15,7 +15,7 @@ from PIL import Image -def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): +def get_analytics_data(input_path: Path, data_folder_name: str, get_gt_data_paths_func, **kwargs): """ Parameters ---------- @@ -24,32 +24,41 @@ def get_analytics(input_path: Path, get_gt_data_paths_func, **kwargs): Returns ------- """ - analytics_file_path = input_path / 'analytics.json' - if analytics_file_path.exists(): - with analytics_file_path.open(mode='r') as f: - analytics_dict = json.load(fp=f) - else: + expected_keys_data = ['mean', 'std'] + + analytics_path_data = input_path / f'analytics.data.{data_folder_name}.json' + + analytics_data = None + + missing_analytics_data = True + + if analytics_path_data.exists(): + with analytics_path_data.open(mode='r') as f: + analytics_data = json.load(fp=f) + # check if analytics file is complete + if all(k in analytics_data for k in expected_keys_data): + missing_analytics_data = False + + if missing_analytics_data: train_path = input_path / 'train' - data_path_list = get_gt_data_paths_func(train_path) - file_names_data = np.asarray([str(item) for item in data_path_list]) - mean, std = compute_mean_std(file_names=file_names_data, **kwargs) - - # Measure weights for class balancing - logging.info(f'Measuring class weights') - # create a list with all gt file paths - analytics_dict = {'mean': mean.tolist(), - 'std': std.tolist()} - # save json - try: - with analytics_file_path.open(mode='w') as f: - json.dump(obj=analytics_dict, fp=f) - except IOError as e: - if e.errno == errno.EACCES: - print(f'WARNING: No permissions to write analytics file ({analytics_file_path})') - else: - raise - # returns the 'mean[RGB]', 'std[RGB]', 'class_frequencies_weights[num_classes]', 'class_encodings' - return analytics_dict + gt_data_path_list = get_gt_data_paths_func(train_path, data_folder_name=data_folder_name, gt_folder_name=None) + file_names_data = np.asarray([str(item) for item in gt_data_path_list]) + + if missing_analytics_data: + mean, std = compute_mean_std(file_names=file_names_data, **kwargs) + analytics_data = {'mean': mean.tolist(), + 'std': std.tolist()} + # save json + try: + with analytics_path_data.open(mode='w') as f: + json.dump(obj=analytics_data, fp=f) + except IOError as e: + if e.errno == errno.EACCES: + print(f'WARNING: No permissions to write analytics file ({analytics_path_data})') + else: + raise + + return analytics_data def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): @@ -328,5 +337,4 @@ def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): if __name__ == '__main__': - # print(get_analytics(input_path=Path('/netscratch/datasets/semantic_segmentation/datasets/CB55/'), inmem=True, workers=16)) - print(get_analytics(input_path=Path('tests/dummy_data/dummy_dataset'))) + print(get_analytics_data(input_path=Path('tests/dummy_data/dummy_dataset'))) diff --git a/src/datamodules/RotNet/utils/misc.py b/src/datamodules/RotNet/utils/misc.py index 78688046..51655d2d 100644 --- a/src/datamodules/RotNet/utils/misc.py +++ b/src/datamodules/RotNet/utils/misc.py @@ -48,7 +48,7 @@ def convert_to_rgb(pic): return pic -def validate_path_for_self_supervised(data_dir, data_folder_name: str = 'data'): +def validate_path_for_self_supervised(data_dir, data_folder_name: str): if data_dir is None: raise PathNone("Please provide the path to root dir of the dataset " "(folder containing the train/val/test folder)") diff --git a/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py b/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py index 014a1695..0dcd77ff 100644 --- a/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py +++ b/tests/datamodules/DivaHisDB/datasets/test_cropped_hisdb_dataset.py @@ -6,20 +6,24 @@ from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped +DATA_FOLDER_NAME = 'data' +GT_FOLDER_NAME = 'gt' +DATASET_PREFIX = 'e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max' + @pytest.fixture def dataset_train(data_dir_cropped): - return CroppedHisDBDataset(path=data_dir_cropped / 'train', data_folder_name='data', gt_folder_name='gt') + return CroppedHisDBDataset(path=data_dir_cropped / f'train', data_folder_name='data', gt_folder_name='gt') @pytest.fixture def dataset_val(data_dir_cropped): - return CroppedHisDBDataset(path=data_dir_cropped / 'val', data_folder_name='data', gt_folder_name='gt') + return CroppedHisDBDataset(path=data_dir_cropped / f'val', data_folder_name='data', gt_folder_name='gt') @pytest.fixture def dataset_test(data_dir_cropped): - return CroppedHisDBDataset(path=data_dir_cropped / 'test', data_folder_name='data', gt_folder_name='gt') + return CroppedHisDBDataset(path=data_dir_cropped / f'test', data_folder_name='data', gt_folder_name='gt') def test__load_data_and_gt(dataset_train): @@ -52,168 +56,169 @@ def test__get_train_val_items_test(dataset_test): def test_dataset_train_selection_int_error(data_dir_cropped): with pytest.raises(ValueError): - CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train', + CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / f'train', data_folder_name='data', gt_folder_name='gt', selection=2) def test_dataset_train_selection_int(data_dir_cropped, get_train_file_names): - files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train', + files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / f'train', data_folder_name='data', gt_folder_name='gt', selection=1) assert len(files_from_method) == 12 assert files_from_method == get_train_file_names def test_get_gt_data_paths_train(data_dir_cropped, get_train_file_names): - files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'train', + files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / f'train', data_folder_name='data', gt_folder_name='gt') assert len(files_from_method) == 12 assert files_from_method == get_train_file_names def test_get_gt_data_paths_val(data_dir_cropped): - files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'val', + files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / f'val', data_folder_name='data', gt_folder_name='gt') - expected_result = [(PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0000', (0, 0)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0150', (0, 150)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0187', (0, 187)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0000', (150, 0)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0150', (150, 150)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0187', (150, 187)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0000', (300, 0)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0150', (300, 150)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0187', (300, 187)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0000', (349, 0)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0150', (349, 150)), ( - PosixPath( - data_dir_cropped / 'val/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), - PosixPath( - data_dir_cropped / 'val/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), - 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0187', (349, 187))] + expected_result = [ + (PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0000_y0000.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0000_y0000.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0000', (0, 0)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0000_y0150.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0000_y0150.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0150', (0, 150)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0000_y0187.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0000_y0187.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0187', (0, 187)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0150_y0000.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0150_y0000.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0000', (150, 0)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0150_y0150.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0150_y0150.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0150', (150, 150)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0150_y0187.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0150_y0187.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0187', (150, 187)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0300_y0000.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0300_y0000.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0000', (300, 0)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0300_y0150.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0300_y0150.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0150', (300, 150)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0300_y0187.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0300_y0187.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0187', (300, 187)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0349_y0000.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0349_y0000.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0000', (349, 0)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0349_y0150.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0349_y0150.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0150', (349, 150)), ( + PosixPath( + data_dir_cropped / f'val/data/{DATASET_PREFIX}_x0349_y0187.png'), + PosixPath( + data_dir_cropped / f'val/gt/{DATASET_PREFIX}_x0349_y0187.png'), + 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0187', (349, 187))] assert len(files_from_method) == 12 assert files_from_method == expected_result def test_get_gt_data_paths_test(data_dir_cropped): - files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / 'test', + files_from_method = CroppedHisDBDataset.get_gt_data_paths(directory=data_dir_cropped / f'test', data_folder_name='data', gt_folder_name='gt') expected_result = [(PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0000_y0000.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0000_y0000.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0000', (0, 0)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0128.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0000_y0128.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0128.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0000_y0128.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0128', (0, 128)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0231.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0000_y0231.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0231.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0000_y0231.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0231', (0, 231)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0128_y0000.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0128_y0000.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0128_y0000.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0128_y0000.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0128_y0000', (128, 0)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0128_y0128.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0128_y0128.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0128_y0128.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0128_y0128.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0128_y0128', (128, 128)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0128_y0231.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0128_y0231.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0128_y0231.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0128_y0231.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0128_y0231', (128, 231)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0256_y0000.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0256_y0000.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0256_y0000.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0256_y0000.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0256_y0000', (256, 0)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0256_y0128.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0256_y0128.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0256_y0128.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0256_y0128.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0256_y0128', (256, 128)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0256_y0231.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0256_y0231.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0256_y0231.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0256_y0231.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0256_y0231', (256, 231)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0384_y0000.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0384_y0000.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0384_y0000.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0384_y0000.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0384_y0000', (384, 0)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0384_y0128.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0384_y0128.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0384_y0128.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0384_y0128.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0384_y0128', (384, 128)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0384_y0231.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0384_y0231.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0384_y0231.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0384_y0231.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0384_y0231', (384, 231)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0393_y0000.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0393_y0000.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0393_y0000.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0393_y0000.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0393_y0000', (393, 0)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0393_y0128.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0393_y0128.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0393_y0128.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0393_y0128.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0393_y0128', (393, 128)), ( PosixPath( - data_dir_cropped / 'test/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0393_y0231.png'), + data_dir_cropped / f'test/data/{DATASET_PREFIX}_x0393_y0231.png'), PosixPath( - data_dir_cropped / 'test/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0393_y0231.png'), + data_dir_cropped / f'test/gt/{DATASET_PREFIX}_x0393_y0231.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0393_y0231', (393, 231))] assert len(files_from_method) == 15 assert files_from_method == expected_result @@ -222,62 +227,62 @@ def test_get_gt_data_paths_test(data_dir_cropped): @pytest.fixture def get_train_file_names(data_dir_cropped): return [(PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0000_y0000.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0000_y0000.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0000', (0, 0)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0000_y0150.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0000_y0150.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0150', (0, 150)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0000_y0187.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0000_y0187.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0000_y0187', (0, 187)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0150_y0000.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0150_y0000.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0000', (150, 0)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0150_y0150.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0150_y0150.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0150', (150, 150)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0150_y0187.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0150_y0187.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0150_y0187', (150, 187)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0300_y0000.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0300_y0000.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0000', (300, 0)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0300_y0150.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0300_y0150.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0150', (300, 150)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0300_y0187.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0300_y0187.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0300_y0187', (300, 187)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0349_y0000.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0349_y0000.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0000', (349, 0)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0349_y0150.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0349_y0150.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0150', (349, 150)), ( PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0349_y0187.png'), PosixPath( - data_dir_cropped / 'train/gt/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), + data_dir_cropped / f'train/gt/{DATASET_PREFIX}_x0349_y0187.png'), 'e-codices_fmb-cb-0055_0098v_max', 'e-codices_fmb-cb-0055_0098v_max_x0349_y0187', (349, 187))] diff --git a/tests/datamodules/DivaHisDB/test_image_analytics.py b/tests/datamodules/DivaHisDB/test_image_analytics.py index c8b7bc83..653882b5 100644 --- a/tests/datamodules/DivaHisDB/test_image_analytics.py +++ b/tests/datamodules/DivaHisDB/test_image_analytics.py @@ -6,27 +6,40 @@ from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped -TEST_JSON = {'mean': [0.7050454974582426, 0.6503181590413943, 0.5567698583877997], - 'std': [0.3104060859619883, 0.3053311838884032, 0.28919611393432726], - 'class_weights': [0.004952207651647859, 0.07424270397485577, 0.8964025044572563, 0.02440258391624002], - 'class_encodings': [1, 2, 4, 8]} +TEST_JSON_DATA = {'mean': [0.7050454974582426, 0.6503181590413943, 0.5567698583877997], + 'std': [0.3104060859619883, 0.3053311838884032, 0.28919611393432726]} + +TEST_JSON_GT = {'class_weights': [0.004952207651647859, 0.07424270397485577, 0.8964025044572563, 0.02440258391624002], + 'class_encodings': [1, 2, 4, 8]} + +DATA_FOLDER_NAME = 'data' +GT_FOLDER_NAME = 'gt' +DATA_ANALYTICS_FILENAME = f'analytics.data.{DATA_FOLDER_NAME}.json' +GT_ANALYTICS_FILENAME = f'analytics.gt.hisDB.{GT_FOLDER_NAME}.json' def test_get_analytics_no_file(data_dir_cropped): - output = get_analytics(input_path=data_dir_cropped, data_folder_name='data', gt_folder_name='gt', - get_gt_data_paths_func=CroppedHisDBDataset.get_gt_data_paths) + analytics_data, analytics_gt = get_analytics(input_path=data_dir_cropped, + data_folder_name=DATA_FOLDER_NAME, gt_folder_name=GT_FOLDER_NAME, + get_gt_data_paths_func=CroppedHisDBDataset.get_gt_data_paths) - assert np.array_equal(np.round(TEST_JSON['mean'], 8), np.round(output['mean'], 8)) - assert np.array_equal(np.round(TEST_JSON['std'], 8), np.round(output['std'], 8)) - assert np.array_equal(np.round(TEST_JSON['class_weights'], 8), np.round(output['class_weights'], 8)) - assert np.array_equal(TEST_JSON['class_encodings'], output['class_encodings']) - assert (data_dir_cropped / 'analytics.json').exists() + assert np.array_equal(np.round(TEST_JSON_DATA['mean'], 8), np.round(analytics_data['mean'], 8)) + assert np.array_equal(np.round(TEST_JSON_DATA['std'], 8), np.round(analytics_data['std'], 8)) + assert np.array_equal(np.round(TEST_JSON_GT['class_weights'], 8), np.round(analytics_gt['class_weights'], 8)) + assert np.array_equal(TEST_JSON_GT['class_encodings'], analytics_gt['class_encodings']) + assert (data_dir_cropped / DATA_ANALYTICS_FILENAME).exists() + assert (data_dir_cropped / GT_ANALYTICS_FILENAME).exists() def test_get_analytics_load_from_file(data_dir_cropped): - analytics_path = data_dir_cropped / 'analytics.json' + analytics_path = data_dir_cropped / DATA_ANALYTICS_FILENAME + with analytics_path.open(mode='w') as f: + json.dump(obj=TEST_JSON_DATA, fp=f) + assert analytics_path.exists() + + analytics_path = data_dir_cropped / GT_ANALYTICS_FILENAME with analytics_path.open(mode='w') as f: - json.dump(obj=TEST_JSON, fp=f) + json.dump(obj=TEST_JSON_GT, fp=f) assert analytics_path.exists() test_get_analytics_no_file(data_dir_cropped=data_dir_cropped) diff --git a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py index 3e56ed85..c27d3814 100644 --- a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py +++ b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py @@ -9,10 +9,15 @@ from src.datamodules.RotNet.datasets.cropped_dataset import CroppedRotNet, ROTATION_ANGLES from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped +DATA_FOLDER_NAME = 'data' +GT_FOLDER_NAME = None +DATASET_PREFIX = 'e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max' + @pytest.fixture def dataset_train(data_dir_cropped): - return CroppedRotNet(path=data_dir_cropped / 'train') + return CroppedRotNet(path=data_dir_cropped / 'train', + data_folder_name=DATA_FOLDER_NAME) def test__load_data_and_gt(dataset_train): @@ -64,32 +69,34 @@ def test__apply_transformation(dataset_train): def test_get_gt_data_paths(data_dir_cropped): - file_paths = CroppedRotNet.get_gt_data_paths(directory=data_dir_cropped / 'train') + file_paths = CroppedRotNet.get_gt_data_paths(directory=data_dir_cropped / 'train', + data_folder_name=DATA_FOLDER_NAME) + expected_result = [ PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0000.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0000_y0000.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0150.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0000_y0150.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0000_y0187.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0000_y0187.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0000.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0150_y0000.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0150.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0150_y0150.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0150_y0187.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0150_y0187.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0000.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0300_y0000.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0150.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0300_y0150.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0300_y0187.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0300_y0187.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0000.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0349_y0000.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0150.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0349_y0150.png'), PosixPath( - data_dir_cropped / 'train/data/e-codices_fmb-cb-0055_0098v_max/e-codices_fmb-cb-0055_0098v_max_x0349_y0187.png'), + data_dir_cropped / f'train/data/{DATASET_PREFIX}_x0349_y0187.png'), ] assert len(file_paths) == len(expected_result) assert file_paths == expected_result diff --git a/tests/datamodules/RotNet/test_datamodule_cropped.py b/tests/datamodules/RotNet/test_datamodule_cropped.py index 4204503b..64306e81 100644 --- a/tests/datamodules/RotNet/test_datamodule_cropped.py +++ b/tests/datamodules/RotNet/test_datamodule_cropped.py @@ -6,12 +6,14 @@ from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped NUM_WORKERS = 4 - +DATA_FOLDER_NAME = 'data' @pytest.fixture def data_module_cropped(data_dir_cropped): OmegaConf.clear_resolvers() - return RotNetDivaHisDBDataModuleCropped(data_dir_cropped, num_workers=NUM_WORKERS) + return RotNetDivaHisDBDataModuleCropped(data_dir=data_dir_cropped, + data_folder_name=DATA_FOLDER_NAME, + num_workers=NUM_WORKERS) def test_init_datamodule(data_module_cropped): @@ -20,7 +22,7 @@ def test_init_datamodule(data_module_cropped): assert data_module_cropped.dims == (3, 256, 256) assert data_module_cropped.num_classes == 4 assert np.array_equal(data_module_cropped.class_encodings, [0, 90, 180, 270]) - assert np.array_equal(data_module_cropped.class_weights, [1., 1., 1., 1.]) + assert np.array_equal(data_module_cropped.class_weights, [.25, .25, .25, .25]) assert data_module_cropped.mean == [0.7050454974582426, 0.6503181590413943, 0.5567698583877997] assert data_module_cropped.std == [0.3104060859619883, 0.3053311838884032, 0.28919611393432726] with pytest.raises(AttributeError): diff --git a/tests/datamodules/RotNet/test_misc.py b/tests/datamodules/RotNet/test_misc.py index d9d163d4..21abd4d5 100644 --- a/tests/datamodules/RotNet/test_misc.py +++ b/tests/datamodules/RotNet/test_misc.py @@ -1,6 +1,9 @@ from src.datamodules.RotNet.utils.misc import validate_path_for_self_supervised from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped +DATA_FOLDER_NAME = 'data' + def test_validate_path_for_self_supervised(data_dir_cropped): - assert data_dir_cropped == validate_path_for_self_supervised(data_dir_cropped) + assert data_dir_cropped == validate_path_for_self_supervised(data_dir=data_dir_cropped, + data_folder_name=DATA_FOLDER_NAME) From ac67da89c80652285eb621a7ec04718607e571c0 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 29 Oct 2021 17:37:58 +0200 Subject: [PATCH 038/108] :wrench: rotnet resnet 18 experiment --- configs/experiment/rotnet_resnet18_cb55.yaml | 69 ++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 configs/experiment/rotnet_resnet18_cb55.yaml diff --git a/configs/experiment/rotnet_resnet18_cb55.yaml b/configs/experiment/rotnet_resnet18_cb55.yaml new file mode 100644 index 00000000..28897c46 --- /dev/null +++ b/configs/experiment/rotnet_resnet18_cb55.yaml @@ -0,0 +1,69 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: classification.yaml + - /loss: crossentropyloss.yaml + - /metric: accuracy.yaml + - /model/backbone: resnet18.yaml + - /model/header: single_layer.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +train: True +test: False + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 200 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 20 + confusion_matrix_val: False + confusion_matrix_test: False + +datamodule: + _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped + + data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + +model: + header: + input_size: 2048 + + +callbacks: + model_checkpoint: + filename: ${checkpoint_folder_name}rotnet-resnet18-cb55 + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'rotnet-resnet18-cb55' + tags: [ "best_model", "USL"] + group: 'rotnet-baseline' From 54508992765fa2c6e3c772ddc9154c27dde6c103 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Sun, 31 Oct 2021 10:54:21 +0100 Subject: [PATCH 039/108] :wrench: rotnet resnet 18 experiment with last 10 train --- ...55.yaml => rotnet_resnet18_cb55_full.yaml} | 0 .../rotnet_resnet18_cb55_train10_last.yaml | 80 +++++++++++++++++++ 2 files changed, 80 insertions(+) rename configs/experiment/{rotnet_resnet18_cb55.yaml => rotnet_resnet18_cb55_full.yaml} (100%) create mode 100644 configs/experiment/rotnet_resnet18_cb55_train10_last.yaml diff --git a/configs/experiment/rotnet_resnet18_cb55.yaml b/configs/experiment/rotnet_resnet18_cb55_full.yaml similarity index 100% rename from configs/experiment/rotnet_resnet18_cb55.yaml rename to configs/experiment/rotnet_resnet18_cb55_full.yaml diff --git a/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml b/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml new file mode 100644 index 00000000..610a3341 --- /dev/null +++ b/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml @@ -0,0 +1,80 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: classification.yaml + - /loss: crossentropyloss.yaml + - /metric: accuracy.yaml + - /model/backbone: resnet18.yaml + - /model/header: single_layer.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +train: True +test: False + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 50 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 20 + confusion_matrix_val: False + confusion_matrix_test: False + +datamodule: + _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped + + data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + selection_train: + - e-codices_fmb-cb-0055_0058v_max + - e-codices_fmb-cb-0055_0059r_max + - e-codices_fmb-cb-0055_0059v_max + - e-codices_fmb-cb-0055_0060r_max + - e-codices_fmb-cb-0055_0067v_max + - e-codices_fmb-cb-0055_0072v_max + - e-codices_fmb-cb-0055_0073v_max + - e-codices_fmb-cb-0055_0074v_max + - e-codices_fmb-cb-0055_0084r_max + - e-codices_fmb-cb-0055_0087r_max + +model: + header: + input_size: 2048 + + +callbacks: + model_checkpoint: + filename: ${checkpoint_folder_name}rotnet-resnet18-cb55 + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'rotnet-resnet18-cb55' + tags: [ "best_model", "USL"] + group: 'rotnet-baseline' From ad3b137820f39f6a856ad4364ff0d820fa2398a2 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Sun, 31 Oct 2021 11:16:59 +0100 Subject: [PATCH 040/108] :wrench: updated config --- configs/experiment/rotnet_resnet18_cb55_train10_last.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml b/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml index 610a3341..91b0e49b 100644 --- a/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml @@ -31,9 +31,10 @@ trainer: gpus: -1 accelerator: 'ddp' min_epochs: 1 - max_epochs: 50 + max_epochs: 100 weights_summary: full precision: 16 + check_val_every_n_epoch: 5 task: confusion_matrix_log_every_n_epoch: 20 @@ -46,7 +47,7 @@ datamodule: data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 num_workers: 4 - batch_size: 16 + batch_size: 32 shuffle: True drop_last: True data_folder_name: data @@ -75,6 +76,6 @@ callbacks: logger: wandb: - name: 'rotnet-resnet18-cb55' + name: 'rotnet-resnet18-cb55-last10-train' tags: [ "best_model", "USL"] group: 'rotnet-baseline' From bcbaf10a91a23e0e64c430f7411f1a9ad0d673cf Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Sun, 31 Oct 2021 14:57:05 +0100 Subject: [PATCH 041/108] :wrench: config for the last 19 pages in train --- .../rotnet_resnet18_cb55_train19_last.yaml | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 configs/experiment/rotnet_resnet18_cb55_train19_last.yaml diff --git a/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml b/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml new file mode 100644 index 00000000..0f43b765 --- /dev/null +++ b/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml @@ -0,0 +1,90 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: classification.yaml + - /loss: crossentropyloss.yaml + - /metric: accuracy.yaml + - /model/backbone: resnet18.yaml + - /model/header: single_layer.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +train: True +test: False + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 100 + weights_summary: full + precision: 16 + check_val_every_n_epoch: 5 + +task: + confusion_matrix_log_every_n_epoch: 20 + confusion_matrix_val: False + confusion_matrix_test: False + +datamodule: + _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped + + data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + crop_size: 256 + num_workers: 4 + batch_size: 32 + shuffle: True + drop_last: True + data_folder_name: data + selection_train: + - e-codices_fmb-cb-0055_0008v_max + - e-codices_fmb-cb-0055_0011r_max + - e-codices_fmb-cb-0055_0011v_max + - e-codices_fmb-cb-0055_0025r_max + - e-codices_fmb-cb-0055_0026v_max + - e-codices_fmb-cb-0055_0027v_max + - e-codices_fmb-cb-0055_0031v_max + - e-codices_fmb-cb-0055_0043v_max + - e-codices_fmb-cb-0055_0044r_max + - e-codices_fmb-cb-0055_0058v_max + - e-codices_fmb-cb-0055_0059r_max + - e-codices_fmb-cb-0055_0059v_max + - e-codices_fmb-cb-0055_0060r_max + - e-codices_fmb-cb-0055_0067v_max + - e-codices_fmb-cb-0055_0072v_max + - e-codices_fmb-cb-0055_0073v_max + - e-codices_fmb-cb-0055_0074v_max + - e-codices_fmb-cb-0055_0084r_max + - e-codices_fmb-cb-0055_0087r_max + +model: + header: + input_size: 2048 + + +callbacks: + model_checkpoint: + filename: ${checkpoint_folder_name}rotnet-resnet18-cb55 + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'rotnet-resnet18-cb55-last10-train' + tags: [ "best_model", "USL"] + group: 'rotnet-baseline' From 174cad0b82413aeb238a8a566ba09e6a01c41d94 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Sun, 31 Oct 2021 15:03:39 +0100 Subject: [PATCH 042/108] :wrench: raised batch size --- configs/experiment/rotnet_resnet18_cb55_train19_last.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml b/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml index 0f43b765..cdc1c7dc 100644 --- a/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml @@ -47,7 +47,7 @@ datamodule: data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 num_workers: 4 - batch_size: 32 + batch_size: 256 shuffle: True drop_last: True data_folder_name: data @@ -85,6 +85,6 @@ callbacks: logger: wandb: - name: 'rotnet-resnet18-cb55-last10-train' + name: 'rotnet-resnet18-cb55-last19-train' tags: [ "best_model", "USL"] group: 'rotnet-baseline' From 7b169dee072bfd8bd5ef2f70bf98e210e14ef4e7 Mon Sep 17 00:00:00 2001 From: Paul M Date: Tue, 2 Nov 2021 12:55:28 +0100 Subject: [PATCH 043/108] :construction: initial start for Rolf's data format --- src/datamodules/RolfFormat/__init__.py | 0 .../RolfFormat/datamodule_cropped.py | 143 +++++++ .../RolfFormat/datasets/__init__.py | 0 .../RolfFormat/datasets/cropped_dataset.py | 251 ++++++++++++ src/datamodules/RolfFormat/utils/__init__.py | 0 .../RolfFormat/utils/functional.py | 63 +++ .../RolfFormat/utils/image_analytics.py | 374 ++++++++++++++++++ src/datamodules/RolfFormat/utils/misc.py | 75 ++++ .../RolfFormat/utils/output_tools.py | 119 ++++++ .../RolfFormat/utils/single_transforms.py | 144 +++++++ .../RolfFormat/utils/twin_transforms.py | 115 ++++++ .../RolfFormat/utils/wrapper_transforms.py | 37 ++ 12 files changed, 1321 insertions(+) create mode 100644 src/datamodules/RolfFormat/__init__.py create mode 100644 src/datamodules/RolfFormat/datamodule_cropped.py create mode 100644 src/datamodules/RolfFormat/datasets/__init__.py create mode 100644 src/datamodules/RolfFormat/datasets/cropped_dataset.py create mode 100644 src/datamodules/RolfFormat/utils/__init__.py create mode 100644 src/datamodules/RolfFormat/utils/functional.py create mode 100644 src/datamodules/RolfFormat/utils/image_analytics.py create mode 100644 src/datamodules/RolfFormat/utils/misc.py create mode 100644 src/datamodules/RolfFormat/utils/output_tools.py create mode 100644 src/datamodules/RolfFormat/utils/single_transforms.py create mode 100644 src/datamodules/RolfFormat/utils/twin_transforms.py create mode 100644 src/datamodules/RolfFormat/utils/wrapper_transforms.py diff --git a/src/datamodules/RolfFormat/__init__.py b/src/datamodules/RolfFormat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datamodules/RolfFormat/datamodule_cropped.py b/src/datamodules/RolfFormat/datamodule_cropped.py new file mode 100644 index 00000000..e85a14a5 --- /dev/null +++ b/src/datamodules/RolfFormat/datamodule_cropped.py @@ -0,0 +1,143 @@ +from pathlib import Path +from typing import Union, List, Optional + +import torch +from torch.utils.data import DataLoader +from torchvision import transforms + +from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB +from src.datamodules.RGB.utils.image_analytics import get_analytics +from src.datamodules.RGB.utils.misc import validate_path_for_segmentation +from src.datamodules.RGB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling, \ + IntegerEncoding +from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget +from src.datamodules.base_datamodule import AbstractDatamodule +from src.utils import utils + +log = utils.get_logger(__name__) + + +class DataModuleCroppedRGB(AbstractDatamodule): + def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, + selection_train: Optional[Union[int, List[str]]] = None, + selection_val: Optional[Union[int, List[str]]] = None, + selection_test: Optional[Union[int, List[str]]] = None, + crop_size: int = 256, num_workers: int = 4, batch_size: int = 8, + shuffle: bool = True, drop_last: bool = True): + super().__init__() + + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + + analytics_data, analytics_gt = get_analytics(input_path=Path(data_dir), + data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name, + get_gt_data_paths_func=CroppedDatasetRGB.get_gt_data_paths) + + self.mean = analytics_data['mean'] + self.std = analytics_data['std'] + self.class_encodings = analytics_gt['class_encodings'] + self.class_encodings_tensor = torch.tensor(self.class_encodings) / 255 + self.num_classes = len(self.class_encodings) + self.class_weights = analytics_gt['class_weights'] + + self.twin_transform = TwinRandomCrop(crop_size=crop_size) + self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std)])) + self.target_transform = OnlyTarget(IntegerEncoding(class_encodings=self.class_encodings_tensor)) + + self.num_workers = num_workers + self.batch_size = batch_size + + self.shuffle = shuffle + self.drop_last = drop_last + + self.data_dir = validate_path_for_segmentation(data_dir=data_dir, data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name) + + self.selection_train = selection_train + self.selection_val = selection_val + self.selection_test = selection_test + + self.dims = (3, crop_size, crop_size) + + def setup(self, stage: Optional[str] = None): + super().setup() + if stage == 'fit' or stage is None: + self.train = CroppedDatasetRGB(**self._create_dataset_parameters('train'), selection=self.selection_train) + self.val = CroppedDatasetRGB(**self._create_dataset_parameters('val'), selection=self.selection_val) + + self._check_min_num_samples(num_samples=len(self.train), data_split='train', + drop_last=self.drop_last) + self._check_min_num_samples(num_samples=len(self.val), data_split='val', + drop_last=self.drop_last) + + if stage == 'test' or stage is not None: + self.test = CroppedDatasetRGB(**self._create_dataset_parameters('test'), selection=self.selection_test) + # self._check_min_num_samples(num_samples=len(self.test), data_split='test', + # drop_last=False) + + def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: bool): + num_processes = self.trainer.num_processes + batch_size = self.batch_size + if drop_last: + if num_samples < (self.trainer.num_processes * self.batch_size): + log.error( + f'#samples ({num_samples}) in "{data_split}" smaller than ' + f'#processes({num_processes}) times batch size ({batch_size}). ' + f'This only works if drop_last is false!') + raise ValueError() + else: + if num_samples < (self.trainer.num_processes * self.batch_size): + log.warning( + f'#samples ({num_samples}) in "{data_split}" smaller than ' + f'#processes ({num_processes}) times batch size ({batch_size}). ' + f'This works due to drop_last=False, however samples will occur multiple times. ' + f'Check if this behavior is intended!') + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(self.train, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=self.shuffle, + drop_last=self.drop_last, + pin_memory=True) + + def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + return DataLoader(self.val, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=self.shuffle, + drop_last=self.drop_last, + pin_memory=True) + + def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + return DataLoader(self.test, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + drop_last=False, + pin_memory=True) + + def _create_dataset_parameters(self, dataset_type: str = 'train'): + is_test = dataset_type == 'test' + return {'path': self.data_dir / dataset_type, + 'data_folder_name': self.data_folder_name, + 'gt_folder_name': self.gt_folder_name, + 'image_transform': self.image_transform, + 'target_transform': self.target_transform, + 'twin_transform': self.twin_transform, + 'classes': self.class_encodings, + 'is_test': is_test} + + def get_img_name_coordinates(self, index): + """ + Returns the original filename of the crop and its coordinate based on the index. + You can just use this during testing! + :param index: + :return: + """ + if not hasattr(self, 'test'): + raise Exception('This method can just be called during testing') + + return self.test.img_paths_per_page[index][2:] diff --git a/src/datamodules/RolfFormat/datasets/__init__.py b/src/datamodules/RolfFormat/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datamodules/RolfFormat/datasets/cropped_dataset.py b/src/datamodules/RolfFormat/datasets/cropped_dataset.py new file mode 100644 index 00000000..b31a2aae --- /dev/null +++ b/src/datamodules/RolfFormat/datasets/cropped_dataset.py @@ -0,0 +1,251 @@ +""" +Load a dataset of historic documents by specifying the folder where its located. +""" + +# Utils +import re +from pathlib import Path +from typing import List, Tuple, Union, Optional + +import torch.utils.data as data +from omegaconf import ListConfig +from torch import is_tensor +from torchvision.transforms import ToTensor + +from src.datamodules.RGB.utils.misc import has_extension, pil_loader +from src.utils import utils + +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] + +log = utils.get_logger(__name__) + + +class CroppedDatasetRGB(data.Dataset): + """A generic data loader where the images are arranged in this way: :: + + root/gt/xxx.png + root/gt/xxy.png + root/gt/xxz.png + + root/data/xxx.png + root/data/xxy.png + root/data/xxz.png + """ + + def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, + selection: Optional[Union[int, List[str]]] = None, + is_test=False, image_transform=None, target_transform=None, twin_transform=None, + classes=None, **kwargs): + """ + #TODO doc + Parameters + ---------- + path : string + Path to dataset folder (train / val / test) + classes : + workers : int + imgs_in_memory : + crops_per_image : int + crop_size : int + image_transform : callable + target_transform : callable + twin_transform : callable + loader : callable + A function to load an image given its path. + """ + + self.path = path + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + self.selection = selection + + # Init list + self.classes = classes + # self.crops_per_image = crops_per_image + + # transformations + self.image_transform = image_transform + self.target_transform = target_transform + self.twin_transform = twin_transform + + self.is_test = is_test + + # List of tuples that contain the path to the gt and image that belong together + self.img_paths_per_page = self.get_gt_data_paths(path, data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name, selection=self.selection) + + # TODO: make more fanzy stuff here + # self.img_paths = [pair for page in self.img_paths_per_page for pair in page] + + self.num_samples = len(self.img_paths_per_page) + if self.num_samples == 0: + raise RuntimeError("Found 0 images in subfolders of: {} \n Supported image extensions are: {}".format( + path, ",".join(IMG_EXTENSIONS))) + + def __len__(self): + """ + This function returns the length of an epoch so the data loader knows when to stop. + The length is different during train/val and test, because we process the whole image during testing, + and only sample from the images during train/val. + """ + return self.num_samples + + def __getitem__(self, index): + if self.is_test: + return self._get_test_items(index=index) + else: + return self._get_train_val_items(index=index) + + def _get_train_val_items(self, index): + data_img, gt_img = self._load_data_and_gt(index=index) + img, gt = self._apply_transformation(data_img, gt_img) + return img, gt + + def _get_test_items(self, index): + data_img, gt_img = self._load_data_and_gt(index=index) + img, gt = self._apply_transformation(data_img, gt_img) + return img, gt, index + + def _load_data_and_gt(self, index): + data_img = pil_loader(self.img_paths_per_page[index][0]) + gt_img = pil_loader(self.img_paths_per_page[index][1]) + + return data_img, gt_img + + def _apply_transformation(self, img, gt): + """ + Applies the transformations that have been defined in the setup (setup.py). If no transformations + have been defined, the PIL image is returned instead. + + Parameters + ---------- + img: PIL image + image data + gt: PIL image + ground truth image + coordinates: tuple (int, int) + coordinates where the sliding window should be cropped + Returns + ------- + tuple + img and gt after transformations + """ + if self.twin_transform is not None and not self.is_test: + img, gt = self.twin_transform(img, gt) + + if self.image_transform is not None: + # perform transformations + img, gt = self.image_transform(img, gt) + + if not is_tensor(img): + img = ToTensor()(img) + if not is_tensor(gt): + gt = ToTensor()(gt) + + if self.target_transform is not None: + img, gt = self.target_transform(img, gt) + + return img, gt + + @staticmethod + def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, + selection: Optional[Union[int, List[str]]] = None) \ + -> List[Tuple[Path, Path, str, str, Tuple[int, int]]]: + """ + Structure of the folder + + directory/data/ORIGINAL_FILENAME/FILE_NAME_X_Y.png + directory/gt/ORIGINAL_FILENAME/FILE_NAME_X_Y.png + + :param directory: + :param data_folder_name: + :param gt_folder_name: + :param selection: + :return: tuple + (path_data_file, path_gt_file, original_image_name, (x, y)) + """ + paths = [] + directory = directory.expanduser() + + path_data_root = directory / data_folder_name + path_gt_root = directory / gt_folder_name + + if not (path_data_root.is_dir() or path_gt_root.is_dir()): + log.error("folder data or gt not found in " + str(directory)) + + # get all subitems (and files) sorted + subitems = sorted(path_data_root.iterdir()) + + # check the selection parameter + if selection: + subdirectories = [x.name for x in subitems if x.is_dir()] + + if isinstance(selection, int): + if selection < 0: + msg = f'Parameter "selection" is a negative integer ({selection}). ' \ + f'Negative values are not supported!' + log.error(msg) + raise ValueError(msg) + + elif selection == 0: + selection = None + + elif selection > len(subdirectories): + msg = f'Parameter "selection" is larger ({selection}) than ' \ + f'number of subdirectories ({len(subdirectories)}).' + log.error(msg) + raise ValueError(msg) + + elif isinstance(selection, ListConfig) or isinstance(selection, list): + if not all(x in subdirectories for x in selection): + msg = f'Parameter "selection" contains a non-existing subdirectory.)' + log.error(msg) + raise ValueError(msg) + + else: + msg = f'Parameter "selection" exists, but it is of unsupported type ({type(selection)})' + log.error(msg) + raise TypeError(msg) + + counter = 0 # Counter for subdirectories, needed for selection parameter + + for path_data_subdir in subitems: + if not path_data_subdir.is_dir(): + if has_extension(path_data_subdir.name, IMG_EXTENSIONS): + log.warning("image file found in data root: " + str(path_data_subdir)) + continue + + counter += 1 + + if selection: + if isinstance(selection, int): + if counter > selection: + break + + elif isinstance(selection, ListConfig) or isinstance(selection, list): + if path_data_subdir.name not in selection: + continue + + path_gt_subdir = path_gt_root / path_data_subdir.stem + assert path_gt_subdir.is_dir() + + for path_data_file, path_gt_file in zip(sorted(path_data_subdir.iterdir()), + sorted(path_gt_subdir.iterdir())): + assert has_extension(path_data_file.name, IMG_EXTENSIONS) == \ + has_extension(path_gt_file.name, IMG_EXTENSIONS), \ + 'get_gt_data_paths(): image file aligned with non-image file' + + if has_extension(path_data_file.name, IMG_EXTENSIONS) and has_extension(path_gt_file.name, + IMG_EXTENSIONS): + assert path_data_file.stem == path_gt_file.stem, \ + 'get_gt_data_paths(): mismatch between data filename and gt filename' + coordinates = re.compile(r'.+_x(\d+)_y(\d+)\.') + m = coordinates.match(path_data_file.name) + if m is None: + continue + x = int(m.group(1)) + y = int(m.group(2)) + # TODO check if we need x/y + paths.append((path_data_file, path_gt_file, path_data_subdir.stem, path_data_file.stem, (x, y))) + + return paths diff --git a/src/datamodules/RolfFormat/utils/__init__.py b/src/datamodules/RolfFormat/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datamodules/RolfFormat/utils/functional.py b/src/datamodules/RolfFormat/utils/functional.py new file mode 100644 index 00000000..b052d201 --- /dev/null +++ b/src/datamodules/RolfFormat/utils/functional.py @@ -0,0 +1,63 @@ +from typing import List + +import torch +from torch.nn.functional import one_hot + + +def gt_to_int_encoding(matrix: torch.Tensor, class_encodings: torch.Tensor): + """ + Convert ground truth tensor or numpy matrix to one-hot encoded matrix + + Parameters + ------- + matrix: float tensor from to_tensor() or numpy array + shape (C x H x W) in the range [0.0, 1.0] or shape (H x W x C) BGR + class_encodings: List of int + Blue channel values that encode the different classes + Returns + ------- + torch.LongTensor of size [#C x H x W] + sparse one-hot encoded multi-class matrix, where #C is the number of classes + """ + integer_encoded = torch.full(size=matrix[0].shape, fill_value=-1, dtype=torch.long) + for index, encoding in enumerate(class_encodings): + mask = torch.logical_and(torch.logical_and( + torch.where(matrix[0] == encoding[0], True, False), + torch.where(matrix[1] == encoding[1], True, False)), + torch.where(matrix[2] == encoding[2], True, False)) + integer_encoded[mask] = index + + return integer_encoded + + +def gt_to_one_hot(matrix: torch.Tensor, class_encodings: torch.Tensor): + """ + Convert ground truth tensor or numpy matrix to one-hot encoded matrix + + Parameters + ------- + matrix: float tensor from to_tensor() or numpy array + shape (C x H x W) in the range [0.0, 1.0] or shape (H x W x C) BGR + class_encodings: List of int + Blue channel values that encode the different classes + Returns + ------- + torch.LongTensor of size [#C x H x W] + sparse one-hot encoded multi-class matrix, where #C is the number of classes + """ + integer_encoded = gt_to_int_encoding(matrix=matrix, class_encodings=class_encodings) + + num_classes = class_encodings.shape[0] + + onehot_encoded = one_hot(input=integer_encoded, num_classes=num_classes) + onehot_encoded = onehot_encoded.swapaxes(1, 2).swapaxes(0, 1) # changes axis from (0, 1, 2) to (2, 0, 1) + + return onehot_encoded + + +def argmax_onehot(tensor: torch.Tensor): + """ + # TODO + """ + output = torch.LongTensor(torch.argmax(tensor, dim=0)) + return output diff --git a/src/datamodules/RolfFormat/utils/image_analytics.py b/src/datamodules/RolfFormat/utils/image_analytics.py new file mode 100644 index 00000000..839636a0 --- /dev/null +++ b/src/datamodules/RolfFormat/utils/image_analytics.py @@ -0,0 +1,374 @@ +# Utils +import errno +import json +import logging +import os +from multiprocessing import Pool +from pathlib import Path +from typing import List + +import numpy as np +# Torch related stuff +import torch +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from PIL import Image + +from src.datamodules.RGB.utils.misc import pil_loader + + +def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, get_gt_data_paths_func, **kwargs): + """ + Parameters + ---------- + input_path: Path to dataset + + Returns + ------- + """ + expected_keys_data = ['mean', 'std'] + expected_keys_gt = ['class_weights', 'class_encodings'] + + analytics_path_data = input_path / f'analytics.data.{data_folder_name}.json' + analytics_path_gt = input_path / f'analytics.gt.{gt_folder_name}.json' + + analytics_data = None + analytics_gt = None + + missing_analytics_data = True + missing_analytics_gt = True + + if analytics_path_data.exists(): + with analytics_path_data.open(mode='r') as f: + analytics_data = json.load(fp=f) + # check if analytics file is complete + if all(k in analytics_data for k in expected_keys_data): + missing_analytics_data = False + + if analytics_path_gt.exists(): + with analytics_path_gt.open(mode='r') as f: + analytics_gt = json.load(fp=f) + # check if analytics file is complete + if all(k in analytics_gt for k in expected_keys_gt): + missing_analytics_gt = False + + if missing_analytics_data or missing_analytics_gt: + train_path = input_path / 'train' + gt_data_path_list = get_gt_data_paths_func(train_path, data_folder_name=data_folder_name, + gt_folder_name=gt_folder_name) + file_names_data = np.asarray([str(item[0]) for item in gt_data_path_list]) + file_names_gt = np.asarray([str(item[1]) for item in gt_data_path_list]) + + if missing_analytics_data: + mean, std = compute_mean_std(file_names=file_names_data, **kwargs) + analytics_data = {'mean': mean.tolist(), + 'std': std.tolist()} + # save json + try: + with analytics_path_data.open(mode='w') as f: + json.dump(obj=analytics_data, fp=f) + except IOError as e: + if e.errno == errno.EACCES: + print(f'WARNING: No permissions to write analytics file ({analytics_path_data})') + else: + raise + + if missing_analytics_gt: + # Measure weights for class balancing + logging.info(f'Measuring class weights') + # create a list with all gt file paths + class_weights, class_encodings = _get_class_frequencies_weights_segmentation(gt_images=file_names_gt, + **kwargs) + analytics_gt = {'class_weights': class_weights, + 'class_encodings': class_encodings} + # save json + try: + with analytics_path_gt.open(mode='w') as f: + json.dump(obj=analytics_gt, fp=f) + except IOError as e: + if e.errno == errno.EACCES: + print(f'WARNING: No permissions to write analytics file ({analytics_path_gt})') + else: + raise + + return analytics_data, analytics_gt + + +def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): + """ + Computes mean and std of all images present at target folder. + + Parameters + ---------- + input_folder : String (path) + Path to the dataset folder (see above for details) + inmem : Boolean + Specifies whether is should be computed i nan online of offline fashion. + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + mean : float + Mean value of all pixels of the images in the input folder + std : float + Standard deviation of all pixels of the images in the input folder + """ + file_names_np = np.array(list(map(str, file_names))) + # Compute mean and std + mean, std = _cms_inmem(file_names_np) if inmem else _cms_online(file_names_np, workers) + return mean, std + + +def _cms_online(file_names, workers=4): + """ + Computes mean and image_classification deviation in an online fashion. + This is useful when the dataset is too big to be allocated in memory. + + Parameters + ---------- + file_names : List of String + List of file names of the dataset + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + mean : double + std : double + """ + logging.info('Begin computing the mean') + + # Set up a pool of workers + pool = Pool(workers + 1) + + # Online mean + results = pool.map(_return_mean, file_names) + mean_sum = np.sum(np.array(results), axis=0) + + # Divide by number of samples in train set + mean = mean_sum / file_names.size + + logging.info('Finished computing the mean') + logging.info('Begin computing the std') + + # Online image_classification deviation + results = pool.starmap(_return_std, [[item, mean] for item in file_names]) + std_sum = np.sum(np.array([item[0] for item in results]), axis=0) + total_pixel_count = np.sum(np.array([item[1] for item in results])) + std = np.sqrt(std_sum / total_pixel_count) + logging.info('Finished computing the std') + + # Shut down the pool + pool.close() + + return mean, std + + +# Loads an image with OpenCV and returns the channel wise means of the image. +def _return_mean(image_path): + img = np.array(Image.open(image_path).convert('RGB')) + mean = np.array([np.mean(img[:, :, 0]), np.mean(img[:, :, 1]), np.mean(img[:, :, 2])]) / 255.0 + return mean + + +# Loads an image with OpenCV and returns the channel wise std of the image. +def _return_std(image_path, mean): + img = np.array(Image.open(image_path).convert('RGB')) / 255.0 + m2 = np.square(np.array([img[:, :, 0] - mean[0], img[:, :, 1] - mean[1], img[:, :, 2] - mean[2]])) + return np.sum(np.sum(m2, axis=1), 1), m2.size / 3.0 + + +def _cms_inmem(file_names): + """ + Computes mean and image_classification deviation in an offline fashion. This is possible only when the dataset can + be allocated in memory. + + Parameters + ---------- + file_names: List of String + List of file names of the dataset + Returns + ------- + mean : double + std : double + """ + img = np.zeros([file_names.size] + list(np.array(Image.open(file_names[0]).convert('RGB')).shape)) + + # Load all samples + for i, sample in enumerate(file_names): + img[i] = np.array(Image.open(sample).convert('RGB')) + + mean = np.array([np.mean(img[:, :, :, 0]), np.mean(img[:, :, :, 1]), np.mean(img[:, :, :, 2])]) / 255.0 + std = np.array([np.std(img[:, :, :, 0]), np.std(img[:, :, :, 1]), np.std(img[:, :, :, 2])]) / 255.0 + + return mean, std + + +def get_class_weights(input_folder, workers=4, **kwargs): + """ + Get the weights proportional to the inverse of their class frequencies. + The vector sums up to 1 + + Parameters + ---------- + input_folder : String (path) + Path to the dataset folder (see above for details) + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + ndarray[double] of size (num_classes) + The weights vector as a 1D array normalized (sum up to 1) + """ + # Sanity check on the folder + if not os.path.isdir(input_folder): + logging.error(f"Folder {input_folder} does not exist") + raise FileNotFoundError + + # Load the dataset + ds = datasets.ImageFolder(input_folder, transform=transforms.Compose([transforms.ToTensor()])) + + logging.info('Begin computing class frequencies weights') + + if hasattr(ds, 'targets'): + labels = ds.targets + elif hasattr(ds, 'labels'): + labels = ds.labels + else: + # This is a fail-safe net in case a custom dataset changed the name of the internal variables + data_loader = torch.utils.data.DataLoader(ds, batch_size=1, num_workers=workers) + labels = [] + for target, label in data_loader: + labels.append(label) + labels = np.concatenate(labels).reshape(len(ds)) + + class_support = np.unique(labels, return_counts=True)[1] + class_frequencies = class_support / len(labels) + # Class weights are the inverse of the class frequencies + class_weights = 1 / class_frequencies + # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) + class_weights /= class_weights.sum() + + logging.info('Finished computing class frequencies weights ') + logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') + logging.info(f'Class weights (rounded): {np.around(class_weights * 100, decimals=2)}') + + return class_weights + + +def compute_mean_std_graphs(dataset, **kwargs): + """ + Computes mean and std of all node and edge features present in the given ParsedGxlDataset (see gxl_parser.py). + + Parameters + ---------- + input_folder : ParsedGxlDataset + Dataset object (see above for details) + + # TODO implement online version + + Returns + ------- + node_features : {"mean": list, "std": list} + Mean and std value of all node features in the input dataset + edge_features : {"mean": list, "std": list} + Mean and std value of all edge features in the input dataset + """ + if dataset.data.x is not None: + logging.info('Begin computing the node feature mean and std') + nodes = _get_feature_mean_std(dataset.data.x) + logging.info('Finished computing the node feature mean and std') + else: + nodes = {} + logging.info('No node features present') + + if dataset.data.edge_attr is not None: + logging.info('Begin computing the edge feature mean and std') + edges = _get_feature_mean_std(dataset.data.edge_attr) + logging.info('Finished computing the edge feature mean and std') + else: + edges = {} + logging.info('No edge features present') + + return nodes, edges + + +def _get_feature_mean_std(torch_array): + array = np.array(torch_array) + return {'mean': [np.mean(col) for col in array.T], 'std': [np.std(col) for col in array.T]} + + +def get_class_weights_graphs(dataset, **kwargs): + """ + Get the weights proportional to the inverse of their class frequencies. + The vector sums up to 1 + + Parameters + ---------- + input_folder : ParsedGxlDataset + Dataset object (see above for details) + + # TODO implement online version + + Returns + ------- + ndarray[double] of size (num_classes) + The weights vector as a 1D array normalized (sum up to 1) + """ + logging.info('Begin computing class frequencies weights') + + class_frequencies = np.array(dataset.config['class_freq'][1]) + # Class weights are the inverse of the class frequencies + class_weights = 1 / class_frequencies + # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) + class_weights /= class_weights.sum() + + logging.info('Finished computing class frequencies weights ') + logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') + logging.info(f'Class weights (rounded): {np.around(class_weights)}') + + return class_weights + + +def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): + """ + Get the weights proportional to the inverse of their class frequencies. + The vector sums up to 1 + + Parameters + ---------- + gt_images: list of strings + Path to all ground truth images, which contain the pixel-wise label + workers: int + Number of workers to use for the mean/std computation + + Returns + ------- + ndarray[double] of size (num_classes) and ints the classes are represented as + The weights vector as a 1D array normalized (sum up to 1) + """ + logging.info('Begin computing class frequencies weights') + + total_num_pixels = 0 + label_counter = {} + + for path in gt_images: + img_raw = pil_loader(path) + colors = img_raw.getcolors() + + for count, color in colors: + total_num_pixels += count + label_counter[color] = label_counter.get(color, 0) + count + + classes = sorted(label_counter.keys()) + num_samples_per_class = np.asarray([label_counter[k] for k in classes]) + logging.info('Finished computing class frequencies weights') + # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) + class_weights = (1 / num_samples_per_class) / ((1 / num_samples_per_class).sum()) + return class_weights.tolist(), classes + + +if __name__ == '__main__': + print(get_analytics(input_path=Path('tests/dummy_data/dummy_dataset'))) diff --git a/src/datamodules/RolfFormat/utils/misc.py b/src/datamodules/RolfFormat/utils/misc.py new file mode 100644 index 00000000..c0de22d4 --- /dev/null +++ b/src/datamodules/RolfFormat/utils/misc.py @@ -0,0 +1,75 @@ +""" +General purpose utility functions. + +""" + +from pathlib import Path + +# Utils +import numpy as np +from PIL import Image + +from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir + +try: + import accimage +except ImportError: + accimage = None + + +def has_extension(filename, extensions): + filename_lower = filename.lower() + return any(filename_lower.endswith(ext) for ext in extensions) + + +def pil_loader(path, to_rgb=True): + pic = Image.open(path) + if to_rgb: + pic = convert_to_rgb(pic) + return pic + + +def convert_to_rgb(pic): + if pic.mode == "RGB": + pass + elif pic.mode in ("CMYK", "RGBA", "P"): + pic = pic.convert('RGB') + elif pic.mode == "I": + img = (np.divide(np.array(pic, np.int32), 2 ** 16 - 1) * 255).astype(np.uint8) + pic = Image.fromarray(np.stack((img, img, img), axis=2)) + elif pic.mode == "I;16": + img = (np.divide(np.array(pic, np.int16), 2 ** 8 - 1) * 255).astype(np.uint8) + pic = Image.fromarray(np.stack((img, img, img), axis=2)) + elif pic.mode == "L": + img = np.array(pic).astype(np.uint8) + pic = Image.fromarray(np.stack((img, img, img), axis=2)) + else: + raise TypeError(f"unsupported image type {pic.mode}") + return pic + + +def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_name: str): + if data_dir is None: + raise PathNone("Please provide the path to root dir of the dataset " + "(folder containing the train/val/test folder)") + else: + split_names = ['train', 'val', 'test'] + type_names = [data_folder_name, gt_folder_name] + + data_folder = Path(data_dir) + if not data_folder.is_dir(): + raise PathNotDir("Please provide the path to root dir of the dataset " + "(folder containing the train/val/test folder)") + split_folders = [d for d in data_folder.iterdir() if d.is_dir() and d.name in split_names] + if len(split_folders) != 3: + raise PathMissingSplitDir(f'Your path needs to contain train/val/test and ' + f'each of them a folder {data_folder_name} and {gt_folder_name}') + + # check if we have train/test/val + for split in split_folders: + type_folders = [d for d in split.iterdir() if d.is_dir() and d.name in type_names] + # check if we have data/gt + if len(type_folders) != 2: + raise PathMissingDirinSplitDir(f'Folder {split.name} does not contain a {data_folder_name} ' + f'and {gt_folder_name} folder') + return Path(data_dir) diff --git a/src/datamodules/RolfFormat/utils/output_tools.py b/src/datamodules/RolfFormat/utils/output_tools.py new file mode 100644 index 00000000..6a472955 --- /dev/null +++ b/src/datamodules/RolfFormat/utils/output_tools.py @@ -0,0 +1,119 @@ +from pathlib import Path +from typing import Union, Tuple, List + +import numpy as np +import torch +from PIL import Image + + +def _get_argmax(output: Union[torch.Tensor, np.ndarray], dim=1): + """ + takes the biggest value from a pixel across all classes + :param output: (Batch_size x num_classes x W x H) + matrix with the given attributes + :return: (Batch_size x W x H) + matrix with the hisdb class number for each pixel + """ + if isinstance(output, torch.Tensor): + return torch.argmax(output, dim=dim) + if isinstance(output, np.ndarray): + return np.argmax(output, axis=dim) + return output + + +def merge_patches(patch, coordinates, full_output): + """ + This function merges the patch into the full output image + Overlapping values are resolved by taking the max. + + Parameters + ---------- + patch: numpy matrix of size [#classes x crop_size x crop_size] + a patch from the larger image + coordinates: tuple of ints + top left coordinates of the patch within the larger image for all patches in a batch + full_output: numpy matrix of size [#C x H x W] + output image at full size + Returns + ------- + full_output: numpy matrix [#C x Htot x Wtot] + """ + assert len(full_output.shape) == 3 + assert full_output.size != 0 + + # Resolve patch coordinates + x1, y1 = coordinates + x2, y2 = x1 + patch.shape[2], y1 + patch.shape[1] + + # If this triggers it means that a patch is 'out-of-bounds' of the image and that should never happen! + assert x2 <= full_output.shape[2] + assert y2 <= full_output.shape[1] + + mask = np.isnan(full_output[:, y1:y2, x1:x2]) + # if still NaN in full_output just insert value from crop, if there is a value then take max + full_output[:, y1:y2, x1:x2] = np.where(mask, patch, np.maximum(patch, full_output[:, y1:y2, x1:x2])) + + return full_output + + +def save_output_page_image(image_name, output_image, output_folder: Path, class_encoding: List[Tuple[int]]): + """ + Helper function to save the output during testing in the DIVAHisDB format + + Parameters + ---------- + image_name: str + name of the image that is saved + output_image: numpy matrix of size [#C x H x W] + output image at full size + output_folder: Path + path to the output folder for the test data + class_encoding: list(tuple(int)) + list with the class encodings + + Returns + ------- + mean_iu : float + mean iu of this image + """ + + output_encoded = output_to_class_encodings(output_image, class_encoding) + + dest_folder = output_folder + dest_folder.mkdir(parents=True, exist_ok=True) + dest_filename = dest_folder / image_name + + # Save the output + Image.fromarray(output_encoded.astype(np.uint8)).save(str(dest_filename)) + + +def output_to_class_encodings(output, class_encodings): + """ + This function converts the output prediction matrix to an image like it was provided in the ground truth + + Parameters + ------- + output : np.array of size [#C x H x W] + output prediction of the network for a full-size image, where #C is the number of classes + class_encodings : List + Contains the range of encoded classes + perform_argmax : bool + perform argmax on input data + Returns + ------- + numpy array of size [C x H x W] (BGR) + """ + + integer_encoded = np.argmax(output, axis=0) + + num_classes = len(class_encodings) + + masks = [integer_encoded == class_index for class_index in range(num_classes)] + + rgb = np.full((*integer_encoded.shape, 3), -1) + for mask, color in zip(masks, class_encodings): + rgb[:, :, 0] = np.where(mask, color[0], rgb[:, :, 0]) + rgb[:, :, 1] = np.where(mask, color[1], rgb[:, :, 1]) + rgb[:, :, 2] = np.where(mask, color[2], rgb[:, :, 2]) + + return rgb diff --git a/src/datamodules/RolfFormat/utils/single_transforms.py b/src/datamodules/RolfFormat/utils/single_transforms.py new file mode 100644 index 00000000..dc48230b --- /dev/null +++ b/src/datamodules/RolfFormat/utils/single_transforms.py @@ -0,0 +1,144 @@ +import math +import random + +import torch +from PIL import Image +from torchvision.transforms import Pad + + +class ResizePad(object): + """ + Perform resizing keeping the aspect ratio of the image --padding type: continuous (black). + Expects PIL image and int value as target_size + (It can be extended to perform other transforms on both PIL image and object boxes.) + + Example: + target_size = 200 + # im: numpy array + img = Image.fromarray(im.astype('uint8'), 'RGB') + img = ResizePad(target_size)(img) + """ + + def __init__(self, target_size): + self.target_size = target_size + self.boxes = torch.Tensor([[0, 0, 0, 0]]) + + def resize(self, img, boxes, size, max_size=1000): + '''Resize the input PIL image to the given size. + Args: + img: (PIL.Image) image to be resized. + boxes: (tensor) object boxes, sized [#ojb,4]. + size: (tuple or int) + - if is tuple, resize image to the size. + - if is int, resize the shorter side to the size while maintaining the aspect ratio. + max_size: (int) when size is int, limit the image longer size to max_size. + This is essential to limit the usage of GPU memory. + Returns: + img: (PIL.Image) resized image. + boxes: (tensor) resized boxes. + ''' + w, h = img.size + if isinstance(size, int): + size_min = min(w, h) + size_max = max(w, h) + sw = sh = float(size) / size_min + if sw * size_max > max_size: + sw = sh = float(max_size) / size_max + ow = int(w * sw + 0.5) + oh = int(h * sh + 0.5) + else: + ow, oh = size + sw = float(ow) / w + sh = float(oh) / h + return img.resize((ow, oh), Image.BILINEAR), \ + boxes * torch.Tensor([sw, sh, sw, sh]) + + def random_crop(self, img, boxes): + '''Crop the given PIL image to a random size and aspect ratio. + A crop of random size of (0.08 to 1.0) of the original size and a random + aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. + Args: + img: (PIL.Image) image to be cropped. + boxes: (tensor) object boxes, sized [#ojb,4]. + Returns: + img: (PIL.Image) randomly cropped image. + boxes: (tensor) randomly cropped boxes. + ''' + success = False + for attempt in range(10): + area = img.size[0] * img.size[1] + target_area = random.uniform(0.56, 1.0) * area + aspect_ratio = random.uniform(3. / 4, 4. / 3) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if random.random() < 0.5: + w, h = h, w + + if w <= img.size[0] and h <= img.size[1]: + x = random.randint(0, img.size[0] - w) + y = random.randint(0, img.size[1] - h) + success = True + break + + # Fallback + if not success: + w = h = min(img.size[0], img.size[1]) + x = (img.size[0] - w) // 2 + y = (img.size[1] - h) // 2 + + img = img.crop((x, y, x + w, y + h)) + boxes -= torch.Tensor([x, y, x, y]) + boxes[:, 0::2].clamp_(min=0, max=w - 1) + boxes[:, 1::2].clamp_(min=0, max=h - 1) + return img, boxes + + def center_crop(self, img, boxes, size): + '''Crops the given PIL Image at the center. + Args: + img: (PIL.Image) image to be cropped. + boxes: (tensor) object boxes, sized [#ojb,4]. + size (tuple): desired output size of (w,h). + Returns: + img: (PIL.Image) center cropped image. + boxes: (tensor) center cropped boxes. + ''' + w, h = img.size + ow, oh = size + i = int(round((h - oh) / 2.)) + j = int(round((w - ow) / 2.)) + img = img.crop((j, i, j + ow, i + oh)) + boxes -= torch.Tensor([j, i, j, i]) + boxes[:, 0::2].clamp_(min=0, max=ow - 1) + boxes[:, 1::2].clamp_(min=0, max=oh - 1) + return img, boxes + + def random_flip(self, img, boxes): + '''Randomly flip the given PIL Image. + Args: + img: (PIL Image) image to be flipped. + boxes: (tensor) object boxes, sized [#ojb,4]. + Returns: + img: (PIL.Image) randomly flipped image. + boxes: (tensor) randomly flipped boxes. + ''' + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + w = img.width + xmin = w - boxes[:, 2] + xmax = w - boxes[:, 0] + boxes[:, 0] = xmin + boxes[:, 2] = xmax + return img, boxes + + def resize_with_padding(self, img, target_size): + img, boxes = self.resize(img, self.boxes, target_size, max_size=target_size) + padding = (max(0, target_size - img.size[0]) // 2, max(0, target_size - img.size[1]) // 2) + img = Pad(padding)(img) + + return img + + def __call__(self, img): + img = self.resize_with_padding(img, self.target_size) + return img \ No newline at end of file diff --git a/src/datamodules/RolfFormat/utils/twin_transforms.py b/src/datamodules/RolfFormat/utils/twin_transforms.py new file mode 100644 index 00000000..5ba68cad --- /dev/null +++ b/src/datamodules/RolfFormat/utils/twin_transforms.py @@ -0,0 +1,115 @@ +import random + +from torchvision.transforms import functional as F + +from src.datamodules.RGB.utils import functional as F_custom + + +class TwinCompose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img, gt): + for t in self.transforms: + img, gt = t(img, gt) + return img, gt + + +class TwinRandomCrop(object): + """Crop the given PIL Images at the same random location""" + + def __init__(self, crop_size): + self.crop_size = crop_size + + def get_params(self, img_size): + """Get parameters for ``crop`` for a random crop""" + w, h = img_size + th = self.crop_size + tw = self.crop_size + + assert w >= tw and h >= th + + if w == tw and h == th: + return 0, 0, h, w + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + def __call__(self, img, gt): + i, j, h, w = self.get_params(img.size) + return F.crop(img, i, j, h, w), F.crop(gt, i, j, h, w) + + +class TwinImageToTensor(object): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + Converts a PIL Image or numpy.ndarray (W x H x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. + """ + + def __call__(self, img, gt): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + return F.to_tensor(img), F.to_tensor(gt) + + +class ToTensorSlidingWindowCrop(object): + """ + Crop the data and ground truth image at the specified coordinates to the specified size and convert + them to a tensor. + """ + + def __init__(self, crop_size): + self.crop_size = crop_size + + def __call__(self, img, gt, coordinates): + """ + Args: + img (PIL Image): Data image to be cropped and converted to tensor. + gt (PIL Image): Ground truth image to be cropped and converted to tensor. + + Returns: + Data tensor, gt tensor (tuple of tensors): cropped and converted images + + """ + x_position = coordinates[0] + y_position = coordinates[1] + + return F.to_tensor(F.crop(img, x_position, y_position, self.crop_size, self.crop_size)), \ + F.to_tensor(F.crop(gt, x_position, y_position, self.crop_size, self.crop_size)) + + +class OneHotToPixelLabelling(object): + def __call__(self, tensor): + return F_custom.argmax_onehot(tensor) + + +class OneHotEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_one_hot(gt, self.class_encodings) + + +class IntegerEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_int_encoding(gt, self.class_encodings) diff --git a/src/datamodules/RolfFormat/utils/wrapper_transforms.py b/src/datamodules/RolfFormat/utils/wrapper_transforms.py new file mode 100644 index 00000000..eaa5e437 --- /dev/null +++ b/src/datamodules/RolfFormat/utils/wrapper_transforms.py @@ -0,0 +1,37 @@ +from typing import Callable + + +class OnlyImage(object): + """Wrapper function around a single parameter transform. It will be cast only on image""" + + def __init__(self, transform: Callable): + """Initialize the transformation with the transformation to be called. + Could be a compose. + + Parameters + ---------- + transform : torchvision.transforms.transforms + Transformation to wrap + """ + self.transform = transform + + def __call__(self, image, target): + return self.transform(image), target + + +class OnlyTarget(object): + """Wrapper function around a single parameter transform. It will be cast only on target""" + + def __init__(self, transform: Callable): + """Initialize the transformation with the transformation to be called. + Could be a compose. + + Parameters + ---------- + transform : torchvision.transforms.transforms + Transformation to wrap + """ + self.transform = transform + + def __call__(self, image, target): + return image, self.transform(target) \ No newline at end of file From edef8d9ccdd29dda6ff7696924016b205e59c8ea Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 2 Nov 2021 14:51:35 +0100 Subject: [PATCH 044/108] :bug: fixed a problem with the resolvers in OmegaConfig when using multiruns --- src/datamodules/base_datamodule.py | 11 ++++++----- src/tasks/base_task.py | 11 ++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/datamodules/base_datamodule.py b/src/datamodules/base_datamodule.py index fd9b2c50..ab30a1d8 100644 --- a/src/datamodules/base_datamodule.py +++ b/src/datamodules/base_datamodule.py @@ -9,11 +9,12 @@ def __init__(self): super().__init__() self.num_classes = -1 resolver_name = 'datamodule' - OmegaConf.register_new_resolver( - resolver_name, - lambda name: getattr(self, name), - use_cache=False - ) + if not OmegaConf.has_resolver(resolver_name): + OmegaConf.register_new_resolver( + resolver_name, + lambda name: getattr(self, name), + use_cache=False + ) def setup(self, stage: Optional[str] = None) -> None: if not self.dims: diff --git a/src/tasks/base_task.py b/src/tasks/base_task.py index 73c2f658..976f20eb 100644 --- a/src/tasks/base_task.py +++ b/src/tasks/base_task.py @@ -65,11 +65,12 @@ def __init__( super().__init__() resolver_name = 'task' - OmegaConf.register_new_resolver( - resolver_name, - lambda name: getattr(self, name), - use_cache=False - ) + if not OmegaConf.has_resolver(resolver_name): + OmegaConf.register_new_resolver( + resolver_name, + lambda name: getattr(self, name), + use_cache=False + ) if model is not None: self.model = model From 742cb9f75534a1c9ccac3a6bd5954dc9101d3527 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 2 Nov 2021 15:12:54 +0100 Subject: [PATCH 045/108] :sound_loud: :art: fixed some logging problem and removed a condition --- .../RotNet/utils/image_analytics.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/datamodules/RotNet/utils/image_analytics.py b/src/datamodules/RotNet/utils/image_analytics.py index 750733ac..f62ebbec 100644 --- a/src/datamodules/RotNet/utils/image_analytics.py +++ b/src/datamodules/RotNet/utils/image_analytics.py @@ -44,19 +44,18 @@ def get_analytics_data(input_path: Path, data_folder_name: str, get_gt_data_path gt_data_path_list = get_gt_data_paths_func(train_path, data_folder_name=data_folder_name, gt_folder_name=None) file_names_data = np.asarray([str(item) for item in gt_data_path_list]) - if missing_analytics_data: - mean, std = compute_mean_std(file_names=file_names_data, **kwargs) - analytics_data = {'mean': mean.tolist(), - 'std': std.tolist()} - # save json - try: - with analytics_path_data.open(mode='w') as f: - json.dump(obj=analytics_data, fp=f) - except IOError as e: - if e.errno == errno.EACCES: - print(f'WARNING: No permissions to write analytics file ({analytics_path_data})') - else: - raise + mean, std = compute_mean_std(file_names=file_names_data, **kwargs) + analytics_data = {'mean': mean.tolist(), + 'std': std.tolist()} + # save json + try: + with analytics_path_data.open(mode='w') as f: + json.dump(obj=analytics_data, fp=f) + except IOError as e: + if e.errno == errno.EACCES: + logging.warning(f'WARNING: No permissions to write analytics file ({analytics_path_data})') + else: + raise return analytics_data From 09ec24261abd9f4681efa34492def9a87eeedfd4 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 2 Nov 2021 15:52:25 +0100 Subject: [PATCH 046/108] :wrench: reduced the amount of epochs --- configs/experiment/rotnet_resnet18_cb55_full.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/experiment/rotnet_resnet18_cb55_full.yaml b/configs/experiment/rotnet_resnet18_cb55_full.yaml index 28897c46..d56e88fb 100644 --- a/configs/experiment/rotnet_resnet18_cb55_full.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_full.yaml @@ -31,7 +31,7 @@ trainer: gpus: -1 accelerator: 'ddp' min_epochs: 1 - max_epochs: 200 + max_epochs: 100 weights_summary: full precision: 16 From fa0385b5d429a75dab9712c4ef0e1dd934b8c394 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 2 Nov 2021 15:53:12 +0100 Subject: [PATCH 047/108] :wrench: reduced the amount of val checks --- configs/experiment/rotnet_resnet18_cb55_full.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/experiment/rotnet_resnet18_cb55_full.yaml b/configs/experiment/rotnet_resnet18_cb55_full.yaml index d56e88fb..5282a45d 100644 --- a/configs/experiment/rotnet_resnet18_cb55_full.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_full.yaml @@ -34,6 +34,7 @@ trainer: max_epochs: 100 weights_summary: full precision: 16 + check_val_every_n_epoch: 5 task: confusion_matrix_log_every_n_epoch: 20 From 4f90397dd6737d82a592f2d8b81f63a3c9abde6d Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 2 Nov 2021 16:59:50 +0100 Subject: [PATCH 048/108] :construction: started full page datamodule --- .../RGB/datasets/full_page_dataset.py | 233 ++++++++++++++++++ .../datamodules/RGB/test_full_page_dataset.py | 32 +++ 2 files changed, 265 insertions(+) create mode 100644 src/datamodules/RGB/datasets/full_page_dataset.py create mode 100644 tests/datamodules/RGB/test_full_page_dataset.py diff --git a/src/datamodules/RGB/datasets/full_page_dataset.py b/src/datamodules/RGB/datasets/full_page_dataset.py new file mode 100644 index 00000000..0d55c522 --- /dev/null +++ b/src/datamodules/RGB/datasets/full_page_dataset.py @@ -0,0 +1,233 @@ +""" +Load a dataset of historic documents by specifying the folder where its located. +""" + +# Utils +import re +from pathlib import Path +from typing import List, Tuple, Union, Optional, Any + +import torch.utils.data as data +from omegaconf import ListConfig +from torch import is_tensor +from torchvision.transforms import ToTensor + +from src.datamodules.RGB.utils.misc import has_extension, pil_loader +from src.utils import utils + +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] + +log = utils.get_logger(__name__) + + +class DatasetRGB(data.Dataset): + """A generic data loader where the images are arranged in this way: :: + + root/gt/xxx.png + root/gt/xxy.png + root/gt/xxz.png + + root/data/xxx.png + root/data/xxy.png + root/data/xxz.png + """ + + def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, + selection: Optional[Union[int, List[str]]] = None, + is_test=False, image_transform=None, target_transform=None, twin_transform=None, + classes=None, **kwargs): + """ + #TODO doc + Parameters + ---------- + path : string + Path to dataset folder (train / val / test) + classes : + workers : int + imgs_in_memory : + crops_per_image : int + crop_size : int + image_transform : callable + target_transform : callable + twin_transform : callable + loader : callable + A function to load an image given its path. + """ + + self.path = path + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + self.selection = selection + + # Init list + self.classes = classes + # self.crops_per_image = crops_per_image + + # transformations + self.image_transform = image_transform + self.target_transform = target_transform + self.twin_transform = twin_transform + + self.is_test = is_test + + # List of tuples that contain the path to the gt and image that belong together + self.img_paths_per_page = self.get_gt_data_paths(path, data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name, selection=self.selection) + + # TODO: make more fanzy stuff here + # self.img_paths = [pair for page in self.img_paths_per_page for pair in page] + + self.num_samples = len(self.img_paths_per_page) + if self.num_samples == 0: + raise RuntimeError("Found 0 images in subfolders of: {} \n Supported image extensions are: {}".format( + path, ",".join(IMG_EXTENSIONS))) + + def __len__(self): + """ + This function returns the length of an epoch so the data loader knows when to stop. + The length is different during train/val and test, because we process the whole image during testing, + and only sample from the images during train/val. + """ + return self.num_samples + + def __getitem__(self, index): + if self.is_test: + return self._get_test_items(index=index) + else: + return self._get_train_val_items(index=index) + + def _get_train_val_items(self, index): + data_img, gt_img = self._load_data_and_gt(index=index) + img, gt = self._apply_transformation(data_img, gt_img) + return img, gt + + def _get_test_items(self, index): + data_img, gt_img = self._load_data_and_gt(index=index) + img, gt = self._apply_transformation(data_img, gt_img) + return img, gt, index + + def _load_data_and_gt(self, index): + data_img = pil_loader(self.img_paths_per_page[index][0]) + gt_img = pil_loader(self.img_paths_per_page[index][1]) + + return data_img, gt_img + + def _apply_transformation(self, img, gt): + """ + Applies the transformations that have been defined in the setup (setup.py). If no transformations + have been defined, the PIL image is returned instead. + + Parameters + ---------- + img: PIL image + image data + gt: PIL image + ground truth image + coordinates: tuple (int, int) + coordinates where the sliding window should be cropped + Returns + ------- + tuple + img and gt after transformations + """ + if self.twin_transform is not None and not self.is_test: + img, gt = self.twin_transform(img, gt) + + if self.image_transform is not None: + # perform transformations + img, gt = self.image_transform(img, gt) + + if not is_tensor(img): + img = ToTensor()(img) + if not is_tensor(gt): + gt = ToTensor()(gt) + + if self.target_transform is not None: + img, gt = self.target_transform(img, gt) + + return img, gt + + @staticmethod + def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, + selection: Optional[Union[int, List[str]]] = None) \ + -> List[Tuple[Union[Path, Any], Path]]: + """ + Structure of the folder + + directory/data/FILE_NAME.png + directory/gt/FILE_NAME.png + + :param directory: + :param data_folder_name: + :param gt_folder_name: + :param selection: + :return: tuple + (path_data_file, path_gt_file) + """ + paths = [] + directory = directory.expanduser() + + path_data_root = directory / data_folder_name + path_gt_root = directory / gt_folder_name + + if not (path_data_root.is_dir() or path_gt_root.is_dir()): + log.error("folder data or gt not found in " + str(directory)) + + # get all files sorted + files_in_data_root = sorted(path_data_root.iterdir()) + + # check the selection parameter + if selection: + subdirectories = [x.name for x in files_in_data_root if x.is_dir()] + + if isinstance(selection, int): + if selection < 0: + msg = f'Parameter "selection" is a negative integer ({selection}). ' \ + f'Negative values are not supported!' + log.error(msg) + raise ValueError(msg) + + elif selection == 0: + selection = None + + elif selection > len(subdirectories): + msg = f'Parameter "selection" is larger ({selection}) than ' \ + f'number of subdirectories ({len(subdirectories)}).' + log.error(msg) + raise ValueError(msg) + + elif isinstance(selection, ListConfig) or isinstance(selection, list): + if not all(x in subdirectories for x in selection): + msg = f'Parameter "selection" contains a non-existing subdirectory.)' + log.error(msg) + raise ValueError(msg) + + else: + msg = f'Parameter "selection" exists, but it is of unsupported type ({type(selection)})' + log.error(msg) + raise TypeError(msg) + + counter = 0 # Counter for subdirectories, needed for selection parameter + + for path_data_file, path_gt_file in zip(sorted(files_in_data_root), sorted(path_gt_root.iterdir())): + counter += 1 + + if selection: + if isinstance(selection, int): + if counter > selection: + break + + elif isinstance(selection, ListConfig) or isinstance(selection, list): + if path_data_file.name not in selection: + continue + + assert has_extension(path_data_file.name, IMG_EXTENSIONS) == \ + has_extension(path_gt_file.name, IMG_EXTENSIONS), \ + 'get_gt_data_paths(): image file aligned with non-image file' + + assert path_data_file.stem == path_gt_file.stem, \ + 'get_gt_data_paths(): mismatch between data filename and gt filename' + # TODO check if we need x/y + paths.append((path_data_file, path_gt_file)) + + return paths diff --git a/tests/datamodules/RGB/test_full_page_dataset.py b/tests/datamodules/RGB/test_full_page_dataset.py new file mode 100644 index 00000000..2b779895 --- /dev/null +++ b/tests/datamodules/RGB/test_full_page_dataset.py @@ -0,0 +1,32 @@ +import pytest + +from src.datamodules.RGB.datasets.full_page_dataset import DatasetRGB +from tests.test_data.dummy_data_hisdb.dummy_data import data_dir + + +@pytest.fixture +def dataset_train(data_dir): + return DatasetRGB(path=data_dir / 'train', data_folder_name='data', gt_folder_name='gt') + + +def test_get_gt_data_paths(data_dir): + file_list = DatasetRGB.get_gt_data_paths(directory=data_dir / 'train', data_folder_name='data', gt_folder_name='gt') + assert len(file_list) == 1 + assert file_list[0] == (data_dir / 'train' / 'data' / 'e-codices_fmb-cb-0055_0098v_max.jpg', + data_dir / 'train' / 'gt' / 'e-codices_fmb-cb-0055_0098v_max.png') + + +def test_dataset_rgb(dataset_train): + data_tensor, gt_tensor = dataset_train[0] + assert data_tensor.shape == gt_tensor.shape + assert data_tensor.ndim == 3 + assert gt_tensor.ndim == 3 + + +def test__load_data_and_gt(dataset_train): + data_img, gt_img = dataset_train._load_data_and_gt(index=0) + assert data_img.size == gt_img.size + assert data_img.format == 'JPEG' + assert data_img.mode == 'RGB' + assert gt_img.format == 'PNG' + assert gt_img.mode == 'RGB' From 9b34f39d710a7c0b6ee92b21f095882ef3b4e43d Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 2 Nov 2021 17:01:37 +0100 Subject: [PATCH 049/108] :construction: added assertions to check the shape of the data and the gt tensor --- src/datamodules/RGB/datasets/full_page_dataset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/datamodules/RGB/datasets/full_page_dataset.py b/src/datamodules/RGB/datasets/full_page_dataset.py index 0d55c522..0a6a07ea 100644 --- a/src/datamodules/RGB/datasets/full_page_dataset.py +++ b/src/datamodules/RGB/datasets/full_page_dataset.py @@ -74,12 +74,9 @@ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, self.img_paths_per_page = self.get_gt_data_paths(path, data_folder_name=self.data_folder_name, gt_folder_name=self.gt_folder_name, selection=self.selection) - # TODO: make more fanzy stuff here - # self.img_paths = [pair for page in self.img_paths_per_page for pair in page] - self.num_samples = len(self.img_paths_per_page) if self.num_samples == 0: - raise RuntimeError("Found 0 images in subfolders of: {} \n Supported image extensions are: {}".format( + raise RuntimeError("Found 0 images in: {} \n Supported image extensions are: {}".format( path, ",".join(IMG_EXTENSIONS))) def __len__(self): @@ -99,11 +96,13 @@ def __getitem__(self, index): def _get_train_val_items(self, index): data_img, gt_img = self._load_data_and_gt(index=index) img, gt = self._apply_transformation(data_img, gt_img) + assert img.shape == gt.shape return img, gt def _get_test_items(self, index): data_img, gt_img = self._load_data_and_gt(index=index) img, gt = self._apply_transformation(data_img, gt_img) + assert img.shape == gt.shape return img, gt, index def _load_data_and_gt(self, index): From 3bfdec11e65b86e1f746aee3c0fda2610bb89f09 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 2 Nov 2021 17:07:56 +0100 Subject: [PATCH 050/108] :construction: just compare the W and H of the image/tensor --- src/datamodules/RGB/datasets/full_page_dataset.py | 4 ++-- tests/datamodules/RGB/__init__.py | 0 tests/datamodules/RGB/test_full_page_dataset.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 tests/datamodules/RGB/__init__.py diff --git a/src/datamodules/RGB/datasets/full_page_dataset.py b/src/datamodules/RGB/datasets/full_page_dataset.py index 0a6a07ea..410d0e76 100644 --- a/src/datamodules/RGB/datasets/full_page_dataset.py +++ b/src/datamodules/RGB/datasets/full_page_dataset.py @@ -96,13 +96,13 @@ def __getitem__(self, index): def _get_train_val_items(self, index): data_img, gt_img = self._load_data_and_gt(index=index) img, gt = self._apply_transformation(data_img, gt_img) - assert img.shape == gt.shape + assert img.shape[-2:] == gt.shape[-2:] return img, gt def _get_test_items(self, index): data_img, gt_img = self._load_data_and_gt(index=index) img, gt = self._apply_transformation(data_img, gt_img) - assert img.shape == gt.shape + assert img.shape[-2:] == gt.shape[-2:] return img, gt, index def _load_data_and_gt(self, index): diff --git a/tests/datamodules/RGB/__init__.py b/tests/datamodules/RGB/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datamodules/RGB/test_full_page_dataset.py b/tests/datamodules/RGB/test_full_page_dataset.py index 2b779895..d2ffe681 100644 --- a/tests/datamodules/RGB/test_full_page_dataset.py +++ b/tests/datamodules/RGB/test_full_page_dataset.py @@ -18,7 +18,7 @@ def test_get_gt_data_paths(data_dir): def test_dataset_rgb(dataset_train): data_tensor, gt_tensor = dataset_train[0] - assert data_tensor.shape == gt_tensor.shape + assert data_tensor.shape[-2:] == gt_tensor.shape[-2:] assert data_tensor.ndim == 3 assert gt_tensor.ndim == 3 From 700cff6f98785df3276b1d01288322dd7bfb1e0b Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 15 Nov 2021 10:40:14 +0100 Subject: [PATCH 051/108] :recycle: :wrench: renamed the input size to in_channels --- .../dev_rotnet_resnet18_cb55_10.yaml | 11 ++++------ configs/model/header/single_layer.yaml | 2 +- src/models/backbones/resnet.py | 20 +++++++------------ src/models/headers/fully_connected.py | 4 ++-- 4 files changed, 14 insertions(+), 23 deletions(-) diff --git a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml index e2150025..36c5c405 100644 --- a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml +++ b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml @@ -9,7 +9,7 @@ defaults: - /loss: crossentropyloss.yaml - /metric: accuracy.yaml - /model/backbone: resnet18.yaml - - /model/header: null + - /model/header: single_layer.yaml - /optimizer: adam.yaml - /callbacks: - check_compatibility.yaml @@ -55,21 +55,18 @@ datamodule: model: header: - _target_: src.models.headers.fully_connected.SingleLinear - - num_classes: ${datamodule:num_classes} # needs to be calculated from the output of the last layer of the backbone (do not forget to flatten!) - input_size: 2048 + in_channels: 32768 callbacks: model_checkpoint: - filename: ${checkpoint_folder_name}dev-rotnet-basic-cnn-cb55-10 + filename: ${checkpoint_folder_name}dev-rotnet-resnet18-cb55-10 watch_model: log_freq: 1 logger: wandb: - name: 'dev-rotnet-basic-cnn-cb55-10' + name: 'dev-rotnet-resnet18-cb55-10' tags: [ "best_model", "USL" ] group: 'dev-runs' notes: "Testing" diff --git a/configs/model/header/single_layer.yaml b/configs/model/header/single_layer.yaml index 05b64b26..9ec82edf 100644 --- a/configs/model/header/single_layer.yaml +++ b/configs/model/header/single_layer.yaml @@ -2,4 +2,4 @@ _target_: src.models.headers.fully_connected.SingleLinear num_classes: ${datamodule:num_classes} # needs to be calculated from the output of the last layer of the backbone (do not forget to flatten!) -input_size: 109512 \ No newline at end of file +in_channels: 109512 \ No newline at end of file diff --git a/src/models/backbones/resnet.py b/src/models/backbones/resnet.py index fb90522e..205bee89 100644 --- a/src/models/backbones/resnet.py +++ b/src/models/backbones/resnet.py @@ -1,11 +1,9 @@ """ Model definition adapted from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py """ -import logging import math import torch.nn as nn -import torch.utils.model_zoo as model_zoo model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', @@ -18,7 +16,7 @@ def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + return nn.Conv2d(in_planes, out_planes, kernel_size=(3, 3), stride=stride, padding=1, bias=False) @@ -59,12 +57,12 @@ class _Bottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None): super(_Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=(1, 1), bias=False) self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + self.conv2 = nn.Conv2d(planes, planes, kernel_size=(3, 3), stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=(1, 1), bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample @@ -99,7 +97,7 @@ def __init__(self, block, layers, **kwargs): self.inplanes = 64 super(ResNet, self).__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + self.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) @@ -108,7 +106,6 @@ def __init__(self, block, layers, **kwargs): self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) - self.avgpool = nn.AvgPool2d(7, stride=1) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -123,7 +120,7 @@ def _make_layer(self, block, planes, blocks, stride=1): if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), + kernel_size=(1, 1), stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) @@ -146,9 +143,6 @@ def forward(self, x): x = self.layer3(x) x = self.layer4(x) - x = self.avgpool(x) - x = x.view(x.size(0), -1) - return x @@ -158,7 +152,7 @@ def __init__(self, **kwargs): class ResNet34(ResNet): - def __init__(self,**kwargs): + def __init__(self, **kwargs): super(ResNet34, self).__init__(_BasicBlock, [3, 4, 6, 3], **kwargs) diff --git a/src/models/headers/fully_connected.py b/src/models/headers/fully_connected.py index 29d3482d..20a82de9 100644 --- a/src/models/headers/fully_connected.py +++ b/src/models/headers/fully_connected.py @@ -3,12 +3,12 @@ class SingleLinear(nn.Module): - def __init__(self, num_classes: int = 4, input_size: int = 109512): + def __init__(self, num_classes: int = 4, in_channels: int = 109512): super(SingleLinear, self).__init__() self.fc = nn.Sequential( torch.nn.Flatten(), - nn.Linear(input_size, num_classes) + nn.Linear(in_channels, num_classes) ) def forward(self, x): From ae382008cf842ddecb35ab2de330fa4929171b48 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 15 Nov 2021 10:49:32 +0100 Subject: [PATCH 052/108] :books: added short how to load to the readme --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index 991de23c..3771ecb7 100644 --- a/README.md +++ b/README.md @@ -94,3 +94,20 @@ python run.py trainer.max_epochs=20 datamodule.batch_size=64 11. Go to the root folder of the framework and activate the environment (source .autoenv OR conda activate unsupervised_learning) 12. Log into wandb. Execute `wandb login` and follow the instructions 13. Now you should be able to run the basic experiment from PyCharm + + +### Loading models +You can load the different model parts `backbone` or `header` as well as the whole task. +To load the `backbone` or the `header` you need to add to your experiment config the field `path_to_weights`. +e.g. +``` +model: + header: + path_to_weights: /my/path/to/the/pth/file +``` +To load the whole task you need to provide the path to the whole task to the trainer. This is with the field `resume_from_checkpoint`. +e.g. +``` +trainer: + resume_from_checkpoint: /path/to/.ckpt/file +``` \ No newline at end of file From 7c791989b68d7bb420868067664b120072504a33 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 15 Nov 2021 12:00:18 +0100 Subject: [PATCH 053/108] :art: added dilation --- src/models/backbones/resnet.py | 41 ++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/models/backbones/resnet.py b/src/models/backbones/resnet.py index 205bee89..22fa9075 100644 --- a/src/models/backbones/resnet.py +++ b/src/models/backbones/resnet.py @@ -2,6 +2,7 @@ Model definition adapted from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py """ import math +from typing import Optional, List, Union, Type import torch.nn as nn @@ -23,8 +24,10 @@ def conv3x3(in_planes, out_planes, stride=1): class _BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None): + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation: int = 1): super(_BasicBlock, self).__init__() + if dilation > 1: + raise NotImplementedError("Dilation > 1 not implemented in BasicBlock") self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) @@ -55,15 +58,15 @@ def forward(self, x): class _Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None): + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=False): super(_Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=(1, 1), bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=(3, 3), stride=stride, - padding=1, bias=False) + padding=1, bias=False, dilation=dilation) self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=(1, 1), bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=(1, 1), bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -93,9 +96,17 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, block, layers, **kwargs): - self.inplanes = 64 + def __init__(self, block: Type[Union[_BasicBlock, _Bottleneck]], layers: List[int], + replace_stride_with_dilation: Optional[List[bool]] = None, **kwargs): super(ResNet, self).__init__() + self.inplanes = 64 + self.dilation = 1 + + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + f"replace_stride_with_dilation should be None or a 3-tuple, got {replace_stride_with_dilation}") self.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=2, padding=3, bias=False) @@ -103,9 +114,9 @@ def __init__(self, block, layers, **kwargs): self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -115,8 +126,13 @@ def __init__(self, block, layers, **kwargs): m.weight.data.fill_(1) m.bias.data.zero_() - def _make_layer(self, block, planes, blocks, stride=1): + def _make_layer(self, block: Type[Union[_BasicBlock, _Bottleneck]], planes: int, blocks: int, stride: int = 1, + dilate: bool = False): downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, @@ -125,7 +141,8 @@ def _make_layer(self, block, planes, blocks, stride=1): ) layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) + layers.append(block(inplanes=self.inplanes, planes=planes, stride=stride, downsample=downsample, + dilation=previous_dilation)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) From 0cd554869c347884f1ccb75e8dc7c0d1b419b7ec Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 15 Nov 2021 14:47:22 +0100 Subject: [PATCH 054/108] :recycle: renamed everything to resnet --- configs/experiment/dev_rotnet_resnet18_cb55_10.yaml | 2 +- .../model/header/{single_layer.yaml => resnet_header.yaml} | 2 +- src/models/headers/fully_connected.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) rename configs/model/header/{single_layer.yaml => resnet_header.yaml} (73%) diff --git a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml index 36c5c405..de9718f6 100644 --- a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml +++ b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml @@ -9,7 +9,7 @@ defaults: - /loss: crossentropyloss.yaml - /metric: accuracy.yaml - /model/backbone: resnet18.yaml - - /model/header: single_layer.yaml + - /model/header: resnet_header.yaml - /optimizer: adam.yaml - /callbacks: - check_compatibility.yaml diff --git a/configs/model/header/single_layer.yaml b/configs/model/header/resnet_header.yaml similarity index 73% rename from configs/model/header/single_layer.yaml rename to configs/model/header/resnet_header.yaml index 9ec82edf..d5a01b9f 100644 --- a/configs/model/header/single_layer.yaml +++ b/configs/model/header/resnet_header.yaml @@ -1,4 +1,4 @@ -_target_: src.models.headers.fully_connected.SingleLinear +_target_: src.models.headers.fully_connected.ResNetHeader num_classes: ${datamodule:num_classes} # needs to be calculated from the output of the last layer of the backbone (do not forget to flatten!) diff --git a/src/models/headers/fully_connected.py b/src/models/headers/fully_connected.py index 20a82de9..c8922a18 100644 --- a/src/models/headers/fully_connected.py +++ b/src/models/headers/fully_connected.py @@ -2,11 +2,12 @@ from torch import nn -class SingleLinear(nn.Module): +class ResNetHeader(nn.Module): def __init__(self, num_classes: int = 4, in_channels: int = 109512): - super(SingleLinear, self).__init__() + super(ResNetHeader, self).__init__() self.fc = nn.Sequential( + nn.AdaptiveAvgPool2d(output_size=(None, None)), torch.nn.Flatten(), nn.Linear(in_channels, num_classes) ) From 184bafc95685222f69e73df48619e012ac6ee885 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 15 Nov 2021 14:59:50 +0100 Subject: [PATCH 055/108] :truck: renamed again --- configs/experiment/dev_rotnet_resnet18_cb55_10.yaml | 2 +- configs/experiment/rotnet_resnet18_cb55_full.yaml | 2 +- configs/experiment/rotnet_resnet18_cb55_train10_last.yaml | 2 +- configs/experiment/rotnet_resnet18_cb55_train19_last.yaml | 2 +- .../header/{resnet_header.yaml => resnet_classification.yaml} | 0 5 files changed, 4 insertions(+), 4 deletions(-) rename configs/model/header/{resnet_header.yaml => resnet_classification.yaml} (100%) diff --git a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml index de9718f6..7cc285ca 100644 --- a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml +++ b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml @@ -9,7 +9,7 @@ defaults: - /loss: crossentropyloss.yaml - /metric: accuracy.yaml - /model/backbone: resnet18.yaml - - /model/header: resnet_header.yaml + - /model/header: resnet_classification.yaml - /optimizer: adam.yaml - /callbacks: - check_compatibility.yaml diff --git a/configs/experiment/rotnet_resnet18_cb55_full.yaml b/configs/experiment/rotnet_resnet18_cb55_full.yaml index 5282a45d..7990c297 100644 --- a/configs/experiment/rotnet_resnet18_cb55_full.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_full.yaml @@ -9,7 +9,7 @@ defaults: - /loss: crossentropyloss.yaml - /metric: accuracy.yaml - /model/backbone: resnet18.yaml - - /model/header: single_layer.yaml + - /model/header: resnet_classification.yaml - /optimizer: adam.yaml - /callbacks: - check_compatibility.yaml diff --git a/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml b/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml index 91b0e49b..d997230b 100644 --- a/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml @@ -9,7 +9,7 @@ defaults: - /loss: crossentropyloss.yaml - /metric: accuracy.yaml - /model/backbone: resnet18.yaml - - /model/header: single_layer.yaml + - /model/header: resnet_classification.yaml - /optimizer: adam.yaml - /callbacks: - check_compatibility.yaml diff --git a/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml b/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml index cdc1c7dc..86c0d443 100644 --- a/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml @@ -9,7 +9,7 @@ defaults: - /loss: crossentropyloss.yaml - /metric: accuracy.yaml - /model/backbone: resnet18.yaml - - /model/header: single_layer.yaml + - /model/header: resnet_classification.yaml - /optimizer: adam.yaml - /callbacks: - check_compatibility.yaml diff --git a/configs/model/header/resnet_header.yaml b/configs/model/header/resnet_classification.yaml similarity index 100% rename from configs/model/header/resnet_header.yaml rename to configs/model/header/resnet_classification.yaml From 23d6d4a31d6e65d935b5cd92b05d592be1e036c5 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 15 Nov 2021 15:39:48 +0100 Subject: [PATCH 056/108] :bug: little fix --- src/models/backbones/resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models/backbones/resnet.py b/src/models/backbones/resnet.py index 22fa9075..4f46ce97 100644 --- a/src/models/backbones/resnet.py +++ b/src/models/backbones/resnet.py @@ -58,12 +58,12 @@ def forward(self, x): class _Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=False): + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation: int = 1): super(_Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=(1, 1), bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=(3, 3), stride=stride, - padding=1, bias=False, dilation=dilation) + padding=(1, 1), bias=False, dilation=dilation) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=(1, 1), bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) From e895dd1eb94eb5f4eecc3c50ea707f8d0aae36a0 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 15 Nov 2021 17:17:10 +0100 Subject: [PATCH 057/108] :sparkles: :wrench: added resnet segmentation header and config --- configs/model/header/resnet_segmentation.yaml | 8 +++++ src/models/headers/fully_convolution.py | 32 +++++++++++++++++++ src/utils/utils.py | 2 ++ 3 files changed, 42 insertions(+) create mode 100644 configs/model/header/resnet_segmentation.yaml create mode 100644 src/models/headers/fully_convolution.py diff --git a/configs/model/header/resnet_segmentation.yaml b/configs/model/header/resnet_segmentation.yaml new file mode 100644 index 00000000..7a643403 --- /dev/null +++ b/configs/model/header/resnet_segmentation.yaml @@ -0,0 +1,8 @@ +_target_: src.models.headers.fully_convolution.ResNetFCNHead + +#FCN header for resnets. The in_channels are fixed for the different resnet architectures: +#resnet18, 34 = 512 +#resnet50, 101, 152 = 2048 +in_channels: 512 +num_classes: ${datamodule:num_classes} +output_dims: ${datamodule:dims} \ No newline at end of file diff --git a/src/models/headers/fully_convolution.py b/src/models/headers/fully_convolution.py new file mode 100644 index 00000000..583004a8 --- /dev/null +++ b/src/models/headers/fully_convolution.py @@ -0,0 +1,32 @@ +from typing import Tuple + +from torch import nn + + +class ResNetFCNHead(nn.Sequential): + """ + FCN header for resnets. The in_channels are fixed for the different resnet architectures: + resnet18, 34 = 512 + resnet50, 101, 152 = 2048 + """ + + def __init__(self, in_channels, num_classes, output_dims: Tuple[int] = (256, 256)): + self.output_dims = output_dims + if len(self.output_dims) > 2: + self.output_dims = output_dims[-2:] + inter_channels = in_channels // 4 + layers = [ + nn.Conv2d(in_channels=in_channels, out_channels=inter_channels, kernel_size=(3, 3), padding=(1, 1), + bias=False), + nn.BatchNorm2d(num_features=inter_channels), + nn.ReLU(), + nn.Dropout(p=0.1, inplace=False), + nn.Conv2d(in_channels=inter_channels, out_channels=num_classes, kernel_size=(1, 1), stride=(1, 1)), + ] + + super(ResNetFCNHead, self).__init__(*layers) + + def forward(self, input): + x = super(ResNetFCNHead, self).forward(input) + x = nn.functional.interpolate(x, size=self.output_dims, mode="bilinear", align_corners=False) + return x diff --git a/src/utils/utils.py b/src/utils/utils.py index bff0a453..8839d81d 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -208,3 +208,5 @@ def finish( for lg in logger: if isinstance(lg, WandbLogger): wandb.finish() + if isinstance(trainer.logger, WandbLogger): + trainer.logger.finish() From 53f9c2970e5842a215dfe9d98c9d113fa60b43f2 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 15 Nov 2021 17:20:40 +0100 Subject: [PATCH 058/108] :wrench: configs for resnet 50 and pretrained resnet 18 --- ...tnet_pt_resnet18_cb55_10_segmentation.yaml | 74 +++++++++++++++++++ .../dev_rotnet_resnet50_cb55_10.yaml | 72 ++++++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml create mode 100644 configs/experiment/dev_rotnet_resnet50_cb55_10.yaml diff --git a/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml b/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml new file mode 100644 index 00000000..5d64686a --- /dev/null +++ b/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml @@ -0,0 +1,74 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_HisDB.yaml + - /loss: crossentropyloss.yaml + - /metric: hisdbiou.yaml + - /model/backbone: resnet18.yaml + - /model/header: resnet_segmentation.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +seed: 42 + +train: True +test: False + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 3 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 1 + confusion_matrix_val: True + confusion_matrix_test: True + +datamodule: + _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped + + data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + gt_folder_name: gt + +model: + backbone: + path_to_weights: /netscratch/experiments_lars_paul/lars/2021-11-15/16-08-51/checkpoints/epoch=1/backbone.pth + header: + in_channels: 512 + +callbacks: + model_checkpoint: + filename: ${checkpoint_folder_name}dev-rotnet-pt-resnet18-cb55-10-segmentation + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'dev-rotnet-pt-resnet18-cb55-10-segmetnation' + tags: [ "best_model", "USL" ] + group: 'dev-runs' + notes: "Testing" diff --git a/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml b/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml new file mode 100644 index 00000000..3c65fe6c --- /dev/null +++ b/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml @@ -0,0 +1,72 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: classification.yaml + - /loss: crossentropyloss.yaml + - /metric: accuracy.yaml + - /model/backbone: resnet50.yaml + - /model/header: resnet_classification.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +seed: 42 + +train: True +test: False + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 3 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 1 + confusion_matrix_val: False + confusion_matrix_test: False + +datamodule: + _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped + + data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation + crop_size: 256 + num_workers: 4 + batch_size: 16 + shuffle: True + drop_last: True + data_folder_name: data + +model: + header: + # needs to be calculated from the output of the last layer of the backbone (do not forget to flatten!) + in_channels: 131072 + +callbacks: + model_checkpoint: + filename: ${checkpoint_folder_name}dev-rotnet-resnet50-cb55-10 + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'dev-rotnet-resnet50-cb55-10' + tags: [ "best_model", "USL" ] + group: 'dev-runs' + notes: "Testing" From 7ea2fef585b0985bfdd62a221b8844e67328c239 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Wed, 17 Nov 2021 17:47:19 +0100 Subject: [PATCH 059/108] :recycle: make everything a tuple --- src/models/backbones/resnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/models/backbones/resnet.py b/src/models/backbones/resnet.py index 4f46ce97..a8eda53e 100644 --- a/src/models/backbones/resnet.py +++ b/src/models/backbones/resnet.py @@ -108,11 +108,11 @@ def __init__(self, block: Type[Union[_BasicBlock, _Bottleneck]], layers: List[in raise ValueError( f"replace_stride_with_dilation should be None or a 3-tuple, got {replace_stride_with_dilation}") - self.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=2, padding=3, + self.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.maxpool = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) @@ -136,7 +136,7 @@ def _make_layer(self, block: Type[Union[_BasicBlock, _Bottleneck]], planes: int, if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=(1, 1), stride=stride, bias=False), + kernel_size=(1, 1), stride=(stride, stride), bias=False), nn.BatchNorm2d(planes * block.expansion), ) From 4734eb154445b66f75652b434dc922ed5777372f Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 17 Nov 2021 17:54:17 +0100 Subject: [PATCH 060/108] :construction: Rolf's data format --- configs/datamodule/rolf_format_dev.yaml | 170 ++++++++++++++++++ .../synthetic_baby_unet_rolf_format.yaml | 56 ++++++ .../{datamodule_cropped.py => datamodule.py} | 62 +++---- .../{cropped_dataset.py => dataset.py} | 2 +- .../RolfFormat/utils/image_analytics.py | 95 ++-------- src/datamodules/RolfFormat/utils/misc.py | 26 --- src/tasks/RGB/semantic_segmentation.py | 2 +- 7 files changed, 276 insertions(+), 137 deletions(-) create mode 100644 configs/datamodule/rolf_format_dev.yaml create mode 100644 configs/experiment/synthetic_baby_unet_rolf_format.yaml rename src/datamodules/RolfFormat/{datamodule_cropped.py => datamodule.py} (71%) rename src/datamodules/RolfFormat/datasets/{cropped_dataset.py => dataset.py} (99%) diff --git a/configs/datamodule/rolf_format_dev.yaml b/configs/datamodule/rolf_format_dev.yaml new file mode 100644 index 00000000..f6623ad7 --- /dev/null +++ b/configs/datamodule/rolf_format_dev.yaml @@ -0,0 +1,170 @@ +_target_: src.datamodules.RolfFormat.datamodule.DataModuleRolfFormat + +num_workers: 4 +batch_size: 16 +shuffle: True +drop_last: True + +data_root: /netscratch/datasets/semantic_segmentation/rolf_format + +train_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1000 + range_to: 1029 + append2: + doc_dir: "SetA1_sizeM_Rolf/layoutR/data" + doc_names: "A1-MR-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" + gt_names: "A1-MR-truthD-####.gif" + range_from: 1600 + range_to: 1609 + +val_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1000 + range_to: 1029 + append2: + doc_dir: "SetA1_sizeM_Rolf/layoutR/data" + doc_names: "A1-MR-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" + gt_names: "A1-MR-truthD-####.gif" + range_from: 1600 + range_to: 1609 + +test_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1800 + range_to: 1809 + append2: + doc_dir: "SetA1_sizeM_Rolf/layoutR/data" + doc_names: "A1-MR-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" + gt_names: "A1-MR-truthD-####.gif" + range_from: 1800 + range_to: 1809 + +image_analytics: + mean: + R: 0.8705590692084239 + G: 0.7458294488380534 + B: 0.6311273887799564 + std: + R: 0.18879972204374673 + G: 0.18806443944992593 + B: 0.17460372273808386 + +image_dims: + width: 640 + height: 896 + +classes: + background: + color: + R: 0 + G: 0 + B: 0 + weight: 0.0007976091978914681 + + class1: + color: + R: 0 + G: 102 + B: 0 + weight: 0.2913455129929105 + + class2: + color: + R: 0 + G: 102 + B: 102 + weight: 0.005224675677940002 + + class3: + color: + R: 0 + G: 153 + B: 153 + weight: 0.018517698236153693 + + class4: + color: + R: 0 + G: 255 + B: 0 + weight: 0.14592118584029914 + + class5: + color: + R: 0 + G: 255 + B: 255 + weight: 0.01313123697869497 + + class6: + color: + R: 102 + G: 0 + B: 0 + weight: 0.09552945122836413 + + class7: + color: + R: 102 + G: 0 + B: 102 + weight: 0.04887260934092041 + + class8: + color: + R: 102 + G: 102 + B: 0 + weight: 0.00172073445612366 + + class9: + color: + R: 153 + G: 0 + B: 153 + weight: 0.19027731073074847 + + class10: + color: + R: 153 + G: 153 + B: 0 + weight: 0.006524524841233285 + + class11: + color: + R: 255 + G: 0 + B: 0 + weight: 0.06603666735259975 + + class12: + color: + R: 255 + G: 0 + B: 255 + weight: 0.11209021851644496 + + class13: + color: + R: 255 + G: 255 + B: 0 + weight: 0.004010564609675552 + diff --git a/configs/experiment/synthetic_baby_unet_rolf_format.yaml b/configs/experiment/synthetic_baby_unet_rolf_format.yaml new file mode 100644 index 00000000..7391f8fe --- /dev/null +++ b/configs/experiment/synthetic_baby_unet_rolf_format.yaml @@ -0,0 +1,56 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_RGB.yaml + - /datamodule: rolf_format_dev.yaml + - /loss: crossentropyloss.yaml + - /metric: iou.yaml + - /model/backbone: baby_unet_model.yaml + - /model/header: identity.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +train: True +test: True + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 2000 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 100 + confusion_matrix_val: True + confusion_matrix_test: True + +callbacks: + model_checkpoint: + monitor: "val/iou" + mode: "max" + filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data + watch_model: + log_freq: 100 + +logger: + wandb: + name: 'synthetic-baby-unet-rolf-format' + tags: [ "best_model", "synthetic", "RolfFormat" ] + group: 'synthetic' diff --git a/src/datamodules/RolfFormat/datamodule_cropped.py b/src/datamodules/RolfFormat/datamodule.py similarity index 71% rename from src/datamodules/RolfFormat/datamodule_cropped.py rename to src/datamodules/RolfFormat/datamodule.py index e85a14a5..8d1e2658 100644 --- a/src/datamodules/RolfFormat/datamodule_cropped.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -2,37 +2,44 @@ from typing import Union, List, Optional import torch +from dataclasses import dataclass from torch.utils.data import DataLoader from torchvision import transforms -from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB -from src.datamodules.RGB.utils.image_analytics import get_analytics -from src.datamodules.RGB.utils.misc import validate_path_for_segmentation -from src.datamodules.RGB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling, \ - IntegerEncoding -from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget +from src.datamodules.RolfFormat.datasets.dataset import DatasetRolfFormat +from src.datamodules.RolfFormat.utils.image_analytics import get_analytics_data, get_analytics_gt +from src.datamodules.RolfFormat.utils.twin_transforms import IntegerEncoding +from src.datamodules.RolfFormat.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.datamodules.base_datamodule import AbstractDatamodule from src.utils import utils log = utils.get_logger(__name__) - -class DataModuleCroppedRGB(AbstractDatamodule): - def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, - selection_train: Optional[Union[int, List[str]]] = None, - selection_val: Optional[Union[int, List[str]]] = None, - selection_test: Optional[Union[int, List[str]]] = None, - crop_size: int = 256, num_workers: int = 4, batch_size: int = 8, +@dataclass +class DatasetSpecs: + data_root: str + doc_dir: str + doc_names: str + gt_dir: str + gt_names: str + range_from: int + range_to: int + + +class DataModuleRolfFormat(AbstractDatamodule): + def __init__(self, data_root: str, + train_specs=None, val_specs=None, test_specs=None, + image_analytics=None, classes=None, image_dims=None, + num_workers: int = 4, batch_size: int = 8, shuffle: bool = True, drop_last: bool = True): super().__init__() - self.data_folder_name = data_folder_name - self.gt_folder_name = gt_folder_name + train_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in train_specs.items()] + val_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in val_specs.items()] + test_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in test_specs.items()] - analytics_data, analytics_gt = get_analytics(input_path=Path(data_dir), - data_folder_name=self.data_folder_name, - gt_folder_name=self.gt_folder_name, - get_gt_data_paths_func=CroppedDatasetRGB.get_gt_data_paths) + analytics_data = get_analytics_data() + analytics_gt = get_analytics_gt() self.mean = analytics_data['mean'] self.std = analytics_data['std'] @@ -41,7 +48,7 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.num_classes = len(self.class_encodings) self.class_weights = analytics_gt['class_weights'] - self.twin_transform = TwinRandomCrop(crop_size=crop_size) + self.twin_transform = None self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std)])) self.target_transform = OnlyTarget(IntegerEncoding(class_encodings=self.class_encodings_tensor)) @@ -52,20 +59,13 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.shuffle = shuffle self.drop_last = drop_last - self.data_dir = validate_path_for_segmentation(data_dir=data_dir, data_folder_name=self.data_folder_name, - gt_folder_name=self.gt_folder_name) - - self.selection_train = selection_train - self.selection_val = selection_val - self.selection_test = selection_test - - self.dims = (3, crop_size, crop_size) + self.dims = (3, image_dims['width'], image_dims['height']) def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: - self.train = CroppedDatasetRGB(**self._create_dataset_parameters('train'), selection=self.selection_train) - self.val = CroppedDatasetRGB(**self._create_dataset_parameters('val'), selection=self.selection_val) + self.train = DatasetRolfFormat(**self._create_dataset_parameters('train'), selection=self.selection_train) + self.val = DatasetRolfFormat(**self._create_dataset_parameters('val'), selection=self.selection_val) self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) @@ -73,7 +73,7 @@ def setup(self, stage: Optional[str] = None): drop_last=self.drop_last) if stage == 'test' or stage is not None: - self.test = CroppedDatasetRGB(**self._create_dataset_parameters('test'), selection=self.selection_test) + self.test = DatasetRolfFormat(**self._create_dataset_parameters('test'), selection=self.selection_test) # self._check_min_num_samples(num_samples=len(self.test), data_split='test', # drop_last=False) diff --git a/src/datamodules/RolfFormat/datasets/cropped_dataset.py b/src/datamodules/RolfFormat/datasets/dataset.py similarity index 99% rename from src/datamodules/RolfFormat/datasets/cropped_dataset.py rename to src/datamodules/RolfFormat/datasets/dataset.py index b31a2aae..6ed3776f 100644 --- a/src/datamodules/RolfFormat/datasets/cropped_dataset.py +++ b/src/datamodules/RolfFormat/datasets/dataset.py @@ -20,7 +20,7 @@ log = utils.get_logger(__name__) -class CroppedDatasetRGB(data.Dataset): +class DatasetRolfFormat(data.Dataset): """A generic data loader where the images are arranged in this way: :: root/gt/xxx.png diff --git a/src/datamodules/RolfFormat/utils/image_analytics.py b/src/datamodules/RolfFormat/utils/image_analytics.py index 839636a0..58d10ec8 100644 --- a/src/datamodules/RolfFormat/utils/image_analytics.py +++ b/src/datamodules/RolfFormat/utils/image_analytics.py @@ -17,81 +17,24 @@ from src.datamodules.RGB.utils.misc import pil_loader -def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, get_gt_data_paths_func, **kwargs): - """ - Parameters - ---------- - input_path: Path to dataset +def get_analytics_data(file_names_data, **kwargs): + mean, std = compute_mean_std(file_names=file_names_data, **kwargs) + analytics_data = {'mean': mean.tolist(), + 'std': std.tolist()} - Returns - ------- - """ - expected_keys_data = ['mean', 'std'] - expected_keys_gt = ['class_weights', 'class_encodings'] - - analytics_path_data = input_path / f'analytics.data.{data_folder_name}.json' - analytics_path_gt = input_path / f'analytics.gt.{gt_folder_name}.json' - - analytics_data = None - analytics_gt = None - - missing_analytics_data = True - missing_analytics_gt = True - - if analytics_path_data.exists(): - with analytics_path_data.open(mode='r') as f: - analytics_data = json.load(fp=f) - # check if analytics file is complete - if all(k in analytics_data for k in expected_keys_data): - missing_analytics_data = False - - if analytics_path_gt.exists(): - with analytics_path_gt.open(mode='r') as f: - analytics_gt = json.load(fp=f) - # check if analytics file is complete - if all(k in analytics_gt for k in expected_keys_gt): - missing_analytics_gt = False - - if missing_analytics_data or missing_analytics_gt: - train_path = input_path / 'train' - gt_data_path_list = get_gt_data_paths_func(train_path, data_folder_name=data_folder_name, - gt_folder_name=gt_folder_name) - file_names_data = np.asarray([str(item[0]) for item in gt_data_path_list]) - file_names_gt = np.asarray([str(item[1]) for item in gt_data_path_list]) - - if missing_analytics_data: - mean, std = compute_mean_std(file_names=file_names_data, **kwargs) - analytics_data = {'mean': mean.tolist(), - 'std': std.tolist()} - # save json - try: - with analytics_path_data.open(mode='w') as f: - json.dump(obj=analytics_data, fp=f) - except IOError as e: - if e.errno == errno.EACCES: - print(f'WARNING: No permissions to write analytics file ({analytics_path_data})') - else: - raise - - if missing_analytics_gt: - # Measure weights for class balancing - logging.info(f'Measuring class weights') - # create a list with all gt file paths - class_weights, class_encodings = _get_class_frequencies_weights_segmentation(gt_images=file_names_gt, - **kwargs) - analytics_gt = {'class_weights': class_weights, - 'class_encodings': class_encodings} - # save json - try: - with analytics_path_gt.open(mode='w') as f: - json.dump(obj=analytics_gt, fp=f) - except IOError as e: - if e.errno == errno.EACCES: - print(f'WARNING: No permissions to write analytics file ({analytics_path_gt})') - else: - raise - - return analytics_data, analytics_gt + return analytics_data + + +def get_analytics_gt(file_names_gt, **kwargs): + # Measure weights for class balancing + logging.info(f'Measuring class weights') + # create a list with all gt file paths + class_weights, class_encodings = _get_class_frequencies_weights_segmentation(gt_images=file_names_gt, + **kwargs) + analytics_gt = {'class_weights': class_weights, + 'class_encodings': class_encodings} + + return analytics_gt def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): @@ -368,7 +311,3 @@ def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) class_weights = (1 / num_samples_per_class) / ((1 / num_samples_per_class).sum()) return class_weights.tolist(), classes - - -if __name__ == '__main__': - print(get_analytics(input_path=Path('tests/dummy_data/dummy_dataset'))) diff --git a/src/datamodules/RolfFormat/utils/misc.py b/src/datamodules/RolfFormat/utils/misc.py index c0de22d4..1ccdad41 100644 --- a/src/datamodules/RolfFormat/utils/misc.py +++ b/src/datamodules/RolfFormat/utils/misc.py @@ -47,29 +47,3 @@ def convert_to_rgb(pic): raise TypeError(f"unsupported image type {pic.mode}") return pic - -def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_name: str): - if data_dir is None: - raise PathNone("Please provide the path to root dir of the dataset " - "(folder containing the train/val/test folder)") - else: - split_names = ['train', 'val', 'test'] - type_names = [data_folder_name, gt_folder_name] - - data_folder = Path(data_dir) - if not data_folder.is_dir(): - raise PathNotDir("Please provide the path to root dir of the dataset " - "(folder containing the train/val/test folder)") - split_folders = [d for d in data_folder.iterdir() if d.is_dir() and d.name in split_names] - if len(split_folders) != 3: - raise PathMissingSplitDir(f'Your path needs to contain train/val/test and ' - f'each of them a folder {data_folder_name} and {gt_folder_name}') - - # check if we have train/test/val - for split in split_folders: - type_folders = [d for d in split.iterdir() if d.is_dir() and d.name in type_names] - # check if we have data/gt - if len(type_folders) != 2: - raise PathMissingDirinSplitDir(f'Folder {split.name} does not contain a {data_folder_name} ' - f'and {gt_folder_name} folder') - return Path(data_dir) diff --git a/src/tasks/RGB/semantic_segmentation.py b/src/tasks/RGB/semantic_segmentation.py index 3e978751..01106966 100644 --- a/src/tasks/RGB/semantic_segmentation.py +++ b/src/tasks/RGB/semantic_segmentation.py @@ -7,7 +7,7 @@ import torchmetrics from src.tasks.base_task import AbstractTask -from src.datamodules.DivaHisDB.utils.output_tools import _get_argmax +from src.datamodules.RGB.utils.output_tools import _get_argmax from src.utils import utils from src.tasks.utils.outputs import OutputKeys, reduce_dict From 0a0c67266f3d3d8cdc9522c3d502e0e521080033 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Wed, 17 Nov 2021 18:05:06 +0100 Subject: [PATCH 061/108] :pushpin: fixed dependencies to specific versions --- requirements.txt | 12 ++++++------ tests/requirements.txt | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3985bec2..881808fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ # --------- pytorch --------- # torch==1.8.1 -torchvision>=0.9.1 -pytorch-lightning>=1.4.4 -lightning-bolts>=0.3.2 -torchmetrics>=0.5.0 +torchvision==0.9.1 +pytorch-lightning==1.4.4 +lightning-bolts==0.3.2 +torchmetrics==0.5.0 # --------- hydra --------- # hydra-core==1.1.0 @@ -11,13 +11,13 @@ hydra-colorlog==1.1.0 hydra-optuna-sweeper==1.1.0 # --------- loggers --------- # -wandb>=0.10.31 +wandb==0.12.6 # --------- others --------- # rich python-dotenv pre-commit -scikit-learn>=0.23.2 +scikit-learn==0.23.2 pandas matplotlib seaborn diff --git a/tests/requirements.txt b/tests/requirements.txt index 09f038f5..87f2d71a 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,6 +1,6 @@ coverage -codecov>=2.1 -pytest>=3.0.5 +codecov==2.1 +pytest==3.0.5 pytest-cov pytest-flake8 flake8 From eb01d13795c9787ad7853a9d080791103c247d2b Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Wed, 17 Nov 2021 18:14:22 +0100 Subject: [PATCH 062/108] :pushpin: fixed test dependencies --- requirements.txt | 8 ++++---- tests/requirements.txt | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/requirements.txt b/requirements.txt index 881808fa..725fa2ad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ # --------- pytorch --------- # torch==1.8.1 torchvision==0.9.1 -pytorch-lightning==1.4.4 -lightning-bolts==0.3.2 -torchmetrics==0.5.0 +pytorch-lightning==1.4.8 +lightning-bolts==0.4.0 +torchmetrics==0.5.1 # --------- hydra --------- # hydra-core==1.1.0 @@ -17,7 +17,7 @@ wandb==0.12.6 rich python-dotenv pre-commit -scikit-learn==0.23.2 +scikit-learn==0.24.1 pandas matplotlib seaborn diff --git a/tests/requirements.txt b/tests/requirements.txt index 87f2d71a..c6ec995c 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,8 +1,8 @@ -coverage -codecov==2.1 -pytest==3.0.5 -pytest-cov -pytest-flake8 -flake8 -check-manifest +coverage==6.0.1 +codecov==2.1.12 +pytest==6.2.3 +pytest-cov==3.0.0 +pytest-flake8==1.0.7 +flake8==3.9.2 +check-manifest==0.47 twine==1.13.0 \ No newline at end of file From 15eec2238f6d718f38cbb0a7c83594b08ea7f225 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Wed, 17 Nov 2021 18:17:46 +0100 Subject: [PATCH 063/108] :pushpin: and also the last part --- requirements.txt | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 725fa2ad..e173e5f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,11 +14,10 @@ hydra-optuna-sweeper==1.1.0 wandb==0.12.6 # --------- others --------- # -rich -python-dotenv -pre-commit +rich==10.1.0 +python-dotenv==0.17.0 scikit-learn==0.24.1 -pandas -matplotlib -seaborn -pytest +pandas==1.2.4 +matplotlib==3.4.1 +seaborn==0.11.1 +pytest==6.2.3 From 60929454acee97f06cf9a0ade9e3aec853a759b0 Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 18 Nov 2021 12:09:26 +0100 Subject: [PATCH 064/108] :construction: Rolf's data format --- configs/datamodule/rolf_format_dev.yaml | 170 ++++++++--------- src/datamodules/RolfFormat/datamodule.py | 141 ++++++++++---- .../RolfFormat/datasets/dataset.py | 178 +++++++----------- .../RolfFormat/utils/image_analytics.py | 17 +- 4 files changed, 261 insertions(+), 245 deletions(-) diff --git a/configs/datamodule/rolf_format_dev.yaml b/configs/datamodule/rolf_format_dev.yaml index f6623ad7..9c7551a0 100644 --- a/configs/datamodule/rolf_format_dev.yaml +++ b/configs/datamodule/rolf_format_dev.yaml @@ -14,14 +14,14 @@ train_specs: gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" gt_names: "A1-MD-truthD-####.gif" range_from: 1000 - range_to: 1029 + range_to: 1059 append2: doc_dir: "SetA1_sizeM_Rolf/layoutR/data" doc_names: "A1-MR-page-####.jpg" gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" gt_names: "A1-MR-truthD-####.gif" - range_from: 1600 - range_to: 1609 + range_from: 1000 + range_to: 1059 val_specs: append1: @@ -29,15 +29,15 @@ val_specs: doc_names: "A1-MD-page-####.jpg" gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" gt_names: "A1-MD-truthD-####.gif" - range_from: 1000 - range_to: 1029 + range_from: 1060 + range_to: 1079 append2: doc_dir: "SetA1_sizeM_Rolf/layoutR/data" doc_names: "A1-MR-page-####.jpg" gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" gt_names: "A1-MR-truthD-####.gif" - range_from: 1600 - range_to: 1609 + range_from: 1060 + range_to: 1079 test_specs: append1: @@ -45,126 +45,114 @@ test_specs: doc_names: "A1-MD-page-####.jpg" gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" gt_names: "A1-MD-truthD-####.gif" - range_from: 1800 - range_to: 1809 + range_from: 1080 + range_to: 1099 append2: doc_dir: "SetA1_sizeM_Rolf/layoutR/data" doc_names: "A1-MR-page-####.jpg" gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" gt_names: "A1-MR-truthD-####.gif" - range_from: 1800 - range_to: 1809 - -image_analytics: - mean: - R: 0.8705590692084239 - G: 0.7458294488380534 - B: 0.6311273887799564 - std: - R: 0.18879972204374673 - G: 0.18806443944992593 - B: 0.17460372273808386 + range_from: 1080 + range_to: 1099 image_dims: width: 640 height: 896 +image_analytics: + mean: + R: 0.8664800196201524 + G: 0.7408864118075618 + B: 0.6299955083595935 + std: + R: 0.2156624188591712 + G: 0.20890185198454636 + B: 0.1870731300038113 + classes: - background: + class0: color: - R: 0 - G: 0 - B: 0 - weight: 0.0007976091978914681 - + R: 0 + G: 0 + B: 0 + weight: 0.0016602289391364547 class1: color: - R: 0 - G: 102 - B: 0 - weight: 0.2913455129929105 - + R: 0 + G: 102 + B: 0 + weight: 0.22360020547468618 class2: color: - R: 0 - G: 102 - B: 102 - weight: 0.005224675677940002 - + R: 0 + G: 102 + B: 102 + weight: 0.014794833923108578 class3: color: - R: 0 - G: 153 - B: 153 - weight: 0.018517698236153693 - + R: 0 + G: 153 + B: 153 + weight: 0.05384506923533185 class4: color: - R: 0 - G: 255 - B: 0 - weight: 0.14592118584029914 - + R: 0 + G: 255 + B: 0 + weight: 0.1115978481679602 class5: color: - R: 0 - G: 255 - B: 255 - weight: 0.01313123697869497 - + R: 0 + G: 255 + B: 255 + weight: 0.037436533973406926 class6: color: - R: 102 - G: 0 - B: 0 - weight: 0.09552945122836413 - + R: 102 + G: 0 + B: 0 + weight: 0.12569866772812885 class7: color: - R: 102 - G: 0 - B: 102 - weight: 0.04887260934092041 - + R: 102 + G: 0 + B: 102 + weight: 0.03591164457353043 class8: color: - R: 102 - G: 102 - B: 0 - weight: 0.00172073445612366 - + R: 102 + G: 102 + B: 0 + weight: 0.01062086078798502 class9: color: - R: 153 - G: 0 - B: 153 - weight: 0.19027731073074847 - + R: 153 + G: 0 + B: 153 + weight: 0.1491578366712268 class10: color: - R: 153 - G: 153 - B: 0 - weight: 0.006524524841233285 - + R: 153 + G: 153 + B: 0 + weight: 0.0414074692141804 class11: color: - R: 255 - G: 0 - B: 0 - weight: 0.06603666735259975 - + R: 255 + G: 0 + B: 0 + weight: 0.08600602291055298 class12: color: - R: 255 - G: 0 - B: 255 - weight: 0.11209021851644496 - + R: 255 + G: 0 + B: 255 + weight: 0.08349157426652898 class13: color: - R: 255 - G: 255 - B: 0 - weight: 0.004010564609675552 + R: 255 + G: 255 + B: 0 + weight: 0.024771204134236315 + diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index 8d1e2658..86cbdc1d 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -6,8 +6,8 @@ from torch.utils.data import DataLoader from torchvision import transforms -from src.datamodules.RolfFormat.datasets.dataset import DatasetRolfFormat -from src.datamodules.RolfFormat.utils.image_analytics import get_analytics_data, get_analytics_gt +from src.datamodules.RolfFormat.datasets.dataset import DatasetRolfFormat, DatasetSpecs +from src.datamodules.RolfFormat.utils.image_analytics import get_analytics_data, get_analytics_gt, get_image_dims from src.datamodules.RolfFormat.utils.twin_transforms import IntegerEncoding from src.datamodules.RolfFormat.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.datamodules.base_datamodule import AbstractDatamodule @@ -15,16 +15,6 @@ log = utils.get_logger(__name__) -@dataclass -class DatasetSpecs: - data_root: str - doc_dir: str - doc_names: str - gt_dir: str - gt_names: str - range_from: int - range_to: int - class DataModuleRolfFormat(AbstractDatamodule): def __init__(self, data_root: str, @@ -34,12 +24,41 @@ def __init__(self, data_root: str, shuffle: bool = True, drop_last: bool = True): super().__init__() - train_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in train_specs.items()] - val_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in val_specs.items()] - test_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in test_specs.items()] + self.train_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in train_specs.items()] + self.val_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in val_specs.items()] + self.test_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in test_specs.items()] + + if image_analytics is None or classes is None or image_dims is None: + train_paths_data_gt = DatasetRolfFormat.get_gt_data_paths(list_specs=self.train_dataset_specs) + + if image_analytics is None: + analytics_data = get_analytics_data(data_gt_path_list=train_paths_data_gt) + self._print_analytics_data(analytics_data=analytics_data) + else: + analytics_data = {'mean': [image_analytics['mean']['R'], + image_analytics['mean']['G'], + image_analytics['mean']['B']], + 'std': [image_analytics['std']['R'], + image_analytics['std']['G'], + image_analytics['std']['B']]} + + if classes is None: + analytics_gt = get_analytics_gt(data_gt_path_list=train_paths_data_gt) + self._print_analytics_gt(analytics_gt=analytics_gt) + else: + analytics_gt = {'class_encodings': [], + 'class_weights': []} + for _, class_specs in classes.items(): + analytics_gt['class_encodings'].append([class_specs['color']['R'], + class_specs['color']['G'], + class_specs['color']['B']]) + analytics_gt['class_weights'].append(class_specs['weight']) + + if image_dims is None: + image_dims = get_image_dims(data_gt_path_list=train_paths_data_gt) + self._print_image_dims(image_dims=image_dims) - analytics_data = get_analytics_data() - analytics_gt = get_analytics_gt() + self.dims = (3, image_dims['width'], image_dims['height']) self.mean = analytics_data['mean'] self.std = analytics_data['std'] @@ -59,23 +78,74 @@ def __init__(self, data_root: str, self.shuffle = shuffle self.drop_last = drop_last - self.dims = (3, image_dims['width'], image_dims['height']) + def _print_analytics_data(self, analytics_data): + indent = 4 * ' ' + lines = [''] + lines.append(f'image_analytics:') + lines.append(f'{indent}mean:') + lines.append(f'{indent}{indent}R: {analytics_data["mean"][0]}') + lines.append(f'{indent}{indent}G: {analytics_data["mean"][1]}') + lines.append(f'{indent}{indent}B: {analytics_data["mean"][2]}') + lines.append(f'{indent}std:') + lines.append(f'{indent}{indent}R: {analytics_data["std"][0]}') + lines.append(f'{indent}{indent}G: {analytics_data["std"][1]}') + lines.append(f'{indent}{indent}B: {analytics_data["std"][2]}') + + print_string = '\n'.join(lines) + log.info(print_string) + + def _print_analytics_gt(self, analytics_gt): + indent = 4 * ' ' + lines = [''] + lines.append(f'classes:') + for i, class_specs in enumerate(zip(analytics_gt['class_encodings'], analytics_gt['class_weights'])): + lines.append(f'{indent}class{i}:') + lines.append(f'{indent}{indent}color:') + lines.append(f'{indent}{indent}{indent}R: {class_specs[0][0]}') + lines.append(f'{indent}{indent}{indent}G: {class_specs[0][1]}') + lines.append(f'{indent}{indent}{indent}B: {class_specs[0][2]}') + lines.append(f'{indent}{indent}weight: {class_specs[1]}') + + print_string = '\n'.join(lines) + log.info(print_string) + + def _print_image_dims(self, image_dims): + indent = 4 * ' ' + lines = [''] + lines.append(f'image_dims:') + lines.append(f'{indent}width: {image_dims["width"]}') + lines.append(f'{indent}height: {image_dims["height"]}') + + print_string = '\n'.join(lines) + log.info(print_string) def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: - self.train = DatasetRolfFormat(**self._create_dataset_parameters('train'), selection=self.selection_train) - self.val = DatasetRolfFormat(**self._create_dataset_parameters('val'), selection=self.selection_val) - - self._check_min_num_samples(num_samples=len(self.train), data_split='train', - drop_last=self.drop_last) - self._check_min_num_samples(num_samples=len(self.val), data_split='val', - drop_last=self.drop_last) + self.train = DatasetRolfFormat(dataset_specs=self.train_dataset_specs, + is_test=False, + classes=self.class_encodings, + image_transform=self.image_transform, + target_transform=self.target_transform, + twin_transform=self.twin_transform) + self.val = DatasetRolfFormat(dataset_specs=self.val_dataset_specs, + is_test=False, + classes=self.class_encodings, + image_transform=self.image_transform, + target_transform=self.target_transform, + twin_transform=self.twin_transform) + + self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) + self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) if stage == 'test' or stage is not None: - self.test = DatasetRolfFormat(**self._create_dataset_parameters('test'), selection=self.selection_test) - # self._check_min_num_samples(num_samples=len(self.test), data_split='test', - # drop_last=False) + self.test = DatasetRolfFormat(dataset_specs=self.test_dataset_specs, + is_test=True, + classes=self.class_encodings, + image_transform=self.image_transform, + target_transform=self.target_transform, + twin_transform=self.twin_transform) + # self._check_min_num_samples(num_samples=len(self.test), data_split='test', drop_last=False) def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: bool): num_processes = self.trainer.num_processes @@ -119,20 +189,9 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] drop_last=False, pin_memory=True) - def _create_dataset_parameters(self, dataset_type: str = 'train'): - is_test = dataset_type == 'test' - return {'path': self.data_dir / dataset_type, - 'data_folder_name': self.data_folder_name, - 'gt_folder_name': self.gt_folder_name, - 'image_transform': self.image_transform, - 'target_transform': self.target_transform, - 'twin_transform': self.twin_transform, - 'classes': self.class_encodings, - 'is_test': is_test} - - def get_img_name_coordinates(self, index): + def get_img_name(self, index): """ - Returns the original filename of the crop and its coordinate based on the index. + Returns the original filename of the doc image. You can just use this during testing! :param index: :return: diff --git a/src/datamodules/RolfFormat/datasets/dataset.py b/src/datamodules/RolfFormat/datasets/dataset.py index 6ed3776f..8079489a 100644 --- a/src/datamodules/RolfFormat/datasets/dataset.py +++ b/src/datamodules/RolfFormat/datasets/dataset.py @@ -6,6 +6,7 @@ import re from pathlib import Path from typing import List, Tuple, Union, Optional +from dataclasses import asdict, dataclass import torch.utils.data as data from omegaconf import ListConfig @@ -19,6 +20,16 @@ log = utils.get_logger(__name__) +@dataclass +class DatasetSpecs: + data_root: str + doc_dir: str + doc_names: str + gt_dir: str + gt_names: str + range_from: int + range_to: int + class DatasetRolfFormat(data.Dataset): """A generic data loader where the images are arranged in this way: :: @@ -32,8 +43,7 @@ class DatasetRolfFormat(data.Dataset): root/data/xxz.png """ - def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, - selection: Optional[Union[int, List[str]]] = None, + def __init__(self, dataset_specs: List[DatasetSpecs], is_test=False, image_transform=None, target_transform=None, twin_transform=None, classes=None, **kwargs): """ @@ -54,14 +64,10 @@ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, A function to load an image given its path. """ - self.path = path - self.data_folder_name = data_folder_name - self.gt_folder_name = gt_folder_name - self.selection = selection + self.dataset_specs = dataset_specs # Init list self.classes = classes - # self.crops_per_image = crops_per_image # transformations self.image_transform = image_transform @@ -71,16 +77,11 @@ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, self.is_test = is_test # List of tuples that contain the path to the gt and image that belong together - self.img_paths_per_page = self.get_gt_data_paths(path, data_folder_name=self.data_folder_name, - gt_folder_name=self.gt_folder_name, selection=self.selection) - - # TODO: make more fanzy stuff here - # self.img_paths = [pair for page in self.img_paths_per_page for pair in page] + self.img_paths_per_page = self.get_gt_data_paths(list_specs=self.dataset_specs) self.num_samples = len(self.img_paths_per_page) - if self.num_samples == 0: - raise RuntimeError("Found 0 images in subfolders of: {} \n Supported image extensions are: {}".format( - path, ",".join(IMG_EXTENSIONS))) + + assert self.num_samples > 0 def __len__(self): """ @@ -148,104 +149,59 @@ def _apply_transformation(self, img, gt): return img, gt @staticmethod - def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, - selection: Optional[Union[int, List[str]]] = None) \ - -> List[Tuple[Path, Path, str, str, Tuple[int, int]]]: - """ - Structure of the folder + def _get_paths_from_specs(data_root: str, + doc_dir: str, doc_names: str, + gt_dir: str, gt_names: str, + range_from: int, range_to: int): - directory/data/ORIGINAL_FILENAME/FILE_NAME_X_Y.png - directory/gt/ORIGINAL_FILENAME/FILE_NAME_X_Y.png + path_root = Path(data_root) + path_doc_dir = path_root / doc_dir + path_gt_dir = path_root / gt_dir + + if not path_doc_dir.is_dir(): + log.error(f'Document directory not found ("{path_doc_dir}")!') + + if not path_gt_dir.is_dir(): + log.error(f'Ground Truth directory not found ("{path_gt_dir}")!') + + p = re.compile('#+') + + # assert that there is exactly one placeholder group + assert len(p.findall(doc_names)) == 1 + assert len(p.findall(gt_names)) == 1 + + search_doc_names = p.search(doc_names) + doc_prefix = doc_names[:search_doc_names.span(0)[0]] + doc_suffix = doc_names[search_doc_names.span(0)[1]:] + doc_number_length = len(search_doc_names.group(0)) + + search_gt_names = p.search(gt_names) + gt_prefix = gt_names[:search_gt_names.span(0)[0]] + gt_suffix = gt_names[search_gt_names.span(0)[1]:] + gt_number_length = len(search_gt_names.group(0)) - :param directory: - :param data_folder_name: - :param gt_folder_name: - :param selection: - :return: tuple - (path_data_file, path_gt_file, original_image_name, (x, y)) - """ paths = [] - directory = directory.expanduser() - - path_data_root = directory / data_folder_name - path_gt_root = directory / gt_folder_name - - if not (path_data_root.is_dir() or path_gt_root.is_dir()): - log.error("folder data or gt not found in " + str(directory)) - - # get all subitems (and files) sorted - subitems = sorted(path_data_root.iterdir()) - - # check the selection parameter - if selection: - subdirectories = [x.name for x in subitems if x.is_dir()] - - if isinstance(selection, int): - if selection < 0: - msg = f'Parameter "selection" is a negative integer ({selection}). ' \ - f'Negative values are not supported!' - log.error(msg) - raise ValueError(msg) - - elif selection == 0: - selection = None - - elif selection > len(subdirectories): - msg = f'Parameter "selection" is larger ({selection}) than ' \ - f'number of subdirectories ({len(subdirectories)}).' - log.error(msg) - raise ValueError(msg) - - elif isinstance(selection, ListConfig) or isinstance(selection, list): - if not all(x in subdirectories for x in selection): - msg = f'Parameter "selection" contains a non-existing subdirectory.)' - log.error(msg) - raise ValueError(msg) - - else: - msg = f'Parameter "selection" exists, but it is of unsupported type ({type(selection)})' - log.error(msg) - raise TypeError(msg) - - counter = 0 # Counter for subdirectories, needed for selection parameter - - for path_data_subdir in subitems: - if not path_data_subdir.is_dir(): - if has_extension(path_data_subdir.name, IMG_EXTENSIONS): - log.warning("image file found in data root: " + str(path_data_subdir)) - continue - - counter += 1 - - if selection: - if isinstance(selection, int): - if counter > selection: - break - - elif isinstance(selection, ListConfig) or isinstance(selection, list): - if path_data_subdir.name not in selection: - continue - - path_gt_subdir = path_gt_root / path_data_subdir.stem - assert path_gt_subdir.is_dir() - - for path_data_file, path_gt_file in zip(sorted(path_data_subdir.iterdir()), - sorted(path_gt_subdir.iterdir())): - assert has_extension(path_data_file.name, IMG_EXTENSIONS) == \ - has_extension(path_gt_file.name, IMG_EXTENSIONS), \ - 'get_gt_data_paths(): image file aligned with non-image file' - - if has_extension(path_data_file.name, IMG_EXTENSIONS) and has_extension(path_gt_file.name, - IMG_EXTENSIONS): - assert path_data_file.stem == path_gt_file.stem, \ - 'get_gt_data_paths(): mismatch between data filename and gt filename' - coordinates = re.compile(r'.+_x(\d+)_y(\d+)\.') - m = coordinates.match(path_data_file.name) - if m is None: - continue - x = int(m.group(1)) - y = int(m.group(2)) - # TODO check if we need x/y - paths.append((path_data_file, path_gt_file, path_data_subdir.stem, path_data_file.stem, (x, y))) + for i in range(range_from, range_to + 1): + doc_filename = f'{doc_prefix}{i:0{doc_number_length}d}{doc_suffix}' + path_doc_file = path_doc_dir / doc_filename + + gt_filename = f'{gt_prefix}{i:0{gt_number_length}d}{gt_suffix}' + path_gt_file = path_gt_dir / gt_filename + + assert path_doc_file.exists() == path_gt_file.exists() + + if path_doc_file.exists() and path_gt_file.exists(): + paths.append((path_doc_file, path_gt_file, path_doc_file.stem)) + + assert len(paths) > 0 + + return paths + + @staticmethod + def get_gt_data_paths(list_specs: List[DatasetSpecs]) -> List[Tuple[Path, Path, str]]: + paths = [] + + for specs in list_specs: + paths += DatasetRolfFormat._get_paths_from_specs(**asdict(specs)) return paths diff --git a/src/datamodules/RolfFormat/utils/image_analytics.py b/src/datamodules/RolfFormat/utils/image_analytics.py index 58d10ec8..c38bb233 100644 --- a/src/datamodules/RolfFormat/utils/image_analytics.py +++ b/src/datamodules/RolfFormat/utils/image_analytics.py @@ -17,7 +17,9 @@ from src.datamodules.RGB.utils.misc import pil_loader -def get_analytics_data(file_names_data, **kwargs): +def get_analytics_data(data_gt_path_list, **kwargs): + file_names_data = np.asarray([str(item[0]) for item in data_gt_path_list]) + mean, std = compute_mean_std(file_names=file_names_data, **kwargs) analytics_data = {'mean': mean.tolist(), 'std': std.tolist()} @@ -25,7 +27,9 @@ def get_analytics_data(file_names_data, **kwargs): return analytics_data -def get_analytics_gt(file_names_gt, **kwargs): +def get_analytics_gt(data_gt_path_list, **kwargs): + file_names_gt = np.asarray([str(item[1]) for item in data_gt_path_list]) + # Measure weights for class balancing logging.info(f'Measuring class weights') # create a list with all gt file paths @@ -37,6 +41,15 @@ def get_analytics_gt(file_names_gt, **kwargs): return analytics_gt +def get_image_dims(data_gt_path_list, **kwargs): + img = Image.open(data_gt_path_list[0][0]).convert('RGB') + + image_dims = {'width': img.width, + 'height': img.height} + + return image_dims + + def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): """ Computes mean and std of all images present at target folder. From 4a63234256a5a77064be426952fc69222950a98b Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 18 Nov 2021 12:44:23 +0100 Subject: [PATCH 065/108] :sparkles: Full Page works --- ..._baby_unet_rolf_layoutD_gtD_full_page.yaml | 66 ++++++++ .../semantic_segmentation_RGB_full_page.yaml | 2 + src/datamodules/RGB/datamodule_full_page.py | 143 ++++++++++++++++++ src/datamodules/RGB/utils/image_analytics.py | 8 +- .../RGB/semantic_segmentation_full_page.py | 120 +++++++++++++++ 5 files changed, 337 insertions(+), 2 deletions(-) create mode 100644 configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD_full_page.yaml create mode 100644 configs/task/semantic_segmentation_RGB_full_page.yaml create mode 100644 src/datamodules/RGB/datamodule_full_page.py create mode 100644 src/tasks/RGB/semantic_segmentation_full_page.py diff --git a/configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD_full_page.yaml b/configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD_full_page.yaml new file mode 100644 index 00000000..75251bb9 --- /dev/null +++ b/configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD_full_page.yaml @@ -0,0 +1,66 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_RGB_full_page.yaml + - /loss: crossentropyloss.yaml + - /metric: iou.yaml + - /model/backbone: baby_unet_model.yaml + - /model/header: identity.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +train: True +test: True + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 2000 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 100 + confusion_matrix_val: True + confusion_matrix_test: True + +datamodule: + _target_: src.datamodules.RGB.datamodule_full_page.DataModuleRGB + + data_dir: /netscratch/datasets/semantic_segmentation/synthetic/SetA1_sizeM/layoutD/split + num_workers: 4 + batch_size: 2 + shuffle: True + drop_last: True + data_folder_name: data + gt_folder_name: gtD + +callbacks: + model_checkpoint: + monitor: "val/iou" + mode: "max" + filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data + watch_model: + log_freq: 100 + +logger: + wandb: + name: 'synthetic-baby-unet-rolf-layoutD-gtD-full-page' + tags: [ "best_model", "synthetic", "layoutD", "gtD", "Rolf", "full_page" ] + group: 'synthetic' diff --git a/configs/task/semantic_segmentation_RGB_full_page.yaml b/configs/task/semantic_segmentation_RGB_full_page.yaml new file mode 100644 index 00000000..406872ed --- /dev/null +++ b/configs/task/semantic_segmentation_RGB_full_page.yaml @@ -0,0 +1,2 @@ +_target_: src.tasks.RGB.semantic_segmentation_full_page.SemanticSegmentationFullPageRGB + diff --git a/src/datamodules/RGB/datamodule_full_page.py b/src/datamodules/RGB/datamodule_full_page.py new file mode 100644 index 00000000..4eef0ca9 --- /dev/null +++ b/src/datamodules/RGB/datamodule_full_page.py @@ -0,0 +1,143 @@ +from pathlib import Path +from typing import Union, List, Optional + +import torch +from torch.utils.data import DataLoader +from torchvision import transforms + +from src.datamodules.RGB.datasets.full_page_dataset import DatasetRGB +from src.datamodules.RGB.utils.image_analytics import get_analytics +from src.datamodules.RGB.utils.misc import validate_path_for_segmentation +from src.datamodules.RGB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling, \ + IntegerEncoding +from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget +from src.datamodules.base_datamodule import AbstractDatamodule +from src.utils import utils + +log = utils.get_logger(__name__) + + +class DataModuleRGB(AbstractDatamodule): + def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, + selection_train: Optional[Union[int, List[str]]] = None, + selection_val: Optional[Union[int, List[str]]] = None, + selection_test: Optional[Union[int, List[str]]] = None, + num_workers: int = 4, batch_size: int = 8, + shuffle: bool = True, drop_last: bool = True): + super().__init__() + + self.data_folder_name = data_folder_name + self.gt_folder_name = gt_folder_name + + analytics_data, analytics_gt = get_analytics(input_path=Path(data_dir), + data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name, + get_gt_data_paths_func=DatasetRGB.get_gt_data_paths) + + self.dims = (3, analytics_data['width'], analytics_data['height']) + + self.mean = analytics_data['mean'] + self.std = analytics_data['std'] + self.class_encodings = analytics_gt['class_encodings'] + self.class_encodings_tensor = torch.tensor(self.class_encodings) / 255 + self.num_classes = len(self.class_encodings) + self.class_weights = analytics_gt['class_weights'] + + self.twin_transform = None + self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std)])) + self.target_transform = OnlyTarget(IntegerEncoding(class_encodings=self.class_encodings_tensor)) + + self.num_workers = num_workers + self.batch_size = batch_size + + self.shuffle = shuffle + self.drop_last = drop_last + + self.data_dir = validate_path_for_segmentation(data_dir=data_dir, data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name) + + self.selection_train = selection_train + self.selection_val = selection_val + self.selection_test = selection_test + + def setup(self, stage: Optional[str] = None): + super().setup() + if stage == 'fit' or stage is None: + self.train = DatasetRGB(**self._create_dataset_parameters('train'), selection=self.selection_train) + self.val = DatasetRGB(**self._create_dataset_parameters('val'), selection=self.selection_val) + + self._check_min_num_samples(num_samples=len(self.train), data_split='train', + drop_last=self.drop_last) + self._check_min_num_samples(num_samples=len(self.val), data_split='val', + drop_last=self.drop_last) + + if stage == 'test' or stage is not None: + self.test = DatasetRGB(**self._create_dataset_parameters('test'), selection=self.selection_test) + # self._check_min_num_samples(num_samples=len(self.test), data_split='test', + # drop_last=False) + + def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: bool): + num_processes = self.trainer.num_processes + batch_size = self.batch_size + if drop_last: + if num_samples < (self.trainer.num_processes * self.batch_size): + log.error( + f'#samples ({num_samples}) in "{data_split}" smaller than ' + f'#processes({num_processes}) times batch size ({batch_size}). ' + f'This only works if drop_last is false!') + raise ValueError() + else: + if num_samples < (self.trainer.num_processes * self.batch_size): + log.warning( + f'#samples ({num_samples}) in "{data_split}" smaller than ' + f'#processes ({num_processes}) times batch size ({batch_size}). ' + f'This works due to drop_last=False, however samples will occur multiple times. ' + f'Check if this behavior is intended!') + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(self.train, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=self.shuffle, + drop_last=self.drop_last, + pin_memory=True) + + def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + return DataLoader(self.val, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=self.shuffle, + drop_last=self.drop_last, + pin_memory=True) + + def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + return DataLoader(self.test, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + drop_last=False, + pin_memory=True) + + def _create_dataset_parameters(self, dataset_type: str = 'train'): + is_test = dataset_type == 'test' + return {'path': self.data_dir / dataset_type, + 'data_folder_name': self.data_folder_name, + 'gt_folder_name': self.gt_folder_name, + 'image_transform': self.image_transform, + 'target_transform': self.target_transform, + 'twin_transform': self.twin_transform, + 'classes': self.class_encodings, + 'is_test': is_test} + + def get_img_name(self, index): + """ + Returns the original filename of the doc image. + You can just use this during testing! + :param index: + :return: + """ + if not hasattr(self, 'test'): + raise Exception('This method can just be called during testing') + + return self.test.img_paths_per_page[index][2:] diff --git a/src/datamodules/RGB/utils/image_analytics.py b/src/datamodules/RGB/utils/image_analytics.py index 839636a0..264e9772 100644 --- a/src/datamodules/RGB/utils/image_analytics.py +++ b/src/datamodules/RGB/utils/image_analytics.py @@ -26,7 +26,7 @@ def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, Returns ------- """ - expected_keys_data = ['mean', 'std'] + expected_keys_data = ['mean', 'std', 'width', 'height'] expected_keys_gt = ['class_weights', 'class_encodings'] analytics_path_data = input_path / f'analytics.data.{data_folder_name}.json' @@ -61,8 +61,12 @@ def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, if missing_analytics_data: mean, std = compute_mean_std(file_names=file_names_data, **kwargs) + img = Image.open(file_names_data[0]).convert('RGB') + analytics_data = {'mean': mean.tolist(), - 'std': std.tolist()} + 'std': std.tolist(), + 'width': img.width, + 'height': img.height} # save json try: with analytics_path_data.open(mode='w') as f: diff --git a/src/tasks/RGB/semantic_segmentation_full_page.py b/src/tasks/RGB/semantic_segmentation_full_page.py new file mode 100644 index 00000000..b242298c --- /dev/null +++ b/src/tasks/RGB/semantic_segmentation_full_page.py @@ -0,0 +1,120 @@ +from pathlib import Path +from typing import Optional, Callable, Union + +import numpy as np +import torch.nn as nn +import torch.optim +import torchmetrics + +from src.tasks.base_task import AbstractTask +from src.datamodules.DivaHisDB.utils.output_tools import _get_argmax +from src.utils import utils +from src.tasks.utils.outputs import OutputKeys, reduce_dict + +log = utils.get_logger(__name__) + + +class SemanticSegmentationFullPageRGB(AbstractTask): + + def __init__(self, + model: nn.Module, + optimizer: torch.optim.Optimizer, + loss_fn: Optional[Callable] = None, + metric_train: Optional[torchmetrics.Metric] = None, + metric_val: Optional[torchmetrics.Metric] = None, + metric_test: Optional[torchmetrics.Metric] = None, + test_output_path: Optional[Union[str, Path]] = 'predictions', + confusion_matrix_val: Optional[bool] = False, + confusion_matrix_test: Optional[bool] = False, + confusion_matrix_log_every_n_epoch: Optional[int] = 1, + lr: float = 1e-3 + ) -> None: + """ + pixelvise semantic segmentation. The output of the network during test is a DIVAHisDB encoded image + + :param model: torch.nn.Module + The encoder for the segmentation e.g. unet + :param test_output_path: str + String with a path to the output folder of the testing + """ + super().__init__( + model=model, + optimizer=optimizer, + loss_fn=loss_fn, + metric_train=metric_train, + metric_val=metric_val, + metric_test=metric_test, + test_output_path=test_output_path, + lr=lr, + confusion_matrix_val=confusion_matrix_val, + confusion_matrix_test=confusion_matrix_test, + confusion_matrix_log_every_n_epoch=confusion_matrix_log_every_n_epoch, + ) + self.save_hyperparameters() + + def setup(self, stage: str) -> None: + super().setup(stage) + + if not hasattr(self.trainer.datamodule, 'get_img_name'): + raise NotImplementedError('DataModule needs to implement get_img_name function') + + log.info("Setup done!") + + def forward(self, x): + return self.model(x) + + @staticmethod + def to_metrics_format(x: torch.Tensor, **kwargs) -> torch.Tensor: + return _get_argmax(x, **kwargs) + + ############################################################################################# + ########################################### TRAIN ########################################### + ############################################################################################# + def training_step(self, batch, batch_idx, **kwargs): + input_batch, target_batch = batch + output = super().training_step(batch=(input_batch, target_batch), batch_idx=batch_idx) + return reduce_dict(input_dict=output, key_list=[OutputKeys.LOSS]) + + ############################################################################################# + ############################################ VAL ############################################ + ############################################################################################# + + def validation_step(self, batch, batch_idx, **kwargs): + input_batch, target_batch = batch + output = super().validation_step(batch=(input_batch, target_batch), batch_idx=batch_idx) + return reduce_dict(input_dict=output, key_list=[]) + + ############################################################################################# + ########################################### TEST ############################################ + ############################################################################################# + + def test_step(self, batch, batch_idx, **kwargs): + input_batch, target_batch, input_idx = batch + output = super().test_step(batch=(input_batch, target_batch), batch_idx=batch_idx) + + if not hasattr(self.trainer.datamodule, 'get_img_name'): + raise NotImplementedError('Datamodule does not provide detailed information of the crop') + + for patch, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), + input_idx.detach().cpu().numpy()): + patch_info = self.trainer.datamodule.get_img_name_coordinates(idx) + img_name = patch_info[0] + dest_folder = self.test_output_path / 'preds' + dest_folder.mkdir(parents=True, exist_ok=True) + dest_filename = dest_folder / f'{img_name}.npy' + + np.save(file=str(dest_filename), arr=patch) + + return reduce_dict(input_dict=output, key_list=[]) + + def on_test_end(self) -> None: + datamodule_path = self.trainer.datamodule.data_dir + prediction_path = (self.test_output_path / 'patches').absolute() + output_path = (self.test_output_path / 'result').absolute() + + data_folder_name = self.trainer.datamodule.data_folder_name + gt_folder_name = self.trainer.datamodule.gt_folder_name + + log.info(f'To run the merging of patches:') + log.info(f'python tools/merge_cropped_output_RGB.py -d {datamodule_path} -p {prediction_path} -o {output_path} ' + f'-df {data_folder_name} -gf {gt_folder_name}') From 70ca31f1cb95b0e394c6f32a0e8d579bd1554134 Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 18 Nov 2021 16:22:23 +0100 Subject: [PATCH 066/108] :sparkles: :wrench: Rolf format works --- configs/datamodule/rolf_format_dev.yaml | 2 +- .../datamodule/rolf_format_layoutD_gtD.yaml | 135 ++++++++++++++++++ ..._full_page.yaml => dev_rgb_full_page.yaml} | 13 +- configs/experiment/dev_rolf_format.yaml | 57 ++++++++ ...l => synthetic_baby_unet_layoutD_gtD.yaml} | 8 +- ...ic_baby_unet_layoutD_gtD_rolf_format.yaml} | 12 +- ...l => synthetic_baby_unet_layoutR_gtD.yaml} | 8 +- .../DivaHisDB/datamodule_cropped.py | 2 +- src/datamodules/RGB/datamodule_cropped.py | 2 +- src/datamodules/RGB/datamodule_full_page.py | 2 +- .../RGB/datasets/full_page_dataset.py | 2 +- src/datamodules/RolfFormat/datamodule.py | 10 +- src/datamodules/RotNet/datamodule_cropped.py | 2 +- .../RGB/semantic_segmentation_full_page.py | 26 ++-- 14 files changed, 235 insertions(+), 46 deletions(-) create mode 100644 configs/datamodule/rolf_format_layoutD_gtD.yaml rename configs/experiment/{synthetic_baby_unet_rolf_layoutD_gtD_full_page.yaml => dev_rgb_full_page.yaml} (85%) create mode 100644 configs/experiment/dev_rolf_format.yaml rename configs/experiment/{synthetic_baby_unet_rolf_layoutD_gtD.yaml => synthetic_baby_unet_layoutD_gtD.yaml} (92%) rename configs/experiment/{synthetic_baby_unet_rolf_format.yaml => synthetic_baby_unet_layoutD_gtD_rolf_format.yaml} (83%) rename configs/experiment/{synthetic_baby_unet_rolf_layoutR_gtD.yaml => synthetic_baby_unet_layoutR_gtD.yaml} (92%) diff --git a/configs/datamodule/rolf_format_dev.yaml b/configs/datamodule/rolf_format_dev.yaml index 9c7551a0..9e6d51bd 100644 --- a/configs/datamodule/rolf_format_dev.yaml +++ b/configs/datamodule/rolf_format_dev.yaml @@ -1,7 +1,7 @@ _target_: src.datamodules.RolfFormat.datamodule.DataModuleRolfFormat num_workers: 4 -batch_size: 16 +batch_size: 8 shuffle: True drop_last: True diff --git a/configs/datamodule/rolf_format_layoutD_gtD.yaml b/configs/datamodule/rolf_format_layoutD_gtD.yaml new file mode 100644 index 00000000..60e11bbb --- /dev/null +++ b/configs/datamodule/rolf_format_layoutD_gtD.yaml @@ -0,0 +1,135 @@ +_target_: src.datamodules.RolfFormat.datamodule.DataModuleRolfFormat + +num_workers: 4 +batch_size: 8 +shuffle: True +drop_last: False + +data_root: /netscratch/datasets/semantic_segmentation/rolf_format + +train_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1000 + range_to: 1059 + +val_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1060 + range_to: 1079 + +test_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1080 + range_to: 1099 + +image_dims: + width: 640 + height: 896 + +image_analytics: + mean: + R: 0.8616756883580258 + G: 0.7419672402489641 + B: 0.6295439441727211 + std: + R: 0.21909338297170539 + G: 0.2076260211193138 + B: 0.1875025535444422 + +classes: + class0: + color: + R: 0 + G: 0 + B: 0 + weight: 0.00047694816117914033 + class1: + color: + R: 0 + G: 102 + B: 0 + weight: 0.2750549630858548 + class2: + color: + R: 0 + G: 102 + B: 102 + weight: 0.0032424343955529127 + class3: + color: + R: 0 + G: 153 + B: 153 + weight: 0.011400733756796401 + class4: + color: + R: 0 + G: 255 + B: 0 + weight: 0.13777394656361366 + class5: + color: + R: 0 + G: 255 + B: 255 + weight: 0.008088433280055035 + class6: + color: + R: 102 + G: 0 + B: 0 + weight: 0.09080998823458127 + class7: + color: + R: 102 + G: 0 + B: 102 + weight: 0.05538750877701472 + class8: + color: + R: 102 + G: 102 + B: 0 + weight: 0.0019513173654070824 + class9: + color: + R: 153 + G: 0 + B: 153 + weight: 0.21392691331701744 + class10: + color: + R: 153 + G: 153 + B: 0 + weight: 0.007439598839196634 + class11: + color: + R: 255 + G: 0 + B: 0 + weight: 0.06261423912267086 + class12: + color: + R: 255 + G: 0 + B: 255 + weight: 0.12729126435641605 + class13: + color: + R: 255 + G: 255 + B: 0 + weight: 0.004541710744643828 diff --git a/configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD_full_page.yaml b/configs/experiment/dev_rgb_full_page.yaml similarity index 85% rename from configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD_full_page.yaml rename to configs/experiment/dev_rgb_full_page.yaml index 75251bb9..332706ca 100644 --- a/configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD_full_page.yaml +++ b/configs/experiment/dev_rgb_full_page.yaml @@ -31,12 +31,12 @@ trainer: gpus: -1 accelerator: 'ddp' min_epochs: 1 - max_epochs: 2000 + max_epochs: 2 weights_summary: full precision: 16 task: - confusion_matrix_log_every_n_epoch: 100 + confusion_matrix_log_every_n_epoch: 1 confusion_matrix_val: True confusion_matrix_test: True @@ -57,10 +57,11 @@ callbacks: mode: "max" filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data watch_model: - log_freq: 100 + log_freq: 1 logger: wandb: - name: 'synthetic-baby-unet-rolf-layoutD-gtD-full-page' - tags: [ "best_model", "synthetic", "layoutD", "gtD", "Rolf", "full_page" ] - group: 'synthetic' + name: 'dev-RGB-full-page' + tags: [ "best_model", "synthetic", "RGB", "Rolf", "full_page" ] + group: 'dev-runs' + notes: "Testing" diff --git a/configs/experiment/dev_rolf_format.yaml b/configs/experiment/dev_rolf_format.yaml new file mode 100644 index 00000000..8bd25c2f --- /dev/null +++ b/configs/experiment/dev_rolf_format.yaml @@ -0,0 +1,57 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_RGB_full_page.yaml + - /datamodule: rolf_format_dev.yaml + - /loss: crossentropyloss.yaml + - /metric: iou.yaml + - /model/backbone: baby_unet_model.yaml + - /model/header: identity.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +train: True +test: True + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 2 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 1 + confusion_matrix_val: True + confusion_matrix_test: True + +callbacks: + model_checkpoint: + monitor: "val/iou" + mode: "max" + filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'dev-rolf-format' + tags: [ "best_model", "synthetic", "rolf_format" ] + group: 'dev-runs' + notes: "Testing" diff --git a/configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD.yaml b/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml similarity index 92% rename from configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD.yaml rename to configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml index 78b65fe4..efa7f52f 100644 --- a/configs/experiment/synthetic_baby_unet_rolf_layoutD_gtD.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml @@ -31,12 +31,12 @@ trainer: gpus: -1 accelerator: 'ddp' min_epochs: 1 - max_epochs: 2000 + max_epochs: 200 weights_summary: full precision: 16 task: - confusion_matrix_log_every_n_epoch: 100 + confusion_matrix_log_every_n_epoch: 50 confusion_matrix_val: True confusion_matrix_test: True @@ -58,10 +58,10 @@ callbacks: mode: "max" filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data watch_model: - log_freq: 100 + log_freq: 50 logger: wandb: - name: 'synthetic-baby-unet-rolf-layoutD-gtD' + name: 'synthetic-baby-unet-layoutD-gtD' tags: [ "best_model", "synthetic", "layoutD", "gtD", "Rolf" ] group: 'synthetic' diff --git a/configs/experiment/synthetic_baby_unet_rolf_format.yaml b/configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml similarity index 83% rename from configs/experiment/synthetic_baby_unet_rolf_format.yaml rename to configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml index 7391f8fe..52833445 100644 --- a/configs/experiment/synthetic_baby_unet_rolf_format.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml @@ -5,8 +5,8 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_RGB.yaml - - /datamodule: rolf_format_dev.yaml + - /task: semantic_segmentation_RGB_full_page.yaml + - /datamodule: rolf_format_layoutD_gtD.yaml - /loss: crossentropyloss.yaml - /metric: iou.yaml - /model/backbone: baby_unet_model.yaml @@ -32,12 +32,12 @@ trainer: gpus: -1 accelerator: 'ddp' min_epochs: 1 - max_epochs: 2000 + max_epochs: 200 weights_summary: full precision: 16 task: - confusion_matrix_log_every_n_epoch: 100 + confusion_matrix_log_every_n_epoch: 50 confusion_matrix_val: True confusion_matrix_test: True @@ -47,10 +47,10 @@ callbacks: mode: "max" filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data watch_model: - log_freq: 100 + log_freq: 50 logger: wandb: - name: 'synthetic-baby-unet-rolf-format' + name: 'synthetic-baby-unet-layoutD-gtD-rolf-format' tags: [ "best_model", "synthetic", "RolfFormat" ] group: 'synthetic' diff --git a/configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml b/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml similarity index 92% rename from configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml rename to configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml index 1ab54d6a..e9869db1 100644 --- a/configs/experiment/synthetic_baby_unet_rolf_layoutR_gtD.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml @@ -31,12 +31,12 @@ trainer: gpus: -1 accelerator: 'ddp' min_epochs: 1 - max_epochs: 2000 + max_epochs: 200 weights_summary: full precision: 16 task: - confusion_matrix_log_every_n_epoch: 100 + confusion_matrix_log_every_n_epoch: 50 confusion_matrix_val: True confusion_matrix_test: True @@ -58,10 +58,10 @@ callbacks: mode: "max" filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data watch_model: - log_freq: 100 + log_freq: 50 logger: wandb: - name: 'synthetic-baby-unet-rolf-layoutR-gtD' + name: 'synthetic-baby-unet-layoutR-gtD' tags: [ "best_model", "synthetic", "layoutR", "gtD", "Rolf" ] group: 'synthetic' diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index 40c42ff5..60600289 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -92,7 +92,7 @@ def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: b log.warning( f'#samples ({num_samples}) in "{data_split}" smaller than ' f'#processes ({num_processes}) times batch size ({batch_size}). ' - f'This works due to drop_last=False, however samples will occur multiple times. ' + f'This works due to drop_last=False, however samples might occur multiple times. ' f'Check if this behavior is intended!') def train_dataloader(self, *args, **kwargs) -> DataLoader: diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index e85a14a5..01731b74 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -92,7 +92,7 @@ def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: b log.warning( f'#samples ({num_samples}) in "{data_split}" smaller than ' f'#processes ({num_processes}) times batch size ({batch_size}). ' - f'This works due to drop_last=False, however samples will occur multiple times. ' + f'This works due to drop_last=False, however samples might occur multiple times. ' f'Check if this behavior is intended!') def train_dataloader(self, *args, **kwargs) -> DataLoader: diff --git a/src/datamodules/RGB/datamodule_full_page.py b/src/datamodules/RGB/datamodule_full_page.py index 4eef0ca9..11077523 100644 --- a/src/datamodules/RGB/datamodule_full_page.py +++ b/src/datamodules/RGB/datamodule_full_page.py @@ -92,7 +92,7 @@ def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: b log.warning( f'#samples ({num_samples}) in "{data_split}" smaller than ' f'#processes ({num_processes}) times batch size ({batch_size}). ' - f'This works due to drop_last=False, however samples will occur multiple times. ' + f'This works due to drop_last=False, however samples might occur multiple times. ' f'Check if this behavior is intended!') def train_dataloader(self, *args, **kwargs) -> DataLoader: diff --git a/src/datamodules/RGB/datasets/full_page_dataset.py b/src/datamodules/RGB/datasets/full_page_dataset.py index 410d0e76..408beb37 100644 --- a/src/datamodules/RGB/datasets/full_page_dataset.py +++ b/src/datamodules/RGB/datasets/full_page_dataset.py @@ -227,6 +227,6 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st assert path_data_file.stem == path_gt_file.stem, \ 'get_gt_data_paths(): mismatch between data filename and gt filename' # TODO check if we need x/y - paths.append((path_data_file, path_gt_file)) + paths.append((path_data_file, path_gt_file, path_data_file.stem)) return paths diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index 86cbdc1d..8eac6b4a 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -31,6 +31,10 @@ def __init__(self, data_root: str, if image_analytics is None or classes is None or image_dims is None: train_paths_data_gt = DatasetRolfFormat.get_gt_data_paths(list_specs=self.train_dataset_specs) + if image_dims is None: + image_dims = get_image_dims(data_gt_path_list=train_paths_data_gt) + self._print_image_dims(image_dims=image_dims) + if image_analytics is None: analytics_data = get_analytics_data(data_gt_path_list=train_paths_data_gt) self._print_analytics_data(analytics_data=analytics_data) @@ -54,10 +58,6 @@ def __init__(self, data_root: str, class_specs['color']['B']]) analytics_gt['class_weights'].append(class_specs['weight']) - if image_dims is None: - image_dims = get_image_dims(data_gt_path_list=train_paths_data_gt) - self._print_image_dims(image_dims=image_dims) - self.dims = (3, image_dims['width'], image_dims['height']) self.mean = analytics_data['mean'] @@ -162,7 +162,7 @@ def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: b log.warning( f'#samples ({num_samples}) in "{data_split}" smaller than ' f'#processes ({num_processes}) times batch size ({batch_size}). ' - f'This works due to drop_last=False, however samples will occur multiple times. ' + f'This works due to drop_last=False, however samples might occur multiple times. ' f'Check if this behavior is intended!') def train_dataloader(self, *args, **kwargs) -> DataLoader: diff --git a/src/datamodules/RotNet/datamodule_cropped.py b/src/datamodules/RotNet/datamodule_cropped.py index a1dfb2bd..7365ed2c 100644 --- a/src/datamodules/RotNet/datamodule_cropped.py +++ b/src/datamodules/RotNet/datamodule_cropped.py @@ -83,7 +83,7 @@ def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: b log.warning( f'#samples ({num_samples}) in "{data_split}" smaller than ' f'#processes ({num_processes}) times batch size ({batch_size}). ' - f'This works due to drop_last=False, however samples will occur multiple times. ' + f'This works due to drop_last=False, however samples might occur multiple times. ' f'Check if this behavior is intended!') def train_dataloader(self, *args, **kwargs) -> DataLoader: diff --git a/src/tasks/RGB/semantic_segmentation_full_page.py b/src/tasks/RGB/semantic_segmentation_full_page.py index b242298c..d8bc2214 100644 --- a/src/tasks/RGB/semantic_segmentation_full_page.py +++ b/src/tasks/RGB/semantic_segmentation_full_page.py @@ -6,6 +6,7 @@ import torch.optim import torchmetrics +from src.datamodules.RolfFormat.utils.output_tools import save_output_page_image from src.tasks.base_task import AbstractTask from src.datamodules.DivaHisDB.utils.output_tools import _get_argmax from src.utils import utils @@ -95,26 +96,21 @@ def test_step(self, batch, batch_idx, **kwargs): if not hasattr(self.trainer.datamodule, 'get_img_name'): raise NotImplementedError('Datamodule does not provide detailed information of the crop') - for patch, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), - input_idx.detach().cpu().numpy()): - patch_info = self.trainer.datamodule.get_img_name_coordinates(idx) + for pred_raw, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), + input_idx.detach().cpu().numpy()): + patch_info = self.trainer.datamodule.get_img_name(idx) img_name = patch_info[0] - dest_folder = self.test_output_path / 'preds' + dest_folder = self.test_output_path / 'preds_raw' dest_folder.mkdir(parents=True, exist_ok=True) dest_filename = dest_folder / f'{img_name}.npy' + np.save(file=str(dest_filename), arr=pred_raw) - np.save(file=str(dest_filename), arr=patch) + dest_folder = self.test_output_path / 'preds' + dest_folder.mkdir(parents=True, exist_ok=True) + save_output_page_image(image_name=f'{img_name}.gif', output_image=pred_raw, + output_folder=dest_folder, class_encoding=self.trainer.datamodule.class_encodings) return reduce_dict(input_dict=output, key_list=[]) def on_test_end(self) -> None: - datamodule_path = self.trainer.datamodule.data_dir - prediction_path = (self.test_output_path / 'patches').absolute() - output_path = (self.test_output_path / 'result').absolute() - - data_folder_name = self.trainer.datamodule.data_folder_name - gt_folder_name = self.trainer.datamodule.gt_folder_name - - log.info(f'To run the merging of patches:') - log.info(f'python tools/merge_cropped_output_RGB.py -d {datamodule_path} -p {prediction_path} -o {output_path} ' - f'-df {data_folder_name} -gf {gt_folder_name}') + pass From 31e58b62a9644bc176fff32f8b2dce3a2d8878fe Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 18 Nov 2021 17:11:17 +0100 Subject: [PATCH 067/108] :art: Check image dimensions for full page data --- configs/datamodule/rolf_format_dev.yaml | 2 +- src/datamodules/RGB/datamodule_full_page.py | 9 ++--- .../RGB/datasets/full_page_dataset.py | 14 ++++++++ src/datamodules/RolfFormat/datamodule.py | 35 +++++++++---------- .../RolfFormat/datasets/dataset.py | 11 +++++- .../RolfFormat/utils/image_analytics.py | 4 +-- 6 files changed, 48 insertions(+), 27 deletions(-) diff --git a/configs/datamodule/rolf_format_dev.yaml b/configs/datamodule/rolf_format_dev.yaml index 9e6d51bd..9db9f874 100644 --- a/configs/datamodule/rolf_format_dev.yaml +++ b/configs/datamodule/rolf_format_dev.yaml @@ -3,7 +3,7 @@ _target_: src.datamodules.RolfFormat.datamodule.DataModuleRolfFormat num_workers: 4 batch_size: 8 shuffle: True -drop_last: True +drop_last: False data_root: /netscratch/datasets/semantic_segmentation/rolf_format diff --git a/src/datamodules/RGB/datamodule_full_page.py b/src/datamodules/RGB/datamodule_full_page.py index 11077523..437439d7 100644 --- a/src/datamodules/RGB/datamodule_full_page.py +++ b/src/datamodules/RGB/datamodule_full_page.py @@ -5,11 +5,10 @@ from torch.utils.data import DataLoader from torchvision import transforms -from src.datamodules.RGB.datasets.full_page_dataset import DatasetRGB +from src.datamodules.RGB.datasets.full_page_dataset import DatasetRGB, ImageDimensions from src.datamodules.RGB.utils.image_analytics import get_analytics from src.datamodules.RGB.utils.misc import validate_path_for_segmentation -from src.datamodules.RGB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling, \ - IntegerEncoding +from src.datamodules.RGB.utils.twin_transforms import IntegerEncoding from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.datamodules.base_datamodule import AbstractDatamodule from src.utils import utils @@ -34,7 +33,8 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, gt_folder_name=self.gt_folder_name, get_gt_data_paths_func=DatasetRGB.get_gt_data_paths) - self.dims = (3, analytics_data['width'], analytics_data['height']) + self.image_dims = ImageDimensions(width=analytics_data['width'], height=analytics_data['height']) + self.dims = (3, self.image_dims.width, self.image_dims.height) self.mean = analytics_data['mean'] self.std = analytics_data['std'] @@ -124,6 +124,7 @@ def _create_dataset_parameters(self, dataset_type: str = 'train'): return {'path': self.data_dir / dataset_type, 'data_folder_name': self.data_folder_name, 'gt_folder_name': self.gt_folder_name, + 'image_dims': self.image_dims, 'image_transform': self.image_transform, 'target_transform': self.target_transform, 'twin_transform': self.twin_transform, diff --git a/src/datamodules/RGB/datasets/full_page_dataset.py b/src/datamodules/RGB/datasets/full_page_dataset.py index 408beb37..95cc7a9a 100644 --- a/src/datamodules/RGB/datasets/full_page_dataset.py +++ b/src/datamodules/RGB/datasets/full_page_dataset.py @@ -8,6 +8,7 @@ from typing import List, Tuple, Union, Optional, Any import torch.utils.data as data +from dataclasses import dataclass from omegaconf import ListConfig from torch import is_tensor from torchvision.transforms import ToTensor @@ -20,6 +21,12 @@ log = utils.get_logger(__name__) +@dataclass +class ImageDimensions: + width: int + height: int + + class DatasetRGB(data.Dataset): """A generic data loader where the images are arranged in this way: :: @@ -33,6 +40,7 @@ class DatasetRGB(data.Dataset): """ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, + image_dims: ImageDimensions, selection: Optional[Union[int, List[str]]] = None, is_test=False, image_transform=None, target_transform=None, twin_transform=None, classes=None, **kwargs): @@ -59,6 +67,8 @@ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, self.gt_folder_name = gt_folder_name self.selection = selection + self.image_dims = image_dims + # Init list self.classes = classes # self.crops_per_image = crops_per_image @@ -103,12 +113,16 @@ def _get_test_items(self, index): data_img, gt_img = self._load_data_and_gt(index=index) img, gt = self._apply_transformation(data_img, gt_img) assert img.shape[-2:] == gt.shape[-2:] + return img, gt, index def _load_data_and_gt(self, index): data_img = pil_loader(self.img_paths_per_page[index][0]) gt_img = pil_loader(self.img_paths_per_page[index][1]) + assert data_img.height == self.image_dims.height and data_img.width == self.image_dims.width + assert gt_img.height == self.image_dims.height and gt_img.width == self.image_dims.width + return data_img, gt_img def _apply_transformation(self, img, gt): diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index 8eac6b4a..200fa617 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -1,12 +1,10 @@ -from pathlib import Path from typing import Union, List, Optional import torch -from dataclasses import dataclass from torch.utils.data import DataLoader from torchvision import transforms -from src.datamodules.RolfFormat.datasets.dataset import DatasetRolfFormat, DatasetSpecs +from src.datamodules.RolfFormat.datasets.dataset import DatasetRolfFormat, DatasetSpecs, ImageDimensions from src.datamodules.RolfFormat.utils.image_analytics import get_analytics_data, get_analytics_gt, get_image_dims from src.datamodules.RolfFormat.utils.twin_transforms import IntegerEncoding from src.datamodules.RolfFormat.utils.wrapper_transforms import OnlyImage, OnlyTarget @@ -58,7 +56,8 @@ def __init__(self, data_root: str, class_specs['color']['B']]) analytics_gt['class_weights'].append(class_specs['weight']) - self.dims = (3, image_dims['width'], image_dims['height']) + self.image_dims = image_dims + self.dims = (3, self.image_dims.width, self.image_dims.height) self.mean = analytics_data['mean'] self.std = analytics_data['std'] @@ -109,31 +108,32 @@ def _print_analytics_gt(self, analytics_gt): print_string = '\n'.join(lines) log.info(print_string) - def _print_image_dims(self, image_dims): + def _print_image_dims(self, image_dims: ImageDimensions): indent = 4 * ' ' lines = [''] lines.append(f'image_dims:') - lines.append(f'{indent}width: {image_dims["width"]}') - lines.append(f'{indent}height: {image_dims["height"]}') + lines.append(f'{indent}width: {image_dims.width}') + lines.append(f'{indent}height: {image_dims.height}') print_string = '\n'.join(lines) log.info(print_string) def setup(self, stage: Optional[str] = None): super().setup() + + common_kwargs = {'classes': self.class_encodings, + 'image_dims': self.image_dims, + 'image_transform': self.image_transform, + 'target_transform': self.target_transform, + 'twin_transform': self.twin_transform} + if stage == 'fit' or stage is None: self.train = DatasetRolfFormat(dataset_specs=self.train_dataset_specs, is_test=False, - classes=self.class_encodings, - image_transform=self.image_transform, - target_transform=self.target_transform, - twin_transform=self.twin_transform) + **common_kwargs) self.val = DatasetRolfFormat(dataset_specs=self.val_dataset_specs, is_test=False, - classes=self.class_encodings, - image_transform=self.image_transform, - target_transform=self.target_transform, - twin_transform=self.twin_transform) + **common_kwargs) self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) @@ -141,10 +141,7 @@ def setup(self, stage: Optional[str] = None): if stage == 'test' or stage is not None: self.test = DatasetRolfFormat(dataset_specs=self.test_dataset_specs, is_test=True, - classes=self.class_encodings, - image_transform=self.image_transform, - target_transform=self.target_transform, - twin_transform=self.twin_transform) + **common_kwargs) # self._check_min_num_samples(num_samples=len(self.test), data_split='test', drop_last=False) def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: bool): diff --git a/src/datamodules/RolfFormat/datasets/dataset.py b/src/datamodules/RolfFormat/datasets/dataset.py index 8079489a..6534d90d 100644 --- a/src/datamodules/RolfFormat/datasets/dataset.py +++ b/src/datamodules/RolfFormat/datasets/dataset.py @@ -30,6 +30,10 @@ class DatasetSpecs: range_from: int range_to: int +@dataclass +class ImageDimensions: + width: int + height: int class DatasetRolfFormat(data.Dataset): """A generic data loader where the images are arranged in this way: :: @@ -43,7 +47,7 @@ class DatasetRolfFormat(data.Dataset): root/data/xxz.png """ - def __init__(self, dataset_specs: List[DatasetSpecs], + def __init__(self, dataset_specs: List[DatasetSpecs], image_dims: ImageDimensions, is_test=False, image_transform=None, target_transform=None, twin_transform=None, classes=None, **kwargs): """ @@ -66,6 +70,8 @@ def __init__(self, dataset_specs: List[DatasetSpecs], self.dataset_specs = dataset_specs + self.image_dims = image_dims + # Init list self.classes = classes @@ -111,6 +117,9 @@ def _load_data_and_gt(self, index): data_img = pil_loader(self.img_paths_per_page[index][0]) gt_img = pil_loader(self.img_paths_per_page[index][1]) + assert data_img.height == self.image_dims.height and data_img.width == self.image_dims.width + assert gt_img.height == self.image_dims.height and gt_img.width == self.image_dims.width + return data_img, gt_img def _apply_transformation(self, img, gt): diff --git a/src/datamodules/RolfFormat/utils/image_analytics.py b/src/datamodules/RolfFormat/utils/image_analytics.py index c38bb233..ed332a48 100644 --- a/src/datamodules/RolfFormat/utils/image_analytics.py +++ b/src/datamodules/RolfFormat/utils/image_analytics.py @@ -15,6 +15,7 @@ from PIL import Image from src.datamodules.RGB.utils.misc import pil_loader +from src.datamodules.RolfFormat.datasets.dataset import ImageDimensions def get_analytics_data(data_gt_path_list, **kwargs): @@ -44,8 +45,7 @@ def get_analytics_gt(data_gt_path_list, **kwargs): def get_image_dims(data_gt_path_list, **kwargs): img = Image.open(data_gt_path_list[0][0]).convert('RGB') - image_dims = {'width': img.width, - 'height': img.height} + image_dims = ImageDimensions(width=img.width, height=img.height) return image_dims From 089ef235369a8d47d050a31326c63d707e4795e7 Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 18 Nov 2021 17:19:40 +0100 Subject: [PATCH 068/108] :wrench: added seed --- configs/experiment/dev_rgb_full_page.yaml | 2 ++ configs/experiment/dev_rolf_format.yaml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/configs/experiment/dev_rgb_full_page.yaml b/configs/experiment/dev_rgb_full_page.yaml index 332706ca..c635992e 100644 --- a/configs/experiment/dev_rgb_full_page.yaml +++ b/configs/experiment/dev_rgb_full_page.yaml @@ -23,6 +23,8 @@ defaults: # instead we define all modules and their paths directly in this config, # so everything is stored in one place for more readibility +seed: 42 + train: True test: True diff --git a/configs/experiment/dev_rolf_format.yaml b/configs/experiment/dev_rolf_format.yaml index 8bd25c2f..fe7c5b6b 100644 --- a/configs/experiment/dev_rolf_format.yaml +++ b/configs/experiment/dev_rolf_format.yaml @@ -24,6 +24,8 @@ defaults: # instead we define all modules and their paths directly in this config, # so everything is stored in one place for more readibility +seed: 42 + train: True test: True From 4575dd3769fca89d9e3a667ee01c474d53d57c2d Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 18 Nov 2021 17:40:07 +0100 Subject: [PATCH 069/108] :white_check_mark: fixed test_full_page_dataset.py --- tests/datamodules/RGB/test_full_page_dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/datamodules/RGB/test_full_page_dataset.py b/tests/datamodules/RGB/test_full_page_dataset.py index d2ffe681..ed8aa656 100644 --- a/tests/datamodules/RGB/test_full_page_dataset.py +++ b/tests/datamodules/RGB/test_full_page_dataset.py @@ -1,19 +1,21 @@ import pytest -from src.datamodules.RGB.datasets.full_page_dataset import DatasetRGB +from src.datamodules.RGB.datasets.full_page_dataset import DatasetRGB, ImageDimensions from tests.test_data.dummy_data_hisdb.dummy_data import data_dir @pytest.fixture def dataset_train(data_dir): - return DatasetRGB(path=data_dir / 'train', data_folder_name='data', gt_folder_name='gt') + return DatasetRGB(path=data_dir / 'train', data_folder_name='data', gt_folder_name='gt', + image_dims=ImageDimensions(width=487, height=649)) def test_get_gt_data_paths(data_dir): file_list = DatasetRGB.get_gt_data_paths(directory=data_dir / 'train', data_folder_name='data', gt_folder_name='gt') assert len(file_list) == 1 assert file_list[0] == (data_dir / 'train' / 'data' / 'e-codices_fmb-cb-0055_0098v_max.jpg', - data_dir / 'train' / 'gt' / 'e-codices_fmb-cb-0055_0098v_max.png') + data_dir / 'train' / 'gt' / 'e-codices_fmb-cb-0055_0098v_max.png', + 'e-codices_fmb-cb-0055_0098v_max') def test_dataset_rgb(dataset_train): From bfee708440164a935d13518181bb3b3ba5c8bc31 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 19 Nov 2021 10:21:36 +0100 Subject: [PATCH 070/108] :art: adapted return type --- src/datamodules/RGB/datasets/full_page_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datamodules/RGB/datasets/full_page_dataset.py b/src/datamodules/RGB/datasets/full_page_dataset.py index 95cc7a9a..7e69b7ab 100644 --- a/src/datamodules/RGB/datasets/full_page_dataset.py +++ b/src/datamodules/RGB/datasets/full_page_dataset.py @@ -163,7 +163,7 @@ def _apply_transformation(self, img, gt): @staticmethod def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None) \ - -> List[Tuple[Union[Path, Any], Path]]: + -> List[Tuple[Path, Path, str]]: """ Structure of the folder From da33c4a18ecb6cff5caab2d321970555ef292752 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 19 Nov 2021 10:29:58 +0100 Subject: [PATCH 071/108] :fire: removed graph image analytic methods --- .../DivaHisDB/utils/image_analytics.py | 74 ------------------ src/datamodules/RGB/utils/image_analytics.py | 74 ------------------ .../RolfFormat/utils/image_analytics.py | 76 ------------------- .../RotNet/utils/image_analytics.py | 74 ------------------ 4 files changed, 298 deletions(-) diff --git a/src/datamodules/DivaHisDB/utils/image_analytics.py b/src/datamodules/DivaHisDB/utils/image_analytics.py index 87042a59..100e53ab 100644 --- a/src/datamodules/DivaHisDB/utils/image_analytics.py +++ b/src/datamodules/DivaHisDB/utils/image_analytics.py @@ -256,80 +256,6 @@ def get_class_weights(input_folder, workers=4, **kwargs): return class_weights -def compute_mean_std_graphs(dataset, **kwargs): - """ - Computes mean and std of all node and edge features present in the given ParsedGxlDataset (see gxl_parser.py). - - Parameters - ---------- - input_folder : ParsedGxlDataset - Dataset object (see above for details) - - # TODO implement online version - - Returns - ------- - node_features : {"mean": list, "std": list} - Mean and std value of all node features in the input dataset - edge_features : {"mean": list, "std": list} - Mean and std value of all edge features in the input dataset - """ - if dataset.data.x is not None: - logging.info('Begin computing the node feature mean and std') - nodes = _get_feature_mean_std(dataset.data.x) - logging.info('Finished computing the node feature mean and std') - else: - nodes = {} - logging.info('No node features present') - - if dataset.data.edge_attr is not None: - logging.info('Begin computing the edge feature mean and std') - edges = _get_feature_mean_std(dataset.data.edge_attr) - logging.info('Finished computing the edge feature mean and std') - else: - edges = {} - logging.info('No edge features present') - - return nodes, edges - - -def _get_feature_mean_std(torch_array): - array = np.array(torch_array) - return {'mean': [np.mean(col) for col in array.T], 'std': [np.std(col) for col in array.T]} - - -def get_class_weights_graphs(dataset, **kwargs): - """ - Get the weights proportional to the inverse of their class frequencies. - The vector sums up to 1 - - Parameters - ---------- - input_folder : ParsedGxlDataset - Dataset object (see above for details) - - # TODO implement online version - - Returns - ------- - ndarray[double] of size (num_classes) - The weights vector as a 1D array normalized (sum up to 1) - """ - logging.info('Begin computing class frequencies weights') - - class_frequencies = np.array(dataset.config['class_freq'][1]) - # Class weights are the inverse of the class frequencies - class_weights = 1 / class_frequencies - # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) - class_weights /= class_weights.sum() - - logging.info('Finished computing class frequencies weights ') - logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') - logging.info(f'Class weights (rounded): {np.around(class_weights)}') - - return class_weights - - def _get_class_frequencies_weights_segmentation_hisdb(gt_images, **kwargs): """ Get the weights proportional to the inverse of their class frequencies. diff --git a/src/datamodules/RGB/utils/image_analytics.py b/src/datamodules/RGB/utils/image_analytics.py index 264e9772..b9b0b989 100644 --- a/src/datamodules/RGB/utils/image_analytics.py +++ b/src/datamodules/RGB/utils/image_analytics.py @@ -262,80 +262,6 @@ def get_class_weights(input_folder, workers=4, **kwargs): return class_weights -def compute_mean_std_graphs(dataset, **kwargs): - """ - Computes mean and std of all node and edge features present in the given ParsedGxlDataset (see gxl_parser.py). - - Parameters - ---------- - input_folder : ParsedGxlDataset - Dataset object (see above for details) - - # TODO implement online version - - Returns - ------- - node_features : {"mean": list, "std": list} - Mean and std value of all node features in the input dataset - edge_features : {"mean": list, "std": list} - Mean and std value of all edge features in the input dataset - """ - if dataset.data.x is not None: - logging.info('Begin computing the node feature mean and std') - nodes = _get_feature_mean_std(dataset.data.x) - logging.info('Finished computing the node feature mean and std') - else: - nodes = {} - logging.info('No node features present') - - if dataset.data.edge_attr is not None: - logging.info('Begin computing the edge feature mean and std') - edges = _get_feature_mean_std(dataset.data.edge_attr) - logging.info('Finished computing the edge feature mean and std') - else: - edges = {} - logging.info('No edge features present') - - return nodes, edges - - -def _get_feature_mean_std(torch_array): - array = np.array(torch_array) - return {'mean': [np.mean(col) for col in array.T], 'std': [np.std(col) for col in array.T]} - - -def get_class_weights_graphs(dataset, **kwargs): - """ - Get the weights proportional to the inverse of their class frequencies. - The vector sums up to 1 - - Parameters - ---------- - input_folder : ParsedGxlDataset - Dataset object (see above for details) - - # TODO implement online version - - Returns - ------- - ndarray[double] of size (num_classes) - The weights vector as a 1D array normalized (sum up to 1) - """ - logging.info('Begin computing class frequencies weights') - - class_frequencies = np.array(dataset.config['class_freq'][1]) - # Class weights are the inverse of the class frequencies - class_weights = 1 / class_frequencies - # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) - class_weights /= class_weights.sum() - - logging.info('Finished computing class frequencies weights ') - logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') - logging.info(f'Class weights (rounded): {np.around(class_weights)}') - - return class_weights - - def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): """ Get the weights proportional to the inverse of their class frequencies. diff --git a/src/datamodules/RolfFormat/utils/image_analytics.py b/src/datamodules/RolfFormat/utils/image_analytics.py index ed332a48..f883ebcb 100644 --- a/src/datamodules/RolfFormat/utils/image_analytics.py +++ b/src/datamodules/RolfFormat/utils/image_analytics.py @@ -1,6 +1,4 @@ # Utils -import errno -import json import logging import os from multiprocessing import Pool @@ -214,80 +212,6 @@ def get_class_weights(input_folder, workers=4, **kwargs): return class_weights -def compute_mean_std_graphs(dataset, **kwargs): - """ - Computes mean and std of all node and edge features present in the given ParsedGxlDataset (see gxl_parser.py). - - Parameters - ---------- - input_folder : ParsedGxlDataset - Dataset object (see above for details) - - # TODO implement online version - - Returns - ------- - node_features : {"mean": list, "std": list} - Mean and std value of all node features in the input dataset - edge_features : {"mean": list, "std": list} - Mean and std value of all edge features in the input dataset - """ - if dataset.data.x is not None: - logging.info('Begin computing the node feature mean and std') - nodes = _get_feature_mean_std(dataset.data.x) - logging.info('Finished computing the node feature mean and std') - else: - nodes = {} - logging.info('No node features present') - - if dataset.data.edge_attr is not None: - logging.info('Begin computing the edge feature mean and std') - edges = _get_feature_mean_std(dataset.data.edge_attr) - logging.info('Finished computing the edge feature mean and std') - else: - edges = {} - logging.info('No edge features present') - - return nodes, edges - - -def _get_feature_mean_std(torch_array): - array = np.array(torch_array) - return {'mean': [np.mean(col) for col in array.T], 'std': [np.std(col) for col in array.T]} - - -def get_class_weights_graphs(dataset, **kwargs): - """ - Get the weights proportional to the inverse of their class frequencies. - The vector sums up to 1 - - Parameters - ---------- - input_folder : ParsedGxlDataset - Dataset object (see above for details) - - # TODO implement online version - - Returns - ------- - ndarray[double] of size (num_classes) - The weights vector as a 1D array normalized (sum up to 1) - """ - logging.info('Begin computing class frequencies weights') - - class_frequencies = np.array(dataset.config['class_freq'][1]) - # Class weights are the inverse of the class frequencies - class_weights = 1 / class_frequencies - # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) - class_weights /= class_weights.sum() - - logging.info('Finished computing class frequencies weights ') - logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') - logging.info(f'Class weights (rounded): {np.around(class_weights)}') - - return class_weights - - def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): """ Get the weights proportional to the inverse of their class frequencies. diff --git a/src/datamodules/RotNet/utils/image_analytics.py b/src/datamodules/RotNet/utils/image_analytics.py index f62ebbec..b131466f 100644 --- a/src/datamodules/RotNet/utils/image_analytics.py +++ b/src/datamodules/RotNet/utils/image_analytics.py @@ -224,80 +224,6 @@ def get_class_weights(input_folder, workers=4, **kwargs): return class_weights -def compute_mean_std_graphs(dataset, **kwargs): - """ - Computes mean and std of all node and edge features present in the given ParsedGxlDataset (see gxl_parser.py). - - Parameters - ---------- - input_folder : ParsedGxlDataset - Dataset object (see above for details) - - # TODO implement online version - - Returns - ------- - node_features : {"mean": list, "std": list} - Mean and std value of all node features in the input dataset - edge_features : {"mean": list, "std": list} - Mean and std value of all edge features in the input dataset - """ - if dataset.data.x is not None: - logging.info('Begin computing the node feature mean and std') - nodes = _get_feature_mean_std(dataset.data.x) - logging.info('Finished computing the node feature mean and std') - else: - nodes = {} - logging.info('No node features present') - - if dataset.data.edge_attr is not None: - logging.info('Begin computing the edge feature mean and std') - edges = _get_feature_mean_std(dataset.data.edge_attr) - logging.info('Finished computing the edge feature mean and std') - else: - edges = {} - logging.info('No edge features present') - - return nodes, edges - - -def _get_feature_mean_std(torch_array): - array = np.array(torch_array) - return {'mean': [np.mean(col) for col in array.T], 'std': [np.std(col) for col in array.T]} - - -def get_class_weights_graphs(dataset, **kwargs): - """ - Get the weights proportional to the inverse of their class frequencies. - The vector sums up to 1 - - Parameters - ---------- - input_folder : ParsedGxlDataset - Dataset object (see above for details) - - # TODO implement online version - - Returns - ------- - ndarray[double] of size (num_classes) - The weights vector as a 1D array normalized (sum up to 1) - """ - logging.info('Begin computing class frequencies weights') - - class_frequencies = np.array(dataset.config['class_freq'][1]) - # Class weights are the inverse of the class frequencies - class_weights = 1 / class_frequencies - # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) - class_weights /= class_weights.sum() - - logging.info('Finished computing class frequencies weights ') - logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') - logging.info(f'Class weights (rounded): {np.around(class_weights)}') - - return class_weights - - def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): """ Get the weights proportional to the inverse of their class frequencies. From f3b1da6331373807e015dc9ae4f9ca73157dba68 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 19 Nov 2021 11:10:26 +0100 Subject: [PATCH 072/108] :art: check if we are using a wandb logger before using a wandb object --- src/execute.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/execute.py b/src/execute.py index 17661fb4..0e346417 100644 --- a/src/execute.py +++ b/src/execute.py @@ -112,9 +112,11 @@ def execute(config: DictConfig) -> Optional[float]: if trainer.is_global_zero: with open('config.yaml', mode='w') as fp: OmegaConf.save(config=config, f=fp) - run_config_folder_path = Path(wandb.run.dir) / 'run_config' - run_config_folder_path.mkdir(exist_ok=True) - shutil.copyfile('config.yaml', str(run_config_folder_path / 'config.yaml')) + if config.get('logger') is not None and 'wandb' in config.get('logger'): + if '_target_' in config.logger.wandb: + run_config_folder_path = Path(wandb.run.dir) / 'run_config' + run_config_folder_path.mkdir(exist_ok=True) + shutil.copyfile('config.yaml', str(run_config_folder_path / 'config.yaml')) if config.train: # Train the model From 3c952725531a95a3fe3ace6c6267dc81ea9d0077 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 19 Nov 2021 14:57:43 +0100 Subject: [PATCH 073/108] :wrench: introduce flag predict --- configs/experiment/cb55_full_run_unet.yaml | 1 + configs/experiment/cb55_select_train15_unet.yaml | 1 + configs/experiment/cb55_select_train1_val1_unet.yaml | 1 + configs/experiment/dev_rgb_full_page.yaml | 1 + configs/experiment/dev_rolf_format.yaml | 1 + configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml | 1 + .../dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml | 1 + configs/experiment/dev_rotnet_resnet18_cb55_10.yaml | 1 + configs/experiment/dev_rotnet_resnet50_cb55_10.yaml | 1 + configs/experiment/development_baby_unet_cb55_10.yaml | 1 + configs/experiment/development_baby_unet_rgb_data.yaml | 1 + configs/experiment/rotnet_resnet18_cb55_full.yaml | 1 + configs/experiment/rotnet_resnet18_cb55_train10_last.yaml | 1 + configs/experiment/rotnet_resnet18_cb55_train19_last.yaml | 1 + configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml | 1 + .../synthetic_baby_unet_layoutD_gtD_rolf_format.yaml | 1 + configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml | 1 + src/execute.py | 7 +++++++ 18 files changed, 24 insertions(+) diff --git a/configs/experiment/cb55_full_run_unet.yaml b/configs/experiment/cb55_full_run_unet.yaml index e6cc1894..f42a9817 100644 --- a/configs/experiment/cb55_full_run_unet.yaml +++ b/configs/experiment/cb55_full_run_unet.yaml @@ -24,6 +24,7 @@ defaults: train: True test: True +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/cb55_select_train15_unet.yaml b/configs/experiment/cb55_select_train15_unet.yaml index 400b4b97..550931b1 100644 --- a/configs/experiment/cb55_select_train15_unet.yaml +++ b/configs/experiment/cb55_select_train15_unet.yaml @@ -24,6 +24,7 @@ defaults: train: True test: True +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/cb55_select_train1_val1_unet.yaml b/configs/experiment/cb55_select_train1_val1_unet.yaml index fbc81849..407f304e 100644 --- a/configs/experiment/cb55_select_train1_val1_unet.yaml +++ b/configs/experiment/cb55_select_train1_val1_unet.yaml @@ -24,6 +24,7 @@ defaults: train: True test: True +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/dev_rgb_full_page.yaml b/configs/experiment/dev_rgb_full_page.yaml index c635992e..8cf2fd1f 100644 --- a/configs/experiment/dev_rgb_full_page.yaml +++ b/configs/experiment/dev_rgb_full_page.yaml @@ -27,6 +27,7 @@ seed: 42 train: True test: True +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/dev_rolf_format.yaml b/configs/experiment/dev_rolf_format.yaml index fe7c5b6b..aa25bb26 100644 --- a/configs/experiment/dev_rolf_format.yaml +++ b/configs/experiment/dev_rolf_format.yaml @@ -28,6 +28,7 @@ seed: 42 train: True test: True +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml b/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml index 3d2e9735..55dc405d 100644 --- a/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml +++ b/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml @@ -27,6 +27,7 @@ seed: 42 train: True test: False +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml b/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml index 5d64686a..4f565411 100644 --- a/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml +++ b/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml @@ -27,6 +27,7 @@ seed: 42 train: True test: False +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml index 7cc285ca..7ec3295e 100644 --- a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml +++ b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml @@ -27,6 +27,7 @@ seed: 42 train: True test: False +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml b/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml index 3c65fe6c..5d927551 100644 --- a/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml +++ b/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml @@ -27,6 +27,7 @@ seed: 42 train: True test: False +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/development_baby_unet_cb55_10.yaml b/configs/experiment/development_baby_unet_cb55_10.yaml index 5ac3bcdf..b7432b00 100644 --- a/configs/experiment/development_baby_unet_cb55_10.yaml +++ b/configs/experiment/development_baby_unet_cb55_10.yaml @@ -27,6 +27,7 @@ seed: 42 train: True test: True +predict: True trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/development_baby_unet_rgb_data.yaml b/configs/experiment/development_baby_unet_rgb_data.yaml index 64c21f4d..cf62491c 100644 --- a/configs/experiment/development_baby_unet_rgb_data.yaml +++ b/configs/experiment/development_baby_unet_rgb_data.yaml @@ -27,6 +27,7 @@ seed: 42 train: True test: True +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/rotnet_resnet18_cb55_full.yaml b/configs/experiment/rotnet_resnet18_cb55_full.yaml index 7990c297..fe4ffca3 100644 --- a/configs/experiment/rotnet_resnet18_cb55_full.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_full.yaml @@ -25,6 +25,7 @@ defaults: train: True test: False +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml b/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml index d997230b..92b988f1 100644 --- a/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml @@ -25,6 +25,7 @@ defaults: train: True test: False +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml b/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml index 86c0d443..ac9b1547 100644 --- a/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml @@ -25,6 +25,7 @@ defaults: train: True test: False +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml b/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml index efa7f52f..85fb90aa 100644 --- a/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml @@ -25,6 +25,7 @@ defaults: train: True test: True +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml b/configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml index 52833445..d83bd2a1 100644 --- a/configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml @@ -26,6 +26,7 @@ defaults: train: True test: True +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml b/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml index e9869db1..938e809f 100644 --- a/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml @@ -25,6 +25,7 @@ defaults: train: True test: True +predict: False trainer: _target_: pytorch_lightning.Trainer diff --git a/src/execute.py b/src/execute.py index 0e346417..5bd04fdb 100644 --- a/src/execute.py +++ b/src/execute.py @@ -129,6 +129,10 @@ def execute(config: DictConfig) -> Optional[float]: results = trainer.test(model=task, datamodule=datamodule) log.info(f'Test output: {results}') + if config.predict: + log.info("Starting prediction!") + trainer.predict(model=task, datamodule=datamodule) + # Make sure everything closed properly log.info("Finalizing!") utils.finish( @@ -181,6 +185,9 @@ def _load_model_part(config: DictConfig, part_name: str): if config.test and not config.train: log.warn(f"You are just testing without a trained {part_name} model! " "Use 'path_to_weights' in your model to load a trained model") + if config.predict and not config.train: + log.warn(f"You are just predicting without a trained {part_name} model! " + "Use 'path_to_weights' in your model to load a trained model") part: LightningModule = hydra.utils.instantiate(config.model.get(part_name)) return part From 52650457f17f54588be52e0948e96f6a5d1ee4d3 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 19 Nov 2021 15:48:18 +0100 Subject: [PATCH 074/108] :sparkles: dataset done --- src/datamodules/utils/predict_dataset.py | 81 +++++++++++++++++++ tests/datamodules/util/__init__.py | 0 .../datamodules/util/test_predict_dataset.py | 31 +++++++ 3 files changed, 112 insertions(+) create mode 100644 src/datamodules/utils/predict_dataset.py create mode 100644 tests/datamodules/util/__init__.py create mode 100644 tests/datamodules/util/test_predict_dataset.py diff --git a/src/datamodules/utils/predict_dataset.py b/src/datamodules/utils/predict_dataset.py new file mode 100644 index 00000000..44bc34af --- /dev/null +++ b/src/datamodules/utils/predict_dataset.py @@ -0,0 +1,81 @@ +from typing import List + +import torch.utils.data as data +from torch import is_tensor +from torchvision.datasets.folder import pil_loader +from torchvision.transforms import ToTensor + +from src.utils import utils + +log = utils.get_logger(__name__) +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] + + +class PredictDataset(data.Dataset): + + def __init__(self, image_path_list: List[str], + image_transform=None, target_transform=None, twin_transform=None, + classes=None, **kwargs): + """ + Parameters + ---------- + classes : + image_transform : callable + target_transform : callable + twin_transform : callable + """ + + self.image_path_list = image_path_list + + # Init list + self.classes = classes + # self.crops_per_image = crops_per_image + + # transformations + self.image_transform = image_transform + self.target_transform = target_transform + self.twin_transform = twin_transform + + self.num_samples = len(self.image_path_list) + if self.num_samples == 0: + raise RuntimeError(f'List of image paths is empty!') + + def __len__(self): + """ + This function returns the length of an epoch so the data loader knows when to stop. + The length is different during train/val and test, because we process the whole image during testing, + and only sample from the images during train/val. + """ + return self.num_samples + + def __getitem__(self, index): + data_img = self._load_data_and_gt(index=index) + data_tensor = self._apply_transformation(img=data_img) + return data_tensor + + def _load_data_and_gt(self, index): + data_img = pil_loader(self.image_path_list[index]) + return data_img + + def _apply_transformation(self, img): + """ + Applies the transformations that have been defined in the setup (setup.py). If no transformations + have been defined, the PIL image is returned instead. + + Parameters + ---------- + img: PIL image + image data + Returns + ------- + tuple + img and gt after transformations + """ + if self.image_transform is not None: + # perform transformations + img, _ = self.image_transform(img, None) + + if not is_tensor(img): + img = ToTensor()(img) + + return img diff --git a/tests/datamodules/util/__init__.py b/tests/datamodules/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datamodules/util/test_predict_dataset.py b/tests/datamodules/util/test_predict_dataset.py new file mode 100644 index 00000000..16118978 --- /dev/null +++ b/tests/datamodules/util/test_predict_dataset.py @@ -0,0 +1,31 @@ +import pytest +import torch +from torchvision.transforms import ToTensor + +from src.datamodules.utils.predict_dataset import PredictDataset +from tests.test_data.dummy_data_hisdb.dummy_data import data_dir + + +@pytest.fixture +def file_path_list(data_dir): + test_data_path = data_dir / 'test' / 'data' + return list(test_data_path.iterdir()) + + +@pytest.fixture +def predict_dataset(file_path_list): + return PredictDataset(image_path_list=file_path_list) + + +def test__load_data_and_gt(predict_dataset): + img = predict_dataset._load_data_and_gt(index=0) + assert img.size == (487, 649) + assert img.mode == 'RGB' + assert ToTensor()(img) == predict_dataset[0] + + +def test__apply_transformation(predict_dataset): + img = predict_dataset._load_data_and_gt(index=0) + img_tensor = predict_dataset._apply_transformation(img) + assert torch.equal(img_tensor, predict_dataset[0]) + assert img_tensor.shape == torch.Size((3, 649, 487)) From 55cf724e4a927956e7aa8f43f886cf10c79062dd Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Fri, 19 Nov 2021 15:56:58 +0100 Subject: [PATCH 075/108] :recycle: renamed the dataset --- .../utils/{predict_dataset.py => dataset_predict.py} | 2 +- tests/datamodules/util/test_predict_dataset.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename src/datamodules/utils/{predict_dataset.py => dataset_predict.py} (98%) diff --git a/src/datamodules/utils/predict_dataset.py b/src/datamodules/utils/dataset_predict.py similarity index 98% rename from src/datamodules/utils/predict_dataset.py rename to src/datamodules/utils/dataset_predict.py index 44bc34af..8fc282de 100644 --- a/src/datamodules/utils/predict_dataset.py +++ b/src/datamodules/utils/dataset_predict.py @@ -11,7 +11,7 @@ IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] -class PredictDataset(data.Dataset): +class DatasetPredict(data.Dataset): def __init__(self, image_path_list: List[str], image_transform=None, target_transform=None, twin_transform=None, diff --git a/tests/datamodules/util/test_predict_dataset.py b/tests/datamodules/util/test_predict_dataset.py index 16118978..d8f5b159 100644 --- a/tests/datamodules/util/test_predict_dataset.py +++ b/tests/datamodules/util/test_predict_dataset.py @@ -2,7 +2,7 @@ import torch from torchvision.transforms import ToTensor -from src.datamodules.utils.predict_dataset import PredictDataset +from src.datamodules.utils.dataset_predict import DatasetPredict from tests.test_data.dummy_data_hisdb.dummy_data import data_dir @@ -14,14 +14,14 @@ def file_path_list(data_dir): @pytest.fixture def predict_dataset(file_path_list): - return PredictDataset(image_path_list=file_path_list) + return DatasetPredict(image_path_list=file_path_list) def test__load_data_and_gt(predict_dataset): img = predict_dataset._load_data_and_gt(index=0) assert img.size == (487, 649) assert img.mode == 'RGB' - assert ToTensor()(img) == predict_dataset[0] + assert torch.equal(ToTensor()(img), predict_dataset[0]) def test__apply_transformation(predict_dataset): From f56582bc8bab1093b13c1533762654fb8518c338 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 22 Nov 2021 10:33:35 +0100 Subject: [PATCH 076/108] :sparkles: added predict_step in base and rgb full page segmentation --- .../RGB/semantic_segmentation_full_page.py | 29 ++++++++++++++++++- src/tasks/base_task.py | 4 +++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/tasks/RGB/semantic_segmentation_full_page.py b/src/tasks/RGB/semantic_segmentation_full_page.py index d8bc2214..adac15d9 100644 --- a/src/tasks/RGB/semantic_segmentation_full_page.py +++ b/src/tasks/RGB/semantic_segmentation_full_page.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional, Callable, Union +from typing import Optional, Callable, Union, Any import numpy as np import torch.nn as nn @@ -114,3 +114,30 @@ def test_step(self, batch, batch_idx, **kwargs): def on_test_end(self) -> None: pass + + ############################################################################################# + ######################################### PREDICT ########################################### + ############################################################################################# + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: + input_batch, input_idx = batch + output = super().predict_step(batch=input_batch, batch_idx=batch_idx, dataloader_idx=dataloader_idx) + + if not hasattr(self.trainer.datamodule, 'get_img_name'): + raise NotImplementedError('Datamodule does not provide detailed information of the crop') + + for pred_raw, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), + input_idx.detach().cpu().numpy()): + patch_info = self.trainer.datamodule.get_img_name(idx) + img_name = patch_info[0] + dest_folder = self.test_output_path / 'preds_raw' + dest_folder.mkdir(parents=True, exist_ok=True) + dest_filename = dest_folder / f'{img_name}.npy' + np.save(file=str(dest_filename), arr=pred_raw) + + dest_folder = self.test_output_path / 'preds' + dest_folder.mkdir(parents=True, exist_ok=True) + save_output_page_image(image_name=f'{img_name}.gif', output_image=pred_raw, + output_folder=dest_folder, class_encoding=self.trainer.datamodule.class_encodings) + + return reduce_dict(input_dict=output, key_list=[]) diff --git a/src/tasks/base_task.py b/src/tasks/base_task.py index 976f20eb..595361dc 100644 --- a/src/tasks/base_task.py +++ b/src/tasks/base_task.py @@ -231,6 +231,10 @@ def test_epoch_end(self, outputs: Any) -> None: self.metric_conf_mat_test.reset() + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: + y_hat = self(batch) + return {OutputKeys.PREDICTION: y_hat} + def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: optimizer = self.optimizer if not isinstance(self.optimizer, Optimizer): From 8a56050ca418908d711f1e381dad0d4164bf60d7 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 22 Nov 2021 10:36:20 +0100 Subject: [PATCH 077/108] :books: added testing badge --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 3771ecb7..e7bcd591 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ Lightning Config: Hydra Template
+![tests](https://github.com/DIVA-DIA/unsupervised_learning/actions/workflows/ci-testing.yml/badge.svg) [comment]: <> ([![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://www.nature.com/articles/nature14539)) From e65a227dd523fb29cb92eebea76cf694065474c4 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 22 Nov 2021 12:38:26 +0100 Subject: [PATCH 078/108] :sparkles: :wrench: added prediction. It takes a list of files to predict from --- .../rolf_format_dev_prediction.yaml | 164 ++++++++++++++++++ .../dev_rolf_format_prediction.yaml | 64 +++++++ src/datamodules/RolfFormat/datamodule.py | 39 ++++- src/datamodules/utils/dataset_predict.py | 4 +- .../RGB/semantic_segmentation_full_page.py | 7 +- src/tasks/base_task.py | 3 + 6 files changed, 271 insertions(+), 10 deletions(-) create mode 100644 configs/datamodule/rolf_format_dev_prediction.yaml create mode 100644 configs/experiment/dev_rolf_format_prediction.yaml diff --git a/configs/datamodule/rolf_format_dev_prediction.yaml b/configs/datamodule/rolf_format_dev_prediction.yaml new file mode 100644 index 00000000..24fcf872 --- /dev/null +++ b/configs/datamodule/rolf_format_dev_prediction.yaml @@ -0,0 +1,164 @@ +_target_: src.datamodules.RolfFormat.datamodule.DataModuleRolfFormat + +num_workers: 4 +batch_size: 8 +shuffle: True +drop_last: False + +data_root: /netscratch/datasets/semantic_segmentation/rolf_format + +train_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1000 + range_to: 1059 + append2: + doc_dir: "SetA1_sizeM_Rolf/layoutR/data" + doc_names: "A1-MR-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" + gt_names: "A1-MR-truthD-####.gif" + range_from: 1000 + range_to: 1059 + +val_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1060 + range_to: 1079 + append2: + doc_dir: "SetA1_sizeM_Rolf/layoutR/data" + doc_names: "A1-MR-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" + gt_names: "A1-MR-truthD-####.gif" + range_from: 1060 + range_to: 1079 + +test_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1080 + range_to: 1099 + append2: + doc_dir: "SetA1_sizeM_Rolf/layoutR/data" + doc_names: "A1-MR-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" + gt_names: "A1-MR-truthD-####.gif" + range_from: 1080 + range_to: 1099 + +pred_file_path_list: + - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1061.jpg" + - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1062.jpg" + - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1063.jpg" + - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1064.jpg" + +image_dims: + width: 640 + height: 896 + +image_analytics: + mean: + R: 0.8664800196201524 + G: 0.7408864118075618 + B: 0.6299955083595935 + std: + R: 0.2156624188591712 + G: 0.20890185198454636 + B: 0.1870731300038113 + +classes: + class0: + color: + R: 0 + G: 0 + B: 0 + weight: 0.0016602289391364547 + class1: + color: + R: 0 + G: 102 + B: 0 + weight: 0.22360020547468618 + class2: + color: + R: 0 + G: 102 + B: 102 + weight: 0.014794833923108578 + class3: + color: + R: 0 + G: 153 + B: 153 + weight: 0.05384506923533185 + class4: + color: + R: 0 + G: 255 + B: 0 + weight: 0.1115978481679602 + class5: + color: + R: 0 + G: 255 + B: 255 + weight: 0.037436533973406926 + class6: + color: + R: 102 + G: 0 + B: 0 + weight: 0.12569866772812885 + class7: + color: + R: 102 + G: 0 + B: 102 + weight: 0.03591164457353043 + class8: + color: + R: 102 + G: 102 + B: 0 + weight: 0.01062086078798502 + class9: + color: + R: 153 + G: 0 + B: 153 + weight: 0.1491578366712268 + class10: + color: + R: 153 + G: 153 + B: 0 + weight: 0.0414074692141804 + class11: + color: + R: 255 + G: 0 + B: 0 + weight: 0.08600602291055298 + class12: + color: + R: 255 + G: 0 + B: 255 + weight: 0.08349157426652898 + class13: + color: + R: 255 + G: 255 + B: 0 + weight: 0.024771204134236315 + + diff --git a/configs/experiment/dev_rolf_format_prediction.yaml b/configs/experiment/dev_rolf_format_prediction.yaml new file mode 100644 index 00000000..37654f42 --- /dev/null +++ b/configs/experiment/dev_rolf_format_prediction.yaml @@ -0,0 +1,64 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_RGB_full_page.yaml + - /datamodule: rolf_format_dev_prediction.yaml + - /loss: crossentropyloss.yaml + - /metric: iou.yaml + - /model/backbone: baby_unet_model.yaml + - /model/header: identity.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +seed: 42 + +train: False +test: False +predict: True + +model: + backbone: + path_to_weights: /netscratch/experiments_lars_paul/lars/2021-11-22/11-52-43/checkpoints/epoch=1/backbone.pth + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 2 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 1 + confusion_matrix_val: True + confusion_matrix_test: True + +callbacks: + model_checkpoint: + monitor: "val/iou" + mode: "max" + filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'dev-rolf-format' + tags: [ "best_model", "synthetic", "rolf_format" ] + group: 'dev-runs' + notes: "Testing" diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index 200fa617..8f09e82d 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Union, List, Optional import torch @@ -9,6 +10,7 @@ from src.datamodules.RolfFormat.utils.twin_transforms import IntegerEncoding from src.datamodules.RolfFormat.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.datamodules.base_datamodule import AbstractDatamodule +from src.datamodules.utils.dataset_predict import DatasetPredict from src.utils import utils log = utils.get_logger(__name__) @@ -16,15 +18,20 @@ class DataModuleRolfFormat(AbstractDatamodule): def __init__(self, data_root: str, - train_specs=None, val_specs=None, test_specs=None, + train_specs=None, val_specs=None, test_specs=None, pred_file_path_list: List[str] = None, image_analytics=None, classes=None, image_dims=None, num_workers: int = 4, batch_size: int = 8, shuffle: bool = True, drop_last: bool = True): super().__init__() - self.train_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in train_specs.items()] - self.val_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in val_specs.items()] - self.test_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in test_specs.items()] + if train_specs is not None: + self.train_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in train_specs.items()] + if val_specs is not None: + self.val_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in val_specs.items()] + if test_specs is not None: + self.test_dataset_specs = [DatasetSpecs(data_root=data_root, **v) for k, v in test_specs.items()] + if pred_file_path_list is not None: + self.pred_file_path_list = pred_file_path_list if image_analytics is None or classes is None or image_dims is None: train_paths_data_gt = DatasetRolfFormat.get_gt_data_paths(list_specs=self.train_dataset_specs) @@ -144,6 +151,10 @@ def setup(self, stage: Optional[str] = None): **common_kwargs) # self._check_min_num_samples(num_samples=len(self.test), data_split='test', drop_last=False) + if stage == 'predict': + self.predict = DatasetPredict(image_path_list=self.pred_file_path_list, + **common_kwargs) + def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: bool): num_processes = self.trainer.num_processes batch_size = self.batch_size @@ -186,6 +197,14 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] drop_last=False, pin_memory=True) + def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + return DataLoader(self.predict, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + drop_last=False, + pin_memory=True) + def get_img_name(self, index): """ Returns the original filename of the doc image. @@ -197,3 +216,15 @@ def get_img_name(self, index): raise Exception('This method can just be called during testing') return self.test.img_paths_per_page[index][2:] + + def get_img_name_prediction(self, index): + """ + Returns the original filename of the doc image. + You can just use this during testing! + :param index: + :return: + """ + if not hasattr(self, 'predict'): + raise Exception('This method can just be called during prediction') + + return Path(self.predict.image_path_list[index]).stem diff --git a/src/datamodules/utils/dataset_predict.py b/src/datamodules/utils/dataset_predict.py index 8fc282de..ad4a02a6 100644 --- a/src/datamodules/utils/dataset_predict.py +++ b/src/datamodules/utils/dataset_predict.py @@ -25,7 +25,7 @@ def __init__(self, image_path_list: List[str], twin_transform : callable """ - self.image_path_list = image_path_list + self.image_path_list = list(image_path_list) # Init list self.classes = classes @@ -51,7 +51,7 @@ def __len__(self): def __getitem__(self, index): data_img = self._load_data_and_gt(index=index) data_tensor = self._apply_transformation(img=data_img) - return data_tensor + return data_tensor, index def _load_data_and_gt(self, index): data_img = pil_loader(self.image_path_list[index]) diff --git a/src/tasks/RGB/semantic_segmentation_full_page.py b/src/tasks/RGB/semantic_segmentation_full_page.py index adac15d9..f4f1f26c 100644 --- a/src/tasks/RGB/semantic_segmentation_full_page.py +++ b/src/tasks/RGB/semantic_segmentation_full_page.py @@ -128,14 +128,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] for pred_raw, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), input_idx.detach().cpu().numpy()): - patch_info = self.trainer.datamodule.get_img_name(idx) - img_name = patch_info[0] - dest_folder = self.test_output_path / 'preds_raw' + img_name = self.trainer.datamodule.get_img_name_prediction(idx) + dest_folder = self.test_output_path / 'prediction_raw' dest_folder.mkdir(parents=True, exist_ok=True) dest_filename = dest_folder / f'{img_name}.npy' np.save(file=str(dest_filename), arr=pred_raw) - dest_folder = self.test_output_path / 'preds' + dest_folder = self.test_output_path / 'prediction' dest_folder.mkdir(parents=True, exist_ok=True) save_output_page_image(image_name=f'{img_name}.gif', output_image=pred_raw, output_folder=dest_folder, class_encoding=self.trainer.datamodule.class_encodings) diff --git a/src/tasks/base_task.py b/src/tasks/base_task.py index 595361dc..cbb27dbb 100644 --- a/src/tasks/base_task.py +++ b/src/tasks/base_task.py @@ -105,6 +105,9 @@ def setup(self, stage: str): elif stage == 'test': num_samples = len(self.trainer.datamodule.test) datasplit_name = 'test' + elif stage == 'predict': + num_samples = len(self.trainer.datamodule.predict) + datasplit_name = 'predict' else: log.warn(f'Unknown stage ({stage}) during setup!') num_samples = -1 From e235e0a81efac4204a83d58575491a76f1f17e60 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 22 Nov 2021 15:09:31 +0100 Subject: [PATCH 079/108] :recycle: first little clean up. removed pil_loader and used default one as well as file extension checker --- .../DivaHisDB/datamodule_cropped.py | 5 +- .../DivaHisDB/datasets/cropped_dataset.py | 19 +-- src/datamodules/DivaHisDB/utils/misc.py | 45 ------ src/datamodules/RGB/datamodule_cropped.py | 6 +- src/datamodules/RGB/datamodule_full_page.py | 2 +- .../RGB/datasets/cropped_dataset.py | 18 +-- .../RGB/datasets/full_page_dataset.py | 13 +- src/datamodules/RGB/utils/image_analytics.py | 3 +- src/datamodules/RGB/utils/misc.py | 40 ----- .../RGB/utils/single_transforms.py | 144 ------------------ src/datamodules/RGB/utils/twin_transforms.py | 100 ------------ .../RGB/utils/wrapper_transforms.py | 37 ----- .../RolfFormat/datasets/dataset.py | 9 +- .../RolfFormat/utils/functional.py | 63 -------- .../RolfFormat/utils/image_analytics.py | 2 +- src/datamodules/RolfFormat/utils/misc.py | 49 ------ .../RolfFormat/utils/twin_transforms.py | 115 -------------- .../RolfFormat/utils/wrapper_transforms.py | 37 ----- src/datamodules/RotNet/datamodule_cropped.py | 2 +- .../RotNet/datasets/cropped_dataset.py | 11 +- .../RotNet/utils/image_analytics.py | 37 ----- src/datamodules/RotNet/utils/misc.py | 40 ----- .../RotNet/utils/wrapper_transforms.py | 37 ----- src/datamodules/utils/dataset_predict.py | 1 - .../utils/single_transforms.py | 0 .../{DivaHisDB => }/utils/twin_transforms.py | 0 .../utils/wrapper_transforms.py | 0 .../datamodules/RGB/test_full_page_dataset.py | 2 - .../RotNet/datasets/test_cropped_dataset.py | 1 - .../datamodules/util/test_predict_dataset.py | 4 +- tools/generate_cropped_dataset.py | 56 +------ 31 files changed, 49 insertions(+), 849 deletions(-) delete mode 100644 src/datamodules/RGB/utils/single_transforms.py delete mode 100644 src/datamodules/RGB/utils/wrapper_transforms.py delete mode 100644 src/datamodules/RolfFormat/utils/functional.py delete mode 100644 src/datamodules/RolfFormat/utils/misc.py delete mode 100644 src/datamodules/RolfFormat/utils/twin_transforms.py delete mode 100644 src/datamodules/RolfFormat/utils/wrapper_transforms.py delete mode 100644 src/datamodules/RotNet/utils/wrapper_transforms.py rename src/datamodules/{DivaHisDB => }/utils/single_transforms.py (100%) rename src/datamodules/{DivaHisDB => }/utils/twin_transforms.py (100%) rename src/datamodules/{DivaHisDB => }/utils/wrapper_transforms.py (100%) diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index 60600289..e3c52a23 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -8,9 +8,8 @@ from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics from src.datamodules.DivaHisDB.utils.misc import validate_path_for_segmentation -from src.datamodules.DivaHisDB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling, \ - IntegerEncoding -from src.datamodules.DivaHisDB.utils.wrapper_transforms import OnlyImage, OnlyTarget +from src.datamodules.utils.twin_transforms import TwinRandomCrop, IntegerEncoding +from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils log = utils.get_logger(__name__) diff --git a/src/datamodules/DivaHisDB/datasets/cropped_dataset.py b/src/datamodules/DivaHisDB/datasets/cropped_dataset.py index 59ea9c37..b5f3f5fb 100644 --- a/src/datamodules/DivaHisDB/datasets/cropped_dataset.py +++ b/src/datamodules/DivaHisDB/datasets/cropped_dataset.py @@ -10,15 +10,16 @@ import torch.utils.data as data from omegaconf import ListConfig from torch import is_tensor +from torchvision.datasets.folder import pil_loader, has_file_allowed_extension from torchvision.transforms import ToTensor -from src.datamodules.DivaHisDB.utils.misc import has_extension, pil_loader from src.utils import utils -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm') log = utils.get_logger(__name__) + class CroppedHisDBDataset(data.Dataset): """A generic data loader where the images are arranged in this way: :: @@ -150,7 +151,7 @@ def _apply_transformation(self, img, gt): @staticmethod def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None) \ - -> List[Tuple[Path, Path, str, str, Tuple[int, int]]]: + -> List[Tuple[Path, Path, str, str, Tuple[int, int]]]: """ Structure of the folder @@ -210,7 +211,7 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st for path_data_subdir in subitems: if not path_data_subdir.is_dir(): - if has_extension(path_data_subdir.name, IMG_EXTENSIONS): + if has_file_allowed_extension(path_data_subdir.name, IMG_EXTENSIONS): log.warning("image file found in data root: " + str(path_data_subdir)) continue @@ -230,12 +231,12 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st for path_data_file, path_gt_file in zip(sorted(path_data_subdir.iterdir()), sorted(path_gt_subdir.iterdir())): - assert has_extension(path_data_file.name, IMG_EXTENSIONS) == \ - has_extension(path_gt_file.name, IMG_EXTENSIONS), \ - 'get_gt_data_paths(): image file aligned with non-image file' + assert has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) == \ + has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS), \ + 'get_gt_data_paths(): image file aligned with non-image file' - if has_extension(path_data_file.name, IMG_EXTENSIONS) and has_extension(path_gt_file.name, - IMG_EXTENSIONS): + if has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) and \ + has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS): assert path_data_file.stem == path_gt_file.stem, \ 'get_gt_data_paths(): mismatch between data filename and gt filename' coordinates = re.compile(r'.+_x(\d+)_y(\d+)\.') diff --git a/src/datamodules/DivaHisDB/utils/misc.py b/src/datamodules/DivaHisDB/utils/misc.py index c0de22d4..622a001f 100644 --- a/src/datamodules/DivaHisDB/utils/misc.py +++ b/src/datamodules/DivaHisDB/utils/misc.py @@ -1,52 +1,7 @@ -""" -General purpose utility functions. - -""" - from pathlib import Path -# Utils -import numpy as np -from PIL import Image - from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir -try: - import accimage -except ImportError: - accimage = None - - -def has_extension(filename, extensions): - filename_lower = filename.lower() - return any(filename_lower.endswith(ext) for ext in extensions) - - -def pil_loader(path, to_rgb=True): - pic = Image.open(path) - if to_rgb: - pic = convert_to_rgb(pic) - return pic - - -def convert_to_rgb(pic): - if pic.mode == "RGB": - pass - elif pic.mode in ("CMYK", "RGBA", "P"): - pic = pic.convert('RGB') - elif pic.mode == "I": - img = (np.divide(np.array(pic, np.int32), 2 ** 16 - 1) * 255).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - elif pic.mode == "I;16": - img = (np.divide(np.array(pic, np.int16), 2 ** 8 - 1) * 255).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - elif pic.mode == "L": - img = np.array(pic).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - else: - raise TypeError(f"unsupported image type {pic.mode}") - return pic - def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_name: str): if data_dir is None: diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index 01731b74..f05a2825 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -8,10 +8,10 @@ from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB from src.datamodules.RGB.utils.image_analytics import get_analytics from src.datamodules.RGB.utils.misc import validate_path_for_segmentation -from src.datamodules.RGB.utils.twin_transforms import TwinRandomCrop, OneHotEncoding, OneHotToPixelLabelling, \ - IntegerEncoding -from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget +from src.datamodules.RGB.utils.twin_transforms import IntegerEncoding from src.datamodules.base_datamodule import AbstractDatamodule +from src.datamodules.utils.twin_transforms import TwinRandomCrop +from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils log = utils.get_logger(__name__) diff --git a/src/datamodules/RGB/datamodule_full_page.py b/src/datamodules/RGB/datamodule_full_page.py index 437439d7..607742e6 100644 --- a/src/datamodules/RGB/datamodule_full_page.py +++ b/src/datamodules/RGB/datamodule_full_page.py @@ -9,8 +9,8 @@ from src.datamodules.RGB.utils.image_analytics import get_analytics from src.datamodules.RGB.utils.misc import validate_path_for_segmentation from src.datamodules.RGB.utils.twin_transforms import IntegerEncoding -from src.datamodules.RGB.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.datamodules.base_datamodule import AbstractDatamodule +from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils log = utils.get_logger(__name__) diff --git a/src/datamodules/RGB/datasets/cropped_dataset.py b/src/datamodules/RGB/datasets/cropped_dataset.py index b31a2aae..525e1851 100644 --- a/src/datamodules/RGB/datasets/cropped_dataset.py +++ b/src/datamodules/RGB/datasets/cropped_dataset.py @@ -10,12 +10,12 @@ import torch.utils.data as data from omegaconf import ListConfig from torch import is_tensor +from torchvision.datasets.folder import pil_loader, has_file_allowed_extension from torchvision.transforms import ToTensor -from src.datamodules.RGB.utils.misc import has_extension, pil_loader from src.utils import utils -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif') log = utils.get_logger(__name__) @@ -150,7 +150,7 @@ def _apply_transformation(self, img, gt): @staticmethod def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None) \ - -> List[Tuple[Path, Path, str, str, Tuple[int, int]]]: + -> List[Tuple[Path, Path, str, str, Tuple[int, int]]]: """ Structure of the folder @@ -211,7 +211,7 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st for path_data_subdir in subitems: if not path_data_subdir.is_dir(): - if has_extension(path_data_subdir.name, IMG_EXTENSIONS): + if has_file_allowed_extension(path_data_subdir.name, IMG_EXTENSIONS): log.warning("image file found in data root: " + str(path_data_subdir)) continue @@ -231,12 +231,12 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st for path_data_file, path_gt_file in zip(sorted(path_data_subdir.iterdir()), sorted(path_gt_subdir.iterdir())): - assert has_extension(path_data_file.name, IMG_EXTENSIONS) == \ - has_extension(path_gt_file.name, IMG_EXTENSIONS), \ - 'get_gt_data_paths(): image file aligned with non-image file' + assert has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) == \ + has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS), \ + 'get_gt_data_paths(): image file aligned with non-image file' - if has_extension(path_data_file.name, IMG_EXTENSIONS) and has_extension(path_gt_file.name, - IMG_EXTENSIONS): + if has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) and \ + has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS): assert path_data_file.stem == path_gt_file.stem, \ 'get_gt_data_paths(): mismatch between data filename and gt filename' coordinates = re.compile(r'.+_x(\d+)_y(\d+)\.') diff --git a/src/datamodules/RGB/datasets/full_page_dataset.py b/src/datamodules/RGB/datasets/full_page_dataset.py index 7e69b7ab..b9f8293b 100644 --- a/src/datamodules/RGB/datasets/full_page_dataset.py +++ b/src/datamodules/RGB/datasets/full_page_dataset.py @@ -2,21 +2,20 @@ Load a dataset of historic documents by specifying the folder where its located. """ +from dataclasses import dataclass # Utils -import re from pathlib import Path -from typing import List, Tuple, Union, Optional, Any +from typing import List, Tuple, Union, Optional import torch.utils.data as data -from dataclasses import dataclass from omegaconf import ListConfig from torch import is_tensor +from torchvision.datasets.folder import pil_loader, has_file_allowed_extension from torchvision.transforms import ToTensor -from src.datamodules.RGB.utils.misc import has_extension, pil_loader from src.utils import utils -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif') log = utils.get_logger(__name__) @@ -234,8 +233,8 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st if path_data_file.name not in selection: continue - assert has_extension(path_data_file.name, IMG_EXTENSIONS) == \ - has_extension(path_gt_file.name, IMG_EXTENSIONS), \ + assert has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) == \ + has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS), \ 'get_gt_data_paths(): image file aligned with non-image file' assert path_data_file.stem == path_gt_file.stem, \ diff --git a/src/datamodules/RGB/utils/image_analytics.py b/src/datamodules/RGB/utils/image_analytics.py index b9b0b989..3bffa8a5 100644 --- a/src/datamodules/RGB/utils/image_analytics.py +++ b/src/datamodules/RGB/utils/image_analytics.py @@ -13,8 +13,7 @@ import torchvision.datasets as datasets import torchvision.transforms as transforms from PIL import Image - -from src.datamodules.RGB.utils.misc import pil_loader +from torchvision.datasets.folder import pil_loader def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, get_gt_data_paths_func, **kwargs): diff --git a/src/datamodules/RGB/utils/misc.py b/src/datamodules/RGB/utils/misc.py index c0de22d4..3fb47da7 100644 --- a/src/datamodules/RGB/utils/misc.py +++ b/src/datamodules/RGB/utils/misc.py @@ -5,48 +5,8 @@ from pathlib import Path -# Utils -import numpy as np -from PIL import Image - from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir -try: - import accimage -except ImportError: - accimage = None - - -def has_extension(filename, extensions): - filename_lower = filename.lower() - return any(filename_lower.endswith(ext) for ext in extensions) - - -def pil_loader(path, to_rgb=True): - pic = Image.open(path) - if to_rgb: - pic = convert_to_rgb(pic) - return pic - - -def convert_to_rgb(pic): - if pic.mode == "RGB": - pass - elif pic.mode in ("CMYK", "RGBA", "P"): - pic = pic.convert('RGB') - elif pic.mode == "I": - img = (np.divide(np.array(pic, np.int32), 2 ** 16 - 1) * 255).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - elif pic.mode == "I;16": - img = (np.divide(np.array(pic, np.int16), 2 ** 8 - 1) * 255).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - elif pic.mode == "L": - img = np.array(pic).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - else: - raise TypeError(f"unsupported image type {pic.mode}") - return pic - def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_name: str): if data_dir is None: diff --git a/src/datamodules/RGB/utils/single_transforms.py b/src/datamodules/RGB/utils/single_transforms.py deleted file mode 100644 index dc48230b..00000000 --- a/src/datamodules/RGB/utils/single_transforms.py +++ /dev/null @@ -1,144 +0,0 @@ -import math -import random - -import torch -from PIL import Image -from torchvision.transforms import Pad - - -class ResizePad(object): - """ - Perform resizing keeping the aspect ratio of the image --padding type: continuous (black). - Expects PIL image and int value as target_size - (It can be extended to perform other transforms on both PIL image and object boxes.) - - Example: - target_size = 200 - # im: numpy array - img = Image.fromarray(im.astype('uint8'), 'RGB') - img = ResizePad(target_size)(img) - """ - - def __init__(self, target_size): - self.target_size = target_size - self.boxes = torch.Tensor([[0, 0, 0, 0]]) - - def resize(self, img, boxes, size, max_size=1000): - '''Resize the input PIL image to the given size. - Args: - img: (PIL.Image) image to be resized. - boxes: (tensor) object boxes, sized [#ojb,4]. - size: (tuple or int) - - if is tuple, resize image to the size. - - if is int, resize the shorter side to the size while maintaining the aspect ratio. - max_size: (int) when size is int, limit the image longer size to max_size. - This is essential to limit the usage of GPU memory. - Returns: - img: (PIL.Image) resized image. - boxes: (tensor) resized boxes. - ''' - w, h = img.size - if isinstance(size, int): - size_min = min(w, h) - size_max = max(w, h) - sw = sh = float(size) / size_min - if sw * size_max > max_size: - sw = sh = float(max_size) / size_max - ow = int(w * sw + 0.5) - oh = int(h * sh + 0.5) - else: - ow, oh = size - sw = float(ow) / w - sh = float(oh) / h - return img.resize((ow, oh), Image.BILINEAR), \ - boxes * torch.Tensor([sw, sh, sw, sh]) - - def random_crop(self, img, boxes): - '''Crop the given PIL image to a random size and aspect ratio. - A crop of random size of (0.08 to 1.0) of the original size and a random - aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. - Args: - img: (PIL.Image) image to be cropped. - boxes: (tensor) object boxes, sized [#ojb,4]. - Returns: - img: (PIL.Image) randomly cropped image. - boxes: (tensor) randomly cropped boxes. - ''' - success = False - for attempt in range(10): - area = img.size[0] * img.size[1] - target_area = random.uniform(0.56, 1.0) * area - aspect_ratio = random.uniform(3. / 4, 4. / 3) - - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) - - if random.random() < 0.5: - w, h = h, w - - if w <= img.size[0] and h <= img.size[1]: - x = random.randint(0, img.size[0] - w) - y = random.randint(0, img.size[1] - h) - success = True - break - - # Fallback - if not success: - w = h = min(img.size[0], img.size[1]) - x = (img.size[0] - w) // 2 - y = (img.size[1] - h) // 2 - - img = img.crop((x, y, x + w, y + h)) - boxes -= torch.Tensor([x, y, x, y]) - boxes[:, 0::2].clamp_(min=0, max=w - 1) - boxes[:, 1::2].clamp_(min=0, max=h - 1) - return img, boxes - - def center_crop(self, img, boxes, size): - '''Crops the given PIL Image at the center. - Args: - img: (PIL.Image) image to be cropped. - boxes: (tensor) object boxes, sized [#ojb,4]. - size (tuple): desired output size of (w,h). - Returns: - img: (PIL.Image) center cropped image. - boxes: (tensor) center cropped boxes. - ''' - w, h = img.size - ow, oh = size - i = int(round((h - oh) / 2.)) - j = int(round((w - ow) / 2.)) - img = img.crop((j, i, j + ow, i + oh)) - boxes -= torch.Tensor([j, i, j, i]) - boxes[:, 0::2].clamp_(min=0, max=ow - 1) - boxes[:, 1::2].clamp_(min=0, max=oh - 1) - return img, boxes - - def random_flip(self, img, boxes): - '''Randomly flip the given PIL Image. - Args: - img: (PIL Image) image to be flipped. - boxes: (tensor) object boxes, sized [#ojb,4]. - Returns: - img: (PIL.Image) randomly flipped image. - boxes: (tensor) randomly flipped boxes. - ''' - if random.random() < 0.5: - img = img.transpose(Image.FLIP_LEFT_RIGHT) - w = img.width - xmin = w - boxes[:, 2] - xmax = w - boxes[:, 0] - boxes[:, 0] = xmin - boxes[:, 2] = xmax - return img, boxes - - def resize_with_padding(self, img, target_size): - img, boxes = self.resize(img, self.boxes, target_size, max_size=target_size) - padding = (max(0, target_size - img.size[0]) // 2, max(0, target_size - img.size[1]) // 2) - img = Pad(padding)(img) - - return img - - def __call__(self, img): - img = self.resize_with_padding(img, self.target_size) - return img \ No newline at end of file diff --git a/src/datamodules/RGB/utils/twin_transforms.py b/src/datamodules/RGB/utils/twin_transforms.py index 5ba68cad..b0a212c1 100644 --- a/src/datamodules/RGB/utils/twin_transforms.py +++ b/src/datamodules/RGB/utils/twin_transforms.py @@ -1,106 +1,6 @@ -import random - -from torchvision.transforms import functional as F - from src.datamodules.RGB.utils import functional as F_custom -class TwinCompose(object): - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, img, gt): - for t in self.transforms: - img, gt = t(img, gt) - return img, gt - - -class TwinRandomCrop(object): - """Crop the given PIL Images at the same random location""" - - def __init__(self, crop_size): - self.crop_size = crop_size - - def get_params(self, img_size): - """Get parameters for ``crop`` for a random crop""" - w, h = img_size - th = self.crop_size - tw = self.crop_size - - assert w >= tw and h >= th - - if w == tw and h == th: - return 0, 0, h, w - i = random.randint(0, h - th) - j = random.randint(0, w - tw) - return i, j, th, tw - - def __call__(self, img, gt): - i, j, h, w = self.get_params(img.size) - return F.crop(img, i, j, h, w), F.crop(gt, i, j, h, w) - - -class TwinImageToTensor(object): - """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. - Converts a PIL Image or numpy.ndarray (W x H x C) in the range - [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. - """ - - def __call__(self, img, gt): - """ - Args: - pic (PIL Image or numpy.ndarray): Image to be converted to tensor. - Returns: - Tensor: Converted image. - """ - return F.to_tensor(img), F.to_tensor(gt) - - -class ToTensorSlidingWindowCrop(object): - """ - Crop the data and ground truth image at the specified coordinates to the specified size and convert - them to a tensor. - """ - - def __init__(self, crop_size): - self.crop_size = crop_size - - def __call__(self, img, gt, coordinates): - """ - Args: - img (PIL Image): Data image to be cropped and converted to tensor. - gt (PIL Image): Ground truth image to be cropped and converted to tensor. - - Returns: - Data tensor, gt tensor (tuple of tensors): cropped and converted images - - """ - x_position = coordinates[0] - y_position = coordinates[1] - - return F.to_tensor(F.crop(img, x_position, y_position, self.crop_size, self.crop_size)), \ - F.to_tensor(F.crop(gt, x_position, y_position, self.crop_size, self.crop_size)) - - -class OneHotToPixelLabelling(object): - def __call__(self, tensor): - return F_custom.argmax_onehot(tensor) - - -class OneHotEncoding(object): - def __init__(self, class_encodings): - self.class_encodings = class_encodings - - def __call__(self, gt): - """ - Args: - - Returns: - - """ - return F_custom.gt_to_one_hot(gt, self.class_encodings) - - class IntegerEncoding(object): def __init__(self, class_encodings): self.class_encodings = class_encodings diff --git a/src/datamodules/RGB/utils/wrapper_transforms.py b/src/datamodules/RGB/utils/wrapper_transforms.py deleted file mode 100644 index eaa5e437..00000000 --- a/src/datamodules/RGB/utils/wrapper_transforms.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Callable - - -class OnlyImage(object): - """Wrapper function around a single parameter transform. It will be cast only on image""" - - def __init__(self, transform: Callable): - """Initialize the transformation with the transformation to be called. - Could be a compose. - - Parameters - ---------- - transform : torchvision.transforms.transforms - Transformation to wrap - """ - self.transform = transform - - def __call__(self, image, target): - return self.transform(image), target - - -class OnlyTarget(object): - """Wrapper function around a single parameter transform. It will be cast only on target""" - - def __init__(self, transform: Callable): - """Initialize the transformation with the transformation to be called. - Could be a compose. - - Parameters - ---------- - transform : torchvision.transforms.transforms - Transformation to wrap - """ - self.transform = transform - - def __call__(self, image, target): - return image, self.transform(target) \ No newline at end of file diff --git a/src/datamodules/RolfFormat/datasets/dataset.py b/src/datamodules/RolfFormat/datasets/dataset.py index 6534d90d..8e8866bd 100644 --- a/src/datamodules/RolfFormat/datasets/dataset.py +++ b/src/datamodules/RolfFormat/datasets/dataset.py @@ -4,19 +4,18 @@ # Utils import re -from pathlib import Path -from typing import List, Tuple, Union, Optional from dataclasses import asdict, dataclass +from pathlib import Path +from typing import List, Tuple import torch.utils.data as data -from omegaconf import ListConfig from torch import is_tensor +from torchvision.datasets.folder import pil_loader from torchvision.transforms import ToTensor -from src.datamodules.RGB.utils.misc import has_extension, pil_loader from src.utils import utils -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif') log = utils.get_logger(__name__) diff --git a/src/datamodules/RolfFormat/utils/functional.py b/src/datamodules/RolfFormat/utils/functional.py deleted file mode 100644 index b052d201..00000000 --- a/src/datamodules/RolfFormat/utils/functional.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import List - -import torch -from torch.nn.functional import one_hot - - -def gt_to_int_encoding(matrix: torch.Tensor, class_encodings: torch.Tensor): - """ - Convert ground truth tensor or numpy matrix to one-hot encoded matrix - - Parameters - ------- - matrix: float tensor from to_tensor() or numpy array - shape (C x H x W) in the range [0.0, 1.0] or shape (H x W x C) BGR - class_encodings: List of int - Blue channel values that encode the different classes - Returns - ------- - torch.LongTensor of size [#C x H x W] - sparse one-hot encoded multi-class matrix, where #C is the number of classes - """ - integer_encoded = torch.full(size=matrix[0].shape, fill_value=-1, dtype=torch.long) - for index, encoding in enumerate(class_encodings): - mask = torch.logical_and(torch.logical_and( - torch.where(matrix[0] == encoding[0], True, False), - torch.where(matrix[1] == encoding[1], True, False)), - torch.where(matrix[2] == encoding[2], True, False)) - integer_encoded[mask] = index - - return integer_encoded - - -def gt_to_one_hot(matrix: torch.Tensor, class_encodings: torch.Tensor): - """ - Convert ground truth tensor or numpy matrix to one-hot encoded matrix - - Parameters - ------- - matrix: float tensor from to_tensor() or numpy array - shape (C x H x W) in the range [0.0, 1.0] or shape (H x W x C) BGR - class_encodings: List of int - Blue channel values that encode the different classes - Returns - ------- - torch.LongTensor of size [#C x H x W] - sparse one-hot encoded multi-class matrix, where #C is the number of classes - """ - integer_encoded = gt_to_int_encoding(matrix=matrix, class_encodings=class_encodings) - - num_classes = class_encodings.shape[0] - - onehot_encoded = one_hot(input=integer_encoded, num_classes=num_classes) - onehot_encoded = onehot_encoded.swapaxes(1, 2).swapaxes(0, 1) # changes axis from (0, 1, 2) to (2, 0, 1) - - return onehot_encoded - - -def argmax_onehot(tensor: torch.Tensor): - """ - # TODO - """ - output = torch.LongTensor(torch.argmax(tensor, dim=0)) - return output diff --git a/src/datamodules/RolfFormat/utils/image_analytics.py b/src/datamodules/RolfFormat/utils/image_analytics.py index f883ebcb..278c032f 100644 --- a/src/datamodules/RolfFormat/utils/image_analytics.py +++ b/src/datamodules/RolfFormat/utils/image_analytics.py @@ -11,8 +11,8 @@ import torchvision.datasets as datasets import torchvision.transforms as transforms from PIL import Image +from torchvision.datasets.folder import pil_loader -from src.datamodules.RGB.utils.misc import pil_loader from src.datamodules.RolfFormat.datasets.dataset import ImageDimensions diff --git a/src/datamodules/RolfFormat/utils/misc.py b/src/datamodules/RolfFormat/utils/misc.py deleted file mode 100644 index 1ccdad41..00000000 --- a/src/datamodules/RolfFormat/utils/misc.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -General purpose utility functions. - -""" - -from pathlib import Path - -# Utils -import numpy as np -from PIL import Image - -from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir - -try: - import accimage -except ImportError: - accimage = None - - -def has_extension(filename, extensions): - filename_lower = filename.lower() - return any(filename_lower.endswith(ext) for ext in extensions) - - -def pil_loader(path, to_rgb=True): - pic = Image.open(path) - if to_rgb: - pic = convert_to_rgb(pic) - return pic - - -def convert_to_rgb(pic): - if pic.mode == "RGB": - pass - elif pic.mode in ("CMYK", "RGBA", "P"): - pic = pic.convert('RGB') - elif pic.mode == "I": - img = (np.divide(np.array(pic, np.int32), 2 ** 16 - 1) * 255).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - elif pic.mode == "I;16": - img = (np.divide(np.array(pic, np.int16), 2 ** 8 - 1) * 255).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - elif pic.mode == "L": - img = np.array(pic).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - else: - raise TypeError(f"unsupported image type {pic.mode}") - return pic - diff --git a/src/datamodules/RolfFormat/utils/twin_transforms.py b/src/datamodules/RolfFormat/utils/twin_transforms.py deleted file mode 100644 index 5ba68cad..00000000 --- a/src/datamodules/RolfFormat/utils/twin_transforms.py +++ /dev/null @@ -1,115 +0,0 @@ -import random - -from torchvision.transforms import functional as F - -from src.datamodules.RGB.utils import functional as F_custom - - -class TwinCompose(object): - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, img, gt): - for t in self.transforms: - img, gt = t(img, gt) - return img, gt - - -class TwinRandomCrop(object): - """Crop the given PIL Images at the same random location""" - - def __init__(self, crop_size): - self.crop_size = crop_size - - def get_params(self, img_size): - """Get parameters for ``crop`` for a random crop""" - w, h = img_size - th = self.crop_size - tw = self.crop_size - - assert w >= tw and h >= th - - if w == tw and h == th: - return 0, 0, h, w - i = random.randint(0, h - th) - j = random.randint(0, w - tw) - return i, j, th, tw - - def __call__(self, img, gt): - i, j, h, w = self.get_params(img.size) - return F.crop(img, i, j, h, w), F.crop(gt, i, j, h, w) - - -class TwinImageToTensor(object): - """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. - Converts a PIL Image or numpy.ndarray (W x H x C) in the range - [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. - """ - - def __call__(self, img, gt): - """ - Args: - pic (PIL Image or numpy.ndarray): Image to be converted to tensor. - Returns: - Tensor: Converted image. - """ - return F.to_tensor(img), F.to_tensor(gt) - - -class ToTensorSlidingWindowCrop(object): - """ - Crop the data and ground truth image at the specified coordinates to the specified size and convert - them to a tensor. - """ - - def __init__(self, crop_size): - self.crop_size = crop_size - - def __call__(self, img, gt, coordinates): - """ - Args: - img (PIL Image): Data image to be cropped and converted to tensor. - gt (PIL Image): Ground truth image to be cropped and converted to tensor. - - Returns: - Data tensor, gt tensor (tuple of tensors): cropped and converted images - - """ - x_position = coordinates[0] - y_position = coordinates[1] - - return F.to_tensor(F.crop(img, x_position, y_position, self.crop_size, self.crop_size)), \ - F.to_tensor(F.crop(gt, x_position, y_position, self.crop_size, self.crop_size)) - - -class OneHotToPixelLabelling(object): - def __call__(self, tensor): - return F_custom.argmax_onehot(tensor) - - -class OneHotEncoding(object): - def __init__(self, class_encodings): - self.class_encodings = class_encodings - - def __call__(self, gt): - """ - Args: - - Returns: - - """ - return F_custom.gt_to_one_hot(gt, self.class_encodings) - - -class IntegerEncoding(object): - def __init__(self, class_encodings): - self.class_encodings = class_encodings - - def __call__(self, gt): - """ - Args: - - Returns: - - """ - return F_custom.gt_to_int_encoding(gt, self.class_encodings) diff --git a/src/datamodules/RolfFormat/utils/wrapper_transforms.py b/src/datamodules/RolfFormat/utils/wrapper_transforms.py deleted file mode 100644 index eaa5e437..00000000 --- a/src/datamodules/RolfFormat/utils/wrapper_transforms.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Callable - - -class OnlyImage(object): - """Wrapper function around a single parameter transform. It will be cast only on image""" - - def __init__(self, transform: Callable): - """Initialize the transformation with the transformation to be called. - Could be a compose. - - Parameters - ---------- - transform : torchvision.transforms.transforms - Transformation to wrap - """ - self.transform = transform - - def __call__(self, image, target): - return self.transform(image), target - - -class OnlyTarget(object): - """Wrapper function around a single parameter transform. It will be cast only on target""" - - def __init__(self, transform: Callable): - """Initialize the transformation with the transformation to be called. - Could be a compose. - - Parameters - ---------- - transform : torchvision.transforms.transforms - Transformation to wrap - """ - self.transform = transform - - def __call__(self, image, target): - return image, self.transform(target) \ No newline at end of file diff --git a/src/datamodules/RotNet/datamodule_cropped.py b/src/datamodules/RotNet/datamodule_cropped.py index 7365ed2c..b91be885 100644 --- a/src/datamodules/RotNet/datamodule_cropped.py +++ b/src/datamodules/RotNet/datamodule_cropped.py @@ -8,7 +8,7 @@ from src.datamodules.RotNet.utils.image_analytics import get_analytics_data from src.datamodules.RotNet.datasets.cropped_dataset import CroppedRotNet, ROTATION_ANGLES from src.datamodules.RotNet.utils.misc import validate_path_for_self_supervised -from src.datamodules.RotNet.utils.wrapper_transforms import OnlyImage +from src.datamodules.utils.wrapper_transforms import OnlyImage from src.datamodules.base_datamodule import AbstractDatamodule from src.utils import utils diff --git a/src/datamodules/RotNet/datasets/cropped_dataset.py b/src/datamodules/RotNet/datasets/cropped_dataset.py index 55602e00..ead5a7b4 100644 --- a/src/datamodules/RotNet/datasets/cropped_dataset.py +++ b/src/datamodules/RotNet/datasets/cropped_dataset.py @@ -6,18 +6,17 @@ from pathlib import Path from typing import List, Union, Optional -import numpy as np -import torch import torchvision.transforms.functional from omegaconf import ListConfig from torch import is_tensor +from torchvision.datasets.folder import has_file_allowed_extension, pil_loader from torchvision.transforms import ToTensor from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset -from src.datamodules.RotNet.utils.misc import has_extension, pil_loader +from src.datamodules.RotNet.utils.misc import has_extension from src.utils import utils -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm') ROTATION_ANGLES = [0, 90, 180, 270] log = utils.get_logger(__name__) @@ -63,7 +62,7 @@ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str = None **kwargs) def __getitem__(self, index): - data_img = self._load_data_and_gt(index=int(index/len(ROTATION_ANGLES))) + data_img = self._load_data_and_gt(index=int(index / len(ROTATION_ANGLES))) img, gt = self._apply_transformation(data_img, index=index) return img, gt @@ -192,7 +191,7 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st continue for path_data_file in sorted(path_data_subdir.iterdir()): - if has_extension(path_data_file.name, IMG_EXTENSIONS): + if has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS): paths.append(path_data_file) return paths diff --git a/src/datamodules/RotNet/utils/image_analytics.py b/src/datamodules/RotNet/utils/image_analytics.py index b131466f..fcf5c9b9 100644 --- a/src/datamodules/RotNet/utils/image_analytics.py +++ b/src/datamodules/RotNet/utils/image_analytics.py @@ -224,42 +224,5 @@ def get_class_weights(input_folder, workers=4, **kwargs): return class_weights -def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): - """ - Get the weights proportional to the inverse of their class frequencies. - The vector sums up to 1 - - Parameters - ---------- - gt_images: list of strings - Path to all ground truth images, which contain the pixel-wise label - workers: int - Number of workers to use for the mean/std computation - - Returns - ------- - ndarray[double] of size (num_classes) and ints the classes are represented as - The weights vector as a 1D array normalized (sum up to 1) - """ - logging.info('Begin computing class frequencies weights') - - total_num_pixels = 0 - label_counter = {} - - for path in gt_images: - img = np.array(Image.open(path))[:, :, 2].flatten() - total_num_pixels += len(img) - for i, j in zip(*np.unique(img, return_counts=True)): - label_counter[i] = label_counter.get(i, 0) + j - - classes = np.array(sorted(label_counter.keys())) - num_samples_per_class = np.array([label_counter[k] for k in classes]) - class_frequencies = (num_samples_per_class / total_num_pixels) - logging.info('Finished computing class frequencies weights') - logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') - # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) - return (1 / num_samples_per_class) / ((1 / num_samples_per_class).sum()), classes - - if __name__ == '__main__': print(get_analytics_data(input_path=Path('tests/dummy_data/dummy_dataset'))) diff --git a/src/datamodules/RotNet/utils/misc.py b/src/datamodules/RotNet/utils/misc.py index 51655d2d..9c1a3170 100644 --- a/src/datamodules/RotNet/utils/misc.py +++ b/src/datamodules/RotNet/utils/misc.py @@ -1,53 +1,13 @@ -""" -General purpose utility functions. - -""" - from pathlib import Path -# Utils -import numpy as np -from PIL import Image - from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir -try: - import accimage -except ImportError: - accimage = None - def has_extension(filename, extensions): filename_lower = filename.lower() return any(filename_lower.endswith(ext) for ext in extensions) -def pil_loader(path, to_rgb=True): - pic = Image.open(path) - if to_rgb: - pic = convert_to_rgb(pic) - return pic - - -def convert_to_rgb(pic): - if pic.mode == "RGB": - pass - elif pic.mode in ("CMYK", "RGBA", "P"): - pic = pic.convert('RGB') - elif pic.mode == "I": - img = (np.divide(np.array(pic, np.int32), 2 ** 16 - 1) * 255).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - elif pic.mode == "I;16": - img = (np.divide(np.array(pic, np.int16), 2 ** 8 - 1) * 255).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - elif pic.mode == "L": - img = np.array(pic).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - else: - raise TypeError(f"unsupported image type {pic.mode}") - return pic - - def validate_path_for_self_supervised(data_dir, data_folder_name: str): if data_dir is None: raise PathNone("Please provide the path to root dir of the dataset " diff --git a/src/datamodules/RotNet/utils/wrapper_transforms.py b/src/datamodules/RotNet/utils/wrapper_transforms.py deleted file mode 100644 index eaa5e437..00000000 --- a/src/datamodules/RotNet/utils/wrapper_transforms.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Callable - - -class OnlyImage(object): - """Wrapper function around a single parameter transform. It will be cast only on image""" - - def __init__(self, transform: Callable): - """Initialize the transformation with the transformation to be called. - Could be a compose. - - Parameters - ---------- - transform : torchvision.transforms.transforms - Transformation to wrap - """ - self.transform = transform - - def __call__(self, image, target): - return self.transform(image), target - - -class OnlyTarget(object): - """Wrapper function around a single parameter transform. It will be cast only on target""" - - def __init__(self, transform: Callable): - """Initialize the transformation with the transformation to be called. - Could be a compose. - - Parameters - ---------- - transform : torchvision.transforms.transforms - Transformation to wrap - """ - self.transform = transform - - def __call__(self, image, target): - return image, self.transform(target) \ No newline at end of file diff --git a/src/datamodules/utils/dataset_predict.py b/src/datamodules/utils/dataset_predict.py index ad4a02a6..edf95e00 100644 --- a/src/datamodules/utils/dataset_predict.py +++ b/src/datamodules/utils/dataset_predict.py @@ -8,7 +8,6 @@ from src.utils import utils log = utils.get_logger(__name__) -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] class DatasetPredict(data.Dataset): diff --git a/src/datamodules/DivaHisDB/utils/single_transforms.py b/src/datamodules/utils/single_transforms.py similarity index 100% rename from src/datamodules/DivaHisDB/utils/single_transforms.py rename to src/datamodules/utils/single_transforms.py diff --git a/src/datamodules/DivaHisDB/utils/twin_transforms.py b/src/datamodules/utils/twin_transforms.py similarity index 100% rename from src/datamodules/DivaHisDB/utils/twin_transforms.py rename to src/datamodules/utils/twin_transforms.py diff --git a/src/datamodules/DivaHisDB/utils/wrapper_transforms.py b/src/datamodules/utils/wrapper_transforms.py similarity index 100% rename from src/datamodules/DivaHisDB/utils/wrapper_transforms.py rename to src/datamodules/utils/wrapper_transforms.py diff --git a/tests/datamodules/RGB/test_full_page_dataset.py b/tests/datamodules/RGB/test_full_page_dataset.py index ed8aa656..bfc97363 100644 --- a/tests/datamodules/RGB/test_full_page_dataset.py +++ b/tests/datamodules/RGB/test_full_page_dataset.py @@ -28,7 +28,5 @@ def test_dataset_rgb(dataset_train): def test__load_data_and_gt(dataset_train): data_img, gt_img = dataset_train._load_data_and_gt(index=0) assert data_img.size == gt_img.size - assert data_img.format == 'JPEG' assert data_img.mode == 'RGB' - assert gt_img.format == 'PNG' assert gt_img.mode == 'RGB' diff --git a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py index c27d3814..5d821f35 100644 --- a/tests/datamodules/RotNet/datasets/test_cropped_dataset.py +++ b/tests/datamodules/RotNet/datasets/test_cropped_dataset.py @@ -23,7 +23,6 @@ def dataset_train(data_dir_cropped): def test__load_data_and_gt(dataset_train): img = dataset_train._load_data_and_gt(0) assert img.size == (300, 300) - assert img.format == 'PNG' assert np.array_equal(np.array(img)[150][150], np.array([97, 72, 32])) diff --git a/tests/datamodules/util/test_predict_dataset.py b/tests/datamodules/util/test_predict_dataset.py index d8f5b159..8406d350 100644 --- a/tests/datamodules/util/test_predict_dataset.py +++ b/tests/datamodules/util/test_predict_dataset.py @@ -21,11 +21,11 @@ def test__load_data_and_gt(predict_dataset): img = predict_dataset._load_data_and_gt(index=0) assert img.size == (487, 649) assert img.mode == 'RGB' - assert torch.equal(ToTensor()(img), predict_dataset[0]) + assert torch.equal(ToTensor()(img), predict_dataset[0][0]) def test__apply_transformation(predict_dataset): img = predict_dataset._load_data_and_gt(index=0) img_tensor = predict_dataset._apply_transformation(img) - assert torch.equal(img_tensor, predict_dataset[0]) + assert torch.equal(img_tensor, predict_dataset[0][0]) assert img_tensor.shape == torch.Size((3, 649, 487)) diff --git a/tools/generate_cropped_dataset.py b/tools/generate_cropped_dataset.py index 8d5791df..737e64cf 100644 --- a/tools/generate_cropped_dataset.py +++ b/tools/generate_cropped_dataset.py @@ -10,59 +10,12 @@ from datetime import datetime from pathlib import Path -import numpy as np -from PIL import Image +from torchvision.datasets.folder import has_file_allowed_extension, pil_loader from torchvision.transforms import functional as F -from torchvision.utils import save_image from tqdm import tqdm -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif'] -JPG_EXTENSIONS = ['.jpg', '.jpeg'] - -def has_extension(filename, extensions): - """Checks if a file is an allowed extension. - - Adapted from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py. - - Parameters - ---------- - filename : string - path to a file - extensions : list - extensions to match against - Returns - ------- - bool - True if the filename ends with one of given extensions, false otherwise. - """ - filename_lower = filename.lower() - return any(filename_lower.endswith(ext) for ext in extensions) - - -def pil_loader(path, to_rgb=True): - pic = Image.open(path) - if to_rgb: - pic = convert_to_rgb(pic) - return pic - - -def convert_to_rgb(pic): - if pic.mode == "RGB": - pass - elif pic.mode in ("CMYK", "RGBA", "P"): - pic = pic.convert('RGB') - elif pic.mode == "I": - img = (np.divide(np.array(pic, np.int32), 2 ** 16 - 1) * 255).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - elif pic.mode == "I;16": - img = (np.divide(np.array(pic, np.int16), 2 ** 8 - 1) * 255).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - elif pic.mode == "L": - img = np.array(pic).astype(np.uint8) - pic = Image.fromarray(np.stack((img, img, img), axis=2)) - else: - raise TypeError(f"unsupported image type {pic.mode}") - return pic +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif') +JPG_EXTENSIONS = ('.jpg', '.jpeg') def get_img_paths_uncropped(directory): @@ -89,7 +42,7 @@ def get_img_paths_uncropped(directory): continue for img_name in sorted(subdir.iterdir()): - if has_extension(str(img_name), IMG_EXTENSIONS): + if has_file_allowed_extension(str(img_name), IMG_EXTENSIONS): paths.append((subdir / img_name, str(subdir.stem))) return paths @@ -255,7 +208,6 @@ def write_crops(self): # save_image(img, dest_filename) pil_img.save(dest_filename) - def _load_image(self, img_index): """ Inits the variables responsible of tracking which crop should be taken next, the current images and the like. From 1f2ff1e2af40f23f6e8d281c5950f891ae9a2401 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 22 Nov 2021 15:32:24 +0100 Subject: [PATCH 080/108] :recycle: rename of the output --- src/tasks/RGB/semantic_segmentation_full_page.py | 8 ++++---- src/tasks/base_task.py | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/tasks/RGB/semantic_segmentation_full_page.py b/src/tasks/RGB/semantic_segmentation_full_page.py index f4f1f26c..99ee0fb0 100644 --- a/src/tasks/RGB/semantic_segmentation_full_page.py +++ b/src/tasks/RGB/semantic_segmentation_full_page.py @@ -100,12 +100,12 @@ def test_step(self, batch, batch_idx, **kwargs): input_idx.detach().cpu().numpy()): patch_info = self.trainer.datamodule.get_img_name(idx) img_name = patch_info[0] - dest_folder = self.test_output_path / 'preds_raw' + dest_folder = self.test_output_path / 'pred_raw' dest_folder.mkdir(parents=True, exist_ok=True) dest_filename = dest_folder / f'{img_name}.npy' np.save(file=str(dest_filename), arr=pred_raw) - dest_folder = self.test_output_path / 'preds' + dest_folder = self.test_output_path / 'pred' dest_folder.mkdir(parents=True, exist_ok=True) save_output_page_image(image_name=f'{img_name}.gif', output_image=pred_raw, output_folder=dest_folder, class_encoding=self.trainer.datamodule.class_encodings) @@ -129,12 +129,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] for pred_raw, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), input_idx.detach().cpu().numpy()): img_name = self.trainer.datamodule.get_img_name_prediction(idx) - dest_folder = self.test_output_path / 'prediction_raw' + dest_folder = self.predict_output_path / 'pred_raw' dest_folder.mkdir(parents=True, exist_ok=True) dest_filename = dest_folder / f'{img_name}.npy' np.save(file=str(dest_filename), arr=pred_raw) - dest_folder = self.test_output_path / 'prediction' + dest_folder = self.predict_output_path / 'pred' dest_folder.mkdir(parents=True, exist_ok=True) save_output_page_image(image_name=f'{img_name}.gif', output_image=pred_raw, output_folder=dest_folder, class_encoding=self.trainer.datamodule.class_encodings) diff --git a/src/tasks/base_task.py b/src/tasks/base_task.py index cbb27dbb..3173248c 100644 --- a/src/tasks/base_task.py +++ b/src/tasks/base_task.py @@ -60,7 +60,9 @@ def __init__( confusion_matrix_test: Optional[bool] = False, confusion_matrix_log_every_n_epoch: Optional[int] = 1, lr: float = 1e-3, - test_output_path: Optional[Union[str, Path]] = 'predictions' + test_output_path: Optional[Union[str, Path]] = 'test_output', + predict_output_path: Optional[Union[str, Path]] = 'predict_output' + ): super().__init__() @@ -88,6 +90,7 @@ def __init__( self.confusion_matrix_log_every_n_epoch = confusion_matrix_log_every_n_epoch self.lr = lr self.test_output_path = Path(test_output_path) + self.predict_output_path = Path(predict_output_path) self.save_hyperparameters() def setup(self, stage: str): From c4d7653d10f377268777e87991f864307c9a03ab Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 22 Nov 2021 15:38:52 +0100 Subject: [PATCH 081/108] :truck: moved single image transformations --- .../DivaHisDB/datamodule_cropped.py | 3 +- src/datamodules/RGB/datamodule_cropped.py | 2 +- src/datamodules/RGB/datamodule_full_page.py | 2 +- src/datamodules/RGB/utils/twin_transforms.py | 15 -------- src/datamodules/utils/single_transforms.py | 37 ++++++++++++++++++- src/datamodules/utils/twin_transforms.py | 35 ------------------ 6 files changed, 40 insertions(+), 54 deletions(-) delete mode 100644 src/datamodules/RGB/utils/twin_transforms.py diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index e3c52a23..79b70245 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -8,7 +8,8 @@ from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics from src.datamodules.DivaHisDB.utils.misc import validate_path_for_segmentation -from src.datamodules.utils.twin_transforms import TwinRandomCrop, IntegerEncoding +from src.datamodules.utils.twin_transforms import TwinRandomCrop +from src.datamodules.utils.single_transforms import IntegerEncoding from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index f05a2825..9be579bc 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -8,9 +8,9 @@ from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB from src.datamodules.RGB.utils.image_analytics import get_analytics from src.datamodules.RGB.utils.misc import validate_path_for_segmentation -from src.datamodules.RGB.utils.twin_transforms import IntegerEncoding from src.datamodules.base_datamodule import AbstractDatamodule from src.datamodules.utils.twin_transforms import TwinRandomCrop +from src.datamodules.utils.single_transforms import IntegerEncoding from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils diff --git a/src/datamodules/RGB/datamodule_full_page.py b/src/datamodules/RGB/datamodule_full_page.py index 607742e6..3a2ff935 100644 --- a/src/datamodules/RGB/datamodule_full_page.py +++ b/src/datamodules/RGB/datamodule_full_page.py @@ -8,8 +8,8 @@ from src.datamodules.RGB.datasets.full_page_dataset import DatasetRGB, ImageDimensions from src.datamodules.RGB.utils.image_analytics import get_analytics from src.datamodules.RGB.utils.misc import validate_path_for_segmentation -from src.datamodules.RGB.utils.twin_transforms import IntegerEncoding from src.datamodules.base_datamodule import AbstractDatamodule +from src.datamodules.utils.single_transforms import IntegerEncoding from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils diff --git a/src/datamodules/RGB/utils/twin_transforms.py b/src/datamodules/RGB/utils/twin_transforms.py deleted file mode 100644 index b0a212c1..00000000 --- a/src/datamodules/RGB/utils/twin_transforms.py +++ /dev/null @@ -1,15 +0,0 @@ -from src.datamodules.RGB.utils import functional as F_custom - - -class IntegerEncoding(object): - def __init__(self, class_encodings): - self.class_encodings = class_encodings - - def __call__(self, gt): - """ - Args: - - Returns: - - """ - return F_custom.gt_to_int_encoding(gt, self.class_encodings) diff --git a/src/datamodules/utils/single_transforms.py b/src/datamodules/utils/single_transforms.py index dc48230b..0aaa4a8b 100644 --- a/src/datamodules/utils/single_transforms.py +++ b/src/datamodules/utils/single_transforms.py @@ -5,6 +5,8 @@ from PIL import Image from torchvision.transforms import Pad +from src.datamodules.DivaHisDB.utils import functional as F_custom + class ResizePad(object): """ @@ -141,4 +143,37 @@ def resize_with_padding(self, img, target_size): def __call__(self, img): img = self.resize_with_padding(img, self.target_size) - return img \ No newline at end of file + return img + + +class OneHotToPixelLabelling(object): + def __call__(self, tensor): + return F_custom.argmax_onehot(tensor) + + +class OneHotEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_one_hot(gt, self.class_encodings) + + +class IntegerEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_int_encoding(gt, self.class_encodings) diff --git a/src/datamodules/utils/twin_transforms.py b/src/datamodules/utils/twin_transforms.py index 9b984508..4346167f 100644 --- a/src/datamodules/utils/twin_transforms.py +++ b/src/datamodules/utils/twin_transforms.py @@ -2,8 +2,6 @@ from torchvision.transforms import functional as F -from src.datamodules.DivaHisDB.utils import functional as F_custom - class TwinCompose(object): def __init__(self, transforms): @@ -80,36 +78,3 @@ def __call__(self, img, gt, coordinates): return F.to_tensor(F.crop(img, x_position, y_position, self.crop_size, self.crop_size)), \ F.to_tensor(F.crop(gt, x_position, y_position, self.crop_size, self.crop_size)) - - -class OneHotToPixelLabelling(object): - def __call__(self, tensor): - return F_custom.argmax_onehot(tensor) - - -class OneHotEncoding(object): - def __init__(self, class_encodings): - self.class_encodings = class_encodings - - def __call__(self, gt): - """ - Args: - - Returns: - - """ - return F_custom.gt_to_one_hot(gt, self.class_encodings) - - -class IntegerEncoding(object): - def __init__(self, class_encodings): - self.class_encodings = class_encodings - - def __call__(self, gt): - """ - Args: - - Returns: - - """ - return F_custom.gt_to_int_encoding(gt, self.class_encodings) From 53c0f0715ae012d9fc1bfb4b9182fd05b13f1a62 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Mon, 22 Nov 2021 16:20:12 +0100 Subject: [PATCH 082/108] :truck: :fire: :recycle: moved and improved the usage of code --- .../DivaHisDB/datamodule_cropped.py | 4 +- src/datamodules/DivaHisDB/utils/functional.py | 5 - .../DivaHisDB/utils/image_analytics.py | 115 +--------- src/datamodules/DivaHisDB/utils/misc.py | 30 --- .../DivaHisDB/utils/output_tools.py | 52 ----- .../DivaHisDB/utils/single_transform.py | 29 +++ src/datamodules/RGB/datamodule_cropped.py | 4 +- src/datamodules/RGB/datamodule_full_page.py | 4 +- src/datamodules/RGB/utils/functional.py | 8 - src/datamodules/RGB/utils/image_analytics.py | 115 +--------- src/datamodules/RGB/utils/output_tools.py | 53 +---- src/datamodules/RGB/utils/single_transform.py | 29 +++ .../RolfFormat/utils/image_analytics.py | 212 +----------------- .../RolfFormat/utils/output_tools.py | 119 ---------- .../RotNet/datasets/cropped_dataset.py | 3 +- .../RotNet/utils/image_analytics.py | 178 +-------------- src/datamodules/RotNet/utils/misc.py | 5 - src/datamodules/utils/functional.py | 8 + src/datamodules/utils/image_analytics.py | 116 ++++++++++ src/datamodules/{RGB => }/utils/misc.py | 46 +++- src/datamodules/utils/output_tools.py | 36 +++ src/datamodules/utils/single_transforms.py | 29 +-- src/tasks/DivaHisDB/semantic_segmentation.py | 2 +- src/tasks/RGB/semantic_segmentation.py | 2 +- .../RGB/semantic_segmentation_full_page.py | 4 +- tests/datamodules/DivaHisDB/test_misc.py | 2 +- tests/tasks/sem_seg/test_output_tools.py | 6 +- tools/merge_cropped_output_HisDB.py | 3 +- tools/merge_cropped_output_RGB.py | 6 +- 29 files changed, 290 insertions(+), 935 deletions(-) delete mode 100644 src/datamodules/DivaHisDB/utils/misc.py create mode 100644 src/datamodules/DivaHisDB/utils/single_transform.py create mode 100644 src/datamodules/RGB/utils/single_transform.py delete mode 100644 src/datamodules/RolfFormat/utils/output_tools.py create mode 100644 src/datamodules/utils/functional.py create mode 100644 src/datamodules/utils/image_analytics.py rename src/datamodules/{RGB => }/utils/misc.py (55%) create mode 100644 src/datamodules/utils/output_tools.py diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index 79b70245..ab0979bc 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -4,12 +4,12 @@ from torch.utils.data import DataLoader from torchvision import transforms +from src.datamodules.DivaHisDB.utils.single_transform import IntegerEncoding from src.datamodules.base_datamodule import AbstractDatamodule from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset from src.datamodules.DivaHisDB.utils.image_analytics import get_analytics -from src.datamodules.DivaHisDB.utils.misc import validate_path_for_segmentation +from src.datamodules.utils.misc import validate_path_for_segmentation from src.datamodules.utils.twin_transforms import TwinRandomCrop -from src.datamodules.utils.single_transforms import IntegerEncoding from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils diff --git a/src/datamodules/DivaHisDB/utils/functional.py b/src/datamodules/DivaHisDB/utils/functional.py index 33b31346..128c1546 100644 --- a/src/datamodules/DivaHisDB/utils/functional.py +++ b/src/datamodules/DivaHisDB/utils/functional.py @@ -65,8 +65,3 @@ def gt_to_one_hot(matrix: torch.Tensor, class_encodings: List[int]): return torch.LongTensor(one_hot_matrix.transpose((2, 0, 1))) -def argmax_onehot(tensor: torch.Tensor): - """ - # TODO - """ - return torch.LongTensor(torch.argmax(tensor, dim=0)) diff --git a/src/datamodules/DivaHisDB/utils/image_analytics.py b/src/datamodules/DivaHisDB/utils/image_analytics.py index 100e53ab..67caf3ee 100644 --- a/src/datamodules/DivaHisDB/utils/image_analytics.py +++ b/src/datamodules/DivaHisDB/utils/image_analytics.py @@ -3,9 +3,7 @@ import json import logging import os -from multiprocessing import Pool from pathlib import Path -from typing import List import numpy as np # Torch related stuff @@ -14,6 +12,8 @@ import torchvision.transforms as transforms from PIL import Image +from src.datamodules.utils.image_analytics import compute_mean_std + def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, get_gt_data_paths_func, **kwargs): """ @@ -92,117 +92,6 @@ def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, return analytics_data, analytics_gt -def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): - """ - Computes mean and std of all images present at target folder. - - Parameters - ---------- - input_folder : String (path) - Path to the dataset folder (see above for details) - inmem : Boolean - Specifies whether is should be computed i nan online of offline fashion. - workers : int - Number of workers to use for the mean/std computation - - Returns - ------- - mean : float - Mean value of all pixels of the images in the input folder - std : float - Standard deviation of all pixels of the images in the input folder - """ - file_names_np = np.array(list(map(str, file_names))) - # Compute mean and std - mean, std = _cms_inmem(file_names_np) if inmem else _cms_online(file_names_np, workers) - return mean, std - - -def _cms_online(file_names, workers=4): - """ - Computes mean and image_classification deviation in an online fashion. - This is useful when the dataset is too big to be allocated in memory. - - Parameters - ---------- - file_names : List of String - List of file names of the dataset - workers : int - Number of workers to use for the mean/std computation - - Returns - ------- - mean : double - std : double - """ - logging.info('Begin computing the mean') - - # Set up a pool of workers - pool = Pool(workers + 1) - - # Online mean - results = pool.map(_return_mean, file_names) - mean_sum = np.sum(np.array(results), axis=0) - - # Divide by number of samples in train set - mean = mean_sum / file_names.size - - logging.info('Finished computing the mean') - logging.info('Begin computing the std') - - # Online image_classification deviation - results = pool.starmap(_return_std, [[item, mean] for item in file_names]) - std_sum = np.sum(np.array([item[0] for item in results]), axis=0) - total_pixel_count = np.sum(np.array([item[1] for item in results])) - std = np.sqrt(std_sum / total_pixel_count) - logging.info('Finished computing the std') - - # Shut down the pool - pool.close() - - return mean, std - - -# Loads an image with OpenCV and returns the channel wise means of the image. -def _return_mean(image_path): - img = np.array(Image.open(image_path).convert('RGB')) - mean = np.array([np.mean(img[:, :, 0]), np.mean(img[:, :, 1]), np.mean(img[:, :, 2])]) / 255.0 - return mean - - -# Loads an image with OpenCV and returns the channel wise std of the image. -def _return_std(image_path, mean): - img = np.array(Image.open(image_path).convert('RGB')) / 255.0 - m2 = np.square(np.array([img[:, :, 0] - mean[0], img[:, :, 1] - mean[1], img[:, :, 2] - mean[2]])) - return np.sum(np.sum(m2, axis=1), 1), m2.size / 3.0 - - -def _cms_inmem(file_names): - """ - Computes mean and image_classification deviation in an offline fashion. This is possible only when the dataset can - be allocated in memory. - - Parameters - ---------- - file_names: List of String - List of file names of the dataset - Returns - ------- - mean : double - std : double - """ - img = np.zeros([file_names.size] + list(np.array(Image.open(file_names[0]).convert('RGB')).shape)) - - # Load all samples - for i, sample in enumerate(file_names): - img[i] = np.array(Image.open(sample).convert('RGB')) - - mean = np.array([np.mean(img[:, :, :, 0]), np.mean(img[:, :, :, 1]), np.mean(img[:, :, :, 2])]) / 255.0 - std = np.array([np.std(img[:, :, :, 0]), np.std(img[:, :, :, 1]), np.std(img[:, :, :, 2])]) / 255.0 - - return mean, std - - def get_class_weights(input_folder, workers=4, **kwargs): """ Get the weights proportional to the inverse of their class frequencies. diff --git a/src/datamodules/DivaHisDB/utils/misc.py b/src/datamodules/DivaHisDB/utils/misc.py deleted file mode 100644 index 622a001f..00000000 --- a/src/datamodules/DivaHisDB/utils/misc.py +++ /dev/null @@ -1,30 +0,0 @@ -from pathlib import Path - -from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir - - -def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_name: str): - if data_dir is None: - raise PathNone("Please provide the path to root dir of the dataset " - "(folder containing the train/val/test folder)") - else: - split_names = ['train', 'val', 'test'] - type_names = [data_folder_name, gt_folder_name] - - data_folder = Path(data_dir) - if not data_folder.is_dir(): - raise PathNotDir("Please provide the path to root dir of the dataset " - "(folder containing the train/val/test folder)") - split_folders = [d for d in data_folder.iterdir() if d.is_dir() and d.name in split_names] - if len(split_folders) != 3: - raise PathMissingSplitDir(f'Your path needs to contain train/val/test and ' - f'each of them a folder {data_folder_name} and {gt_folder_name}') - - # check if we have train/test/val - for split in split_folders: - type_folders = [d for d in split.iterdir() if d.is_dir() and d.name in type_names] - # check if we have data/gt - if len(type_folders) != 2: - raise PathMissingDirinSplitDir(f'Folder {split.name} does not contain a {data_folder_name} ' - f'and {gt_folder_name} folder') - return Path(data_dir) diff --git a/src/datamodules/DivaHisDB/utils/output_tools.py b/src/datamodules/DivaHisDB/utils/output_tools.py index 06fe4441..d5ca90ac 100644 --- a/src/datamodules/DivaHisDB/utils/output_tools.py +++ b/src/datamodules/DivaHisDB/utils/output_tools.py @@ -1,61 +1,9 @@ from pathlib import Path -from typing import Union import numpy as np -import torch from PIL import Image -def _get_argmax(output: Union[torch.Tensor, np.ndarray], dim=1): - """ - takes the biggest value from a pixel across all classes - :param output: (Batch_size x num_classes x W x H) - matrix with the given attributes - :return: (Batch_size x W x H) - matrix with the hisdb class number for each pixel - """ - if isinstance(output, torch.Tensor): - return torch.argmax(output, dim=dim) - if isinstance(output, np.ndarray): - return np.argmax(output, axis=dim) - return output - - -def merge_patches(patch, coordinates, full_output): - """ - This function merges the patch into the full output image - Overlapping values are resolved by taking the max. - - Parameters - ---------- - patch: numpy matrix of size [#classes x crop_size x crop_size] - a patch from the larger image - coordinates: tuple of ints - top left coordinates of the patch within the larger image for all patches in a batch - full_output: numpy matrix of size [#C x H x W] - output image at full size - Returns - ------- - full_output: numpy matrix [#C x Htot x Wtot] - """ - assert len(full_output.shape) == 3 - assert full_output.size != 0 - - # Resolve patch coordinates - x1, y1 = coordinates - x2, y2 = x1 + patch.shape[2], y1 + patch.shape[1] - - # If this triggers it means that a patch is 'out-of-bounds' of the image and that should never happen! - assert x2 <= full_output.shape[2] - assert y2 <= full_output.shape[1] - - mask = np.isnan(full_output[:, y1:y2, x1:x2]) - # if still NaN in full_output just insert value from crop, if there is a value then take max - full_output[:, y1:y2, x1:x2] = np.where(mask, patch, np.maximum(patch, full_output[:, y1:y2, x1:x2])) - - return full_output - - def save_output_page_image(image_name, output_image, output_folder: Path, class_encoding): """ Helper function to save the output during testing in the DIVAHisDB format diff --git a/src/datamodules/DivaHisDB/utils/single_transform.py b/src/datamodules/DivaHisDB/utils/single_transform.py new file mode 100644 index 00000000..56710f9b --- /dev/null +++ b/src/datamodules/DivaHisDB/utils/single_transform.py @@ -0,0 +1,29 @@ +import src.datamodules.DivaHisDB.utils.functional as F_custom + + +class IntegerEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_int_encoding(gt, self.class_encodings) + + +class OneHotEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_one_hot(gt, self.class_encodings) diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index 9be579bc..de6b7f06 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -7,10 +7,10 @@ from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB from src.datamodules.RGB.utils.image_analytics import get_analytics -from src.datamodules.RGB.utils.misc import validate_path_for_segmentation +from src.datamodules.RGB.utils.single_transform import IntegerEncoding from src.datamodules.base_datamodule import AbstractDatamodule +from src.datamodules.utils.misc import validate_path_for_segmentation from src.datamodules.utils.twin_transforms import TwinRandomCrop -from src.datamodules.utils.single_transforms import IntegerEncoding from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils diff --git a/src/datamodules/RGB/datamodule_full_page.py b/src/datamodules/RGB/datamodule_full_page.py index 3a2ff935..39d31400 100644 --- a/src/datamodules/RGB/datamodule_full_page.py +++ b/src/datamodules/RGB/datamodule_full_page.py @@ -7,9 +7,9 @@ from src.datamodules.RGB.datasets.full_page_dataset import DatasetRGB, ImageDimensions from src.datamodules.RGB.utils.image_analytics import get_analytics -from src.datamodules.RGB.utils.misc import validate_path_for_segmentation +from src.datamodules.RGB.utils.single_transform import IntegerEncoding from src.datamodules.base_datamodule import AbstractDatamodule -from src.datamodules.utils.single_transforms import IntegerEncoding +from src.datamodules.utils.misc import validate_path_for_segmentation from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils diff --git a/src/datamodules/RGB/utils/functional.py b/src/datamodules/RGB/utils/functional.py index b052d201..d99d572b 100644 --- a/src/datamodules/RGB/utils/functional.py +++ b/src/datamodules/RGB/utils/functional.py @@ -53,11 +53,3 @@ def gt_to_one_hot(matrix: torch.Tensor, class_encodings: torch.Tensor): onehot_encoded = onehot_encoded.swapaxes(1, 2).swapaxes(0, 1) # changes axis from (0, 1, 2) to (2, 0, 1) return onehot_encoded - - -def argmax_onehot(tensor: torch.Tensor): - """ - # TODO - """ - output = torch.LongTensor(torch.argmax(tensor, dim=0)) - return output diff --git a/src/datamodules/RGB/utils/image_analytics.py b/src/datamodules/RGB/utils/image_analytics.py index 3bffa8a5..4fe4c515 100644 --- a/src/datamodules/RGB/utils/image_analytics.py +++ b/src/datamodules/RGB/utils/image_analytics.py @@ -3,9 +3,7 @@ import json import logging import os -from multiprocessing import Pool from pathlib import Path -from typing import List import numpy as np # Torch related stuff @@ -15,6 +13,8 @@ from PIL import Image from torchvision.datasets.folder import pil_loader +from src.datamodules.utils.image_analytics import compute_mean_std + def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, get_gt_data_paths_func, **kwargs): """ @@ -97,117 +97,6 @@ def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, return analytics_data, analytics_gt -def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): - """ - Computes mean and std of all images present at target folder. - - Parameters - ---------- - input_folder : String (path) - Path to the dataset folder (see above for details) - inmem : Boolean - Specifies whether is should be computed i nan online of offline fashion. - workers : int - Number of workers to use for the mean/std computation - - Returns - ------- - mean : float - Mean value of all pixels of the images in the input folder - std : float - Standard deviation of all pixels of the images in the input folder - """ - file_names_np = np.array(list(map(str, file_names))) - # Compute mean and std - mean, std = _cms_inmem(file_names_np) if inmem else _cms_online(file_names_np, workers) - return mean, std - - -def _cms_online(file_names, workers=4): - """ - Computes mean and image_classification deviation in an online fashion. - This is useful when the dataset is too big to be allocated in memory. - - Parameters - ---------- - file_names : List of String - List of file names of the dataset - workers : int - Number of workers to use for the mean/std computation - - Returns - ------- - mean : double - std : double - """ - logging.info('Begin computing the mean') - - # Set up a pool of workers - pool = Pool(workers + 1) - - # Online mean - results = pool.map(_return_mean, file_names) - mean_sum = np.sum(np.array(results), axis=0) - - # Divide by number of samples in train set - mean = mean_sum / file_names.size - - logging.info('Finished computing the mean') - logging.info('Begin computing the std') - - # Online image_classification deviation - results = pool.starmap(_return_std, [[item, mean] for item in file_names]) - std_sum = np.sum(np.array([item[0] for item in results]), axis=0) - total_pixel_count = np.sum(np.array([item[1] for item in results])) - std = np.sqrt(std_sum / total_pixel_count) - logging.info('Finished computing the std') - - # Shut down the pool - pool.close() - - return mean, std - - -# Loads an image with OpenCV and returns the channel wise means of the image. -def _return_mean(image_path): - img = np.array(Image.open(image_path).convert('RGB')) - mean = np.array([np.mean(img[:, :, 0]), np.mean(img[:, :, 1]), np.mean(img[:, :, 2])]) / 255.0 - return mean - - -# Loads an image with OpenCV and returns the channel wise std of the image. -def _return_std(image_path, mean): - img = np.array(Image.open(image_path).convert('RGB')) / 255.0 - m2 = np.square(np.array([img[:, :, 0] - mean[0], img[:, :, 1] - mean[1], img[:, :, 2] - mean[2]])) - return np.sum(np.sum(m2, axis=1), 1), m2.size / 3.0 - - -def _cms_inmem(file_names): - """ - Computes mean and image_classification deviation in an offline fashion. This is possible only when the dataset can - be allocated in memory. - - Parameters - ---------- - file_names: List of String - List of file names of the dataset - Returns - ------- - mean : double - std : double - """ - img = np.zeros([file_names.size] + list(np.array(Image.open(file_names[0]).convert('RGB')).shape)) - - # Load all samples - for i, sample in enumerate(file_names): - img[i] = np.array(Image.open(sample).convert('RGB')) - - mean = np.array([np.mean(img[:, :, :, 0]), np.mean(img[:, :, :, 1]), np.mean(img[:, :, :, 2])]) / 255.0 - std = np.array([np.std(img[:, :, :, 0]), np.std(img[:, :, :, 1]), np.std(img[:, :, :, 2])]) / 255.0 - - return mean, std - - def get_class_weights(input_folder, workers=4, **kwargs): """ Get the weights proportional to the inverse of their class frequencies. diff --git a/src/datamodules/RGB/utils/output_tools.py b/src/datamodules/RGB/utils/output_tools.py index 6a472955..100ed1a8 100644 --- a/src/datamodules/RGB/utils/output_tools.py +++ b/src/datamodules/RGB/utils/output_tools.py @@ -1,61 +1,10 @@ from pathlib import Path -from typing import Union, Tuple, List +from typing import Tuple, List import numpy as np -import torch from PIL import Image -def _get_argmax(output: Union[torch.Tensor, np.ndarray], dim=1): - """ - takes the biggest value from a pixel across all classes - :param output: (Batch_size x num_classes x W x H) - matrix with the given attributes - :return: (Batch_size x W x H) - matrix with the hisdb class number for each pixel - """ - if isinstance(output, torch.Tensor): - return torch.argmax(output, dim=dim) - if isinstance(output, np.ndarray): - return np.argmax(output, axis=dim) - return output - - -def merge_patches(patch, coordinates, full_output): - """ - This function merges the patch into the full output image - Overlapping values are resolved by taking the max. - - Parameters - ---------- - patch: numpy matrix of size [#classes x crop_size x crop_size] - a patch from the larger image - coordinates: tuple of ints - top left coordinates of the patch within the larger image for all patches in a batch - full_output: numpy matrix of size [#C x H x W] - output image at full size - Returns - ------- - full_output: numpy matrix [#C x Htot x Wtot] - """ - assert len(full_output.shape) == 3 - assert full_output.size != 0 - - # Resolve patch coordinates - x1, y1 = coordinates - x2, y2 = x1 + patch.shape[2], y1 + patch.shape[1] - - # If this triggers it means that a patch is 'out-of-bounds' of the image and that should never happen! - assert x2 <= full_output.shape[2] - assert y2 <= full_output.shape[1] - - mask = np.isnan(full_output[:, y1:y2, x1:x2]) - # if still NaN in full_output just insert value from crop, if there is a value then take max - full_output[:, y1:y2, x1:x2] = np.where(mask, patch, np.maximum(patch, full_output[:, y1:y2, x1:x2])) - - return full_output - - def save_output_page_image(image_name, output_image, output_folder: Path, class_encoding: List[Tuple[int]]): """ Helper function to save the output during testing in the DIVAHisDB format diff --git a/src/datamodules/RGB/utils/single_transform.py b/src/datamodules/RGB/utils/single_transform.py new file mode 100644 index 00000000..34b48bb3 --- /dev/null +++ b/src/datamodules/RGB/utils/single_transform.py @@ -0,0 +1,29 @@ +import src.datamodules.RGB.utils.functional as F_custom + + +class IntegerEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_int_encoding(gt, self.class_encodings) + + +class OneHotEncoding(object): + def __init__(self, class_encodings): + self.class_encodings = class_encodings + + def __call__(self, gt): + """ + Args: + + Returns: + + """ + return F_custom.gt_to_one_hot(gt, self.class_encodings) diff --git a/src/datamodules/RolfFormat/utils/image_analytics.py b/src/datamodules/RolfFormat/utils/image_analytics.py index 278c032f..3f96202b 100644 --- a/src/datamodules/RolfFormat/utils/image_analytics.py +++ b/src/datamodules/RolfFormat/utils/image_analytics.py @@ -1,19 +1,13 @@ # Utils import logging -import os -from multiprocessing import Pool -from pathlib import Path -from typing import List import numpy as np # Torch related stuff -import torch -import torchvision.datasets as datasets -import torchvision.transforms as transforms from PIL import Image -from torchvision.datasets.folder import pil_loader +from src.datamodules.RGB.utils.image_analytics import _get_class_frequencies_weights_segmentation from src.datamodules.RolfFormat.datasets.dataset import ImageDimensions +from src.datamodules.utils.image_analytics import compute_mean_std def get_analytics_data(data_gt_path_list, **kwargs): @@ -46,205 +40,3 @@ def get_image_dims(data_gt_path_list, **kwargs): image_dims = ImageDimensions(width=img.width, height=img.height) return image_dims - - -def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): - """ - Computes mean and std of all images present at target folder. - - Parameters - ---------- - input_folder : String (path) - Path to the dataset folder (see above for details) - inmem : Boolean - Specifies whether is should be computed i nan online of offline fashion. - workers : int - Number of workers to use for the mean/std computation - - Returns - ------- - mean : float - Mean value of all pixels of the images in the input folder - std : float - Standard deviation of all pixels of the images in the input folder - """ - file_names_np = np.array(list(map(str, file_names))) - # Compute mean and std - mean, std = _cms_inmem(file_names_np) if inmem else _cms_online(file_names_np, workers) - return mean, std - - -def _cms_online(file_names, workers=4): - """ - Computes mean and image_classification deviation in an online fashion. - This is useful when the dataset is too big to be allocated in memory. - - Parameters - ---------- - file_names : List of String - List of file names of the dataset - workers : int - Number of workers to use for the mean/std computation - - Returns - ------- - mean : double - std : double - """ - logging.info('Begin computing the mean') - - # Set up a pool of workers - pool = Pool(workers + 1) - - # Online mean - results = pool.map(_return_mean, file_names) - mean_sum = np.sum(np.array(results), axis=0) - - # Divide by number of samples in train set - mean = mean_sum / file_names.size - - logging.info('Finished computing the mean') - logging.info('Begin computing the std') - - # Online image_classification deviation - results = pool.starmap(_return_std, [[item, mean] for item in file_names]) - std_sum = np.sum(np.array([item[0] for item in results]), axis=0) - total_pixel_count = np.sum(np.array([item[1] for item in results])) - std = np.sqrt(std_sum / total_pixel_count) - logging.info('Finished computing the std') - - # Shut down the pool - pool.close() - - return mean, std - - -# Loads an image with OpenCV and returns the channel wise means of the image. -def _return_mean(image_path): - img = np.array(Image.open(image_path).convert('RGB')) - mean = np.array([np.mean(img[:, :, 0]), np.mean(img[:, :, 1]), np.mean(img[:, :, 2])]) / 255.0 - return mean - - -# Loads an image with OpenCV and returns the channel wise std of the image. -def _return_std(image_path, mean): - img = np.array(Image.open(image_path).convert('RGB')) / 255.0 - m2 = np.square(np.array([img[:, :, 0] - mean[0], img[:, :, 1] - mean[1], img[:, :, 2] - mean[2]])) - return np.sum(np.sum(m2, axis=1), 1), m2.size / 3.0 - - -def _cms_inmem(file_names): - """ - Computes mean and image_classification deviation in an offline fashion. This is possible only when the dataset can - be allocated in memory. - - Parameters - ---------- - file_names: List of String - List of file names of the dataset - Returns - ------- - mean : double - std : double - """ - img = np.zeros([file_names.size] + list(np.array(Image.open(file_names[0]).convert('RGB')).shape)) - - # Load all samples - for i, sample in enumerate(file_names): - img[i] = np.array(Image.open(sample).convert('RGB')) - - mean = np.array([np.mean(img[:, :, :, 0]), np.mean(img[:, :, :, 1]), np.mean(img[:, :, :, 2])]) / 255.0 - std = np.array([np.std(img[:, :, :, 0]), np.std(img[:, :, :, 1]), np.std(img[:, :, :, 2])]) / 255.0 - - return mean, std - - -def get_class_weights(input_folder, workers=4, **kwargs): - """ - Get the weights proportional to the inverse of their class frequencies. - The vector sums up to 1 - - Parameters - ---------- - input_folder : String (path) - Path to the dataset folder (see above for details) - workers : int - Number of workers to use for the mean/std computation - - Returns - ------- - ndarray[double] of size (num_classes) - The weights vector as a 1D array normalized (sum up to 1) - """ - # Sanity check on the folder - if not os.path.isdir(input_folder): - logging.error(f"Folder {input_folder} does not exist") - raise FileNotFoundError - - # Load the dataset - ds = datasets.ImageFolder(input_folder, transform=transforms.Compose([transforms.ToTensor()])) - - logging.info('Begin computing class frequencies weights') - - if hasattr(ds, 'targets'): - labels = ds.targets - elif hasattr(ds, 'labels'): - labels = ds.labels - else: - # This is a fail-safe net in case a custom dataset changed the name of the internal variables - data_loader = torch.utils.data.DataLoader(ds, batch_size=1, num_workers=workers) - labels = [] - for target, label in data_loader: - labels.append(label) - labels = np.concatenate(labels).reshape(len(ds)) - - class_support = np.unique(labels, return_counts=True)[1] - class_frequencies = class_support / len(labels) - # Class weights are the inverse of the class frequencies - class_weights = 1 / class_frequencies - # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) - class_weights /= class_weights.sum() - - logging.info('Finished computing class frequencies weights ') - logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') - logging.info(f'Class weights (rounded): {np.around(class_weights * 100, decimals=2)}') - - return class_weights - - -def _get_class_frequencies_weights_segmentation(gt_images, **kwargs): - """ - Get the weights proportional to the inverse of their class frequencies. - The vector sums up to 1 - - Parameters - ---------- - gt_images: list of strings - Path to all ground truth images, which contain the pixel-wise label - workers: int - Number of workers to use for the mean/std computation - - Returns - ------- - ndarray[double] of size (num_classes) and ints the classes are represented as - The weights vector as a 1D array normalized (sum up to 1) - """ - logging.info('Begin computing class frequencies weights') - - total_num_pixels = 0 - label_counter = {} - - for path in gt_images: - img_raw = pil_loader(path) - colors = img_raw.getcolors() - - for count, color in colors: - total_num_pixels += count - label_counter[color] = label_counter.get(color, 0) + count - - classes = sorted(label_counter.keys()) - num_samples_per_class = np.asarray([label_counter[k] for k in classes]) - logging.info('Finished computing class frequencies weights') - # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) - class_weights = (1 / num_samples_per_class) / ((1 / num_samples_per_class).sum()) - return class_weights.tolist(), classes diff --git a/src/datamodules/RolfFormat/utils/output_tools.py b/src/datamodules/RolfFormat/utils/output_tools.py deleted file mode 100644 index 6a472955..00000000 --- a/src/datamodules/RolfFormat/utils/output_tools.py +++ /dev/null @@ -1,119 +0,0 @@ -from pathlib import Path -from typing import Union, Tuple, List - -import numpy as np -import torch -from PIL import Image - - -def _get_argmax(output: Union[torch.Tensor, np.ndarray], dim=1): - """ - takes the biggest value from a pixel across all classes - :param output: (Batch_size x num_classes x W x H) - matrix with the given attributes - :return: (Batch_size x W x H) - matrix with the hisdb class number for each pixel - """ - if isinstance(output, torch.Tensor): - return torch.argmax(output, dim=dim) - if isinstance(output, np.ndarray): - return np.argmax(output, axis=dim) - return output - - -def merge_patches(patch, coordinates, full_output): - """ - This function merges the patch into the full output image - Overlapping values are resolved by taking the max. - - Parameters - ---------- - patch: numpy matrix of size [#classes x crop_size x crop_size] - a patch from the larger image - coordinates: tuple of ints - top left coordinates of the patch within the larger image for all patches in a batch - full_output: numpy matrix of size [#C x H x W] - output image at full size - Returns - ------- - full_output: numpy matrix [#C x Htot x Wtot] - """ - assert len(full_output.shape) == 3 - assert full_output.size != 0 - - # Resolve patch coordinates - x1, y1 = coordinates - x2, y2 = x1 + patch.shape[2], y1 + patch.shape[1] - - # If this triggers it means that a patch is 'out-of-bounds' of the image and that should never happen! - assert x2 <= full_output.shape[2] - assert y2 <= full_output.shape[1] - - mask = np.isnan(full_output[:, y1:y2, x1:x2]) - # if still NaN in full_output just insert value from crop, if there is a value then take max - full_output[:, y1:y2, x1:x2] = np.where(mask, patch, np.maximum(patch, full_output[:, y1:y2, x1:x2])) - - return full_output - - -def save_output_page_image(image_name, output_image, output_folder: Path, class_encoding: List[Tuple[int]]): - """ - Helper function to save the output during testing in the DIVAHisDB format - - Parameters - ---------- - image_name: str - name of the image that is saved - output_image: numpy matrix of size [#C x H x W] - output image at full size - output_folder: Path - path to the output folder for the test data - class_encoding: list(tuple(int)) - list with the class encodings - - Returns - ------- - mean_iu : float - mean iu of this image - """ - - output_encoded = output_to_class_encodings(output_image, class_encoding) - - dest_folder = output_folder - dest_folder.mkdir(parents=True, exist_ok=True) - dest_filename = dest_folder / image_name - - # Save the output - Image.fromarray(output_encoded.astype(np.uint8)).save(str(dest_filename)) - - -def output_to_class_encodings(output, class_encodings): - """ - This function converts the output prediction matrix to an image like it was provided in the ground truth - - Parameters - ------- - output : np.array of size [#C x H x W] - output prediction of the network for a full-size image, where #C is the number of classes - class_encodings : List - Contains the range of encoded classes - perform_argmax : bool - perform argmax on input data - Returns - ------- - numpy array of size [C x H x W] (BGR) - """ - - integer_encoded = np.argmax(output, axis=0) - - num_classes = len(class_encodings) - - masks = [integer_encoded == class_index for class_index in range(num_classes)] - - rgb = np.full((*integer_encoded.shape, 3), -1) - for mask, color in zip(masks, class_encodings): - rgb[:, :, 0] = np.where(mask, color[0], rgb[:, :, 0]) - rgb[:, :, 1] = np.where(mask, color[1], rgb[:, :, 1]) - rgb[:, :, 2] = np.where(mask, color[2], rgb[:, :, 2]) - - return rgb diff --git a/src/datamodules/RotNet/datasets/cropped_dataset.py b/src/datamodules/RotNet/datasets/cropped_dataset.py index ead5a7b4..4c201a49 100644 --- a/src/datamodules/RotNet/datasets/cropped_dataset.py +++ b/src/datamodules/RotNet/datasets/cropped_dataset.py @@ -13,7 +13,6 @@ from torchvision.transforms import ToTensor from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset -from src.datamodules.RotNet.utils.misc import has_extension from src.utils import utils IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm') @@ -175,7 +174,7 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st for path_data_subdir in subitems: if not path_data_subdir.is_dir(): - if has_extension(path_data_subdir.name, IMG_EXTENSIONS): + if has_file_allowed_extension(path_data_subdir.name, IMG_EXTENSIONS): log.warning("image file found in data root: " + str(path_data_subdir)) continue diff --git a/src/datamodules/RotNet/utils/image_analytics.py b/src/datamodules/RotNet/utils/image_analytics.py index fcf5c9b9..5c9f0e59 100644 --- a/src/datamodules/RotNet/utils/image_analytics.py +++ b/src/datamodules/RotNet/utils/image_analytics.py @@ -2,17 +2,11 @@ import errno import json import logging -import os -from multiprocessing import Pool from pathlib import Path -from typing import List import numpy as np -# Torch related stuff -import torch -import torchvision.datasets as datasets -import torchvision.transforms as transforms -from PIL import Image + +from src.datamodules.utils.image_analytics import compute_mean_std def get_analytics_data(input_path: Path, data_folder_name: str, get_gt_data_paths_func, **kwargs): @@ -58,171 +52,3 @@ def get_analytics_data(input_path: Path, data_folder_name: str, get_gt_data_path raise return analytics_data - - -def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): - """ - Computes mean and std of all images present at target folder. - - Parameters - ---------- - input_folder : String (path) - Path to the dataset folder (see above for details) - inmem : Boolean - Specifies whether is should be computed i nan online of offline fashion. - workers : int - Number of workers to use for the mean/std computation - - Returns - ------- - mean : float - Mean value of all pixels of the images in the input folder - std : float - Standard deviation of all pixels of the images in the input folder - """ - file_names_np = np.array(list(map(str, file_names))) - # Compute mean and std - mean, std = _cms_inmem(file_names_np) if inmem else _cms_online(file_names_np, workers) - return mean, std - - -def _cms_online(file_names, workers=4): - """ - Computes mean and image_classification deviation in an online fashion. - This is useful when the dataset is too big to be allocated in memory. - - Parameters - ---------- - file_names : List of String - List of file names of the dataset - workers : int - Number of workers to use for the mean/std computation - - Returns - ------- - mean : double - std : double - """ - logging.info('Begin computing the mean') - - # Set up a pool of workers - pool = Pool(workers + 1) - - # Online mean - results = pool.map(_return_mean, file_names) - mean_sum = np.sum(np.array(results), axis=0) - - # Divide by number of samples in train set - mean = mean_sum / file_names.size - - logging.info('Finished computing the mean') - logging.info('Begin computing the std') - - # Online image_classification deviation - results = pool.starmap(_return_std, [[item, mean] for item in file_names]) - std_sum = np.sum(np.array([item[0] for item in results]), axis=0) - total_pixel_count = np.sum(np.array([item[1] for item in results])) - std = np.sqrt(std_sum / total_pixel_count) - logging.info('Finished computing the std') - - # Shut down the pool - pool.close() - - return mean, std - - -# Loads an image with OpenCV and returns the channel wise means of the image. -def _return_mean(image_path): - img = np.array(Image.open(image_path).convert('RGB')) - mean = np.array([np.mean(img[:, :, 0]), np.mean(img[:, :, 1]), np.mean(img[:, :, 2])]) / 255.0 - return mean - - -# Loads an image with OpenCV and returns the channel wise std of the image. -def _return_std(image_path, mean): - img = np.array(Image.open(image_path).convert('RGB')) / 255.0 - m2 = np.square(np.array([img[:, :, 0] - mean[0], img[:, :, 1] - mean[1], img[:, :, 2] - mean[2]])) - return np.sum(np.sum(m2, axis=1), 1), m2.size / 3.0 - - -def _cms_inmem(file_names): - """ - Computes mean and image_classification deviation in an offline fashion. This is possible only when the dataset can - be allocated in memory. - - Parameters - ---------- - file_names: List of String - List of file names of the dataset - Returns - ------- - mean : double - std : double - """ - img = np.zeros([file_names.size] + list(np.array(Image.open(file_names[0]).convert('RGB')).shape)) - - # Load all samples - for i, sample in enumerate(file_names): - img[i] = np.array(Image.open(sample).convert('RGB')) - - mean = np.array([np.mean(img[:, :, :, 0]), np.mean(img[:, :, :, 1]), np.mean(img[:, :, :, 2])]) / 255.0 - std = np.array([np.std(img[:, :, :, 0]), np.std(img[:, :, :, 1]), np.std(img[:, :, :, 2])]) / 255.0 - - return mean, std - - -def get_class_weights(input_folder, workers=4, **kwargs): - """ - Get the weights proportional to the inverse of their class frequencies. - The vector sums up to 1 - - Parameters - ---------- - input_folder : String (path) - Path to the dataset folder (see above for details) - workers : int - Number of workers to use for the mean/std computation - - Returns - ------- - ndarray[double] of size (num_classes) - The weights vector as a 1D array normalized (sum up to 1) - """ - # Sanity check on the folder - if not os.path.isdir(input_folder): - logging.error(f"Folder {input_folder} does not exist") - raise FileNotFoundError - - # Load the dataset - ds = datasets.ImageFolder(input_folder, transform=transforms.Compose([transforms.ToTensor()])) - - logging.info('Begin computing class frequencies weights') - - if hasattr(ds, 'targets'): - labels = ds.targets - elif hasattr(ds, 'labels'): - labels = ds.labels - else: - # This is a fail-safe net in case a custom dataset changed the name of the internal variables - data_loader = torch.utils.data.DataLoader(ds, batch_size=1, num_workers=workers) - labels = [] - for target, label in data_loader: - labels.append(label) - labels = np.concatenate(labels).reshape(len(ds)) - - class_support = np.unique(labels, return_counts=True)[1] - class_frequencies = class_support / len(labels) - # Class weights are the inverse of the class frequencies - class_weights = 1 / class_frequencies - # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) - class_weights /= class_weights.sum() - - logging.info('Finished computing class frequencies weights ') - logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') - logging.info(f'Class weights (rounded): {np.around(class_weights * 100, decimals=2)}') - - return class_weights - - -if __name__ == '__main__': - print(get_analytics_data(input_path=Path('tests/dummy_data/dummy_dataset'))) diff --git a/src/datamodules/RotNet/utils/misc.py b/src/datamodules/RotNet/utils/misc.py index 9c1a3170..eada0b9b 100644 --- a/src/datamodules/RotNet/utils/misc.py +++ b/src/datamodules/RotNet/utils/misc.py @@ -3,11 +3,6 @@ from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir -def has_extension(filename, extensions): - filename_lower = filename.lower() - return any(filename_lower.endswith(ext) for ext in extensions) - - def validate_path_for_self_supervised(data_dir, data_folder_name: str): if data_dir is None: raise PathNone("Please provide the path to root dir of the dataset " diff --git a/src/datamodules/utils/functional.py b/src/datamodules/utils/functional.py new file mode 100644 index 00000000..882b3651 --- /dev/null +++ b/src/datamodules/utils/functional.py @@ -0,0 +1,8 @@ +import torch + + +def argmax_onehot(tensor: torch.Tensor): + """ + # TODO + """ + return torch.LongTensor(torch.argmax(tensor, dim=0)) \ No newline at end of file diff --git a/src/datamodules/utils/image_analytics.py b/src/datamodules/utils/image_analytics.py new file mode 100644 index 00000000..64f27374 --- /dev/null +++ b/src/datamodules/utils/image_analytics.py @@ -0,0 +1,116 @@ +import logging +from multiprocessing import Pool +from pathlib import Path +from typing import List + +import numpy as np +from PIL import Image + + +def compute_mean_std(file_names: List[Path], inmem=False, workers=4, **kwargs): + """ + Computes mean and std of all images present at target folder. + + Parameters + ---------- + input_folder : String (path) + Path to the dataset folder (see above for details) + inmem : Boolean + Specifies whether is should be computed i nan online of offline fashion. + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + mean : float + Mean value of all pixels of the images in the input folder + std : float + Standard deviation of all pixels of the images in the input folder + """ + file_names_np = np.array(list(map(str, file_names))) + # Compute mean and std + mean, std = _cms_inmem(file_names_np) if inmem else _cms_online(file_names_np, workers) + return mean, std + + +def _cms_online(file_names, workers=4): + """ + Computes mean and image_classification deviation in an online fashion. + This is useful when the dataset is too big to be allocated in memory. + + Parameters + ---------- + file_names : List of String + List of file names of the dataset + workers : int + Number of workers to use for the mean/std computation + + Returns + ------- + mean : double + std : double + """ + logging.info('Begin computing the mean') + + # Set up a pool of workers + pool = Pool(workers + 1) + + # Online mean + results = pool.map(_return_mean, file_names) + mean_sum = np.sum(np.array(results), axis=0) + + # Divide by number of samples in train set + mean = mean_sum / file_names.size + + logging.info('Finished computing the mean') + logging.info('Begin computing the std') + + # Online image_classification deviation + results = pool.starmap(_return_std, [[item, mean] for item in file_names]) + std_sum = np.sum(np.array([item[0] for item in results]), axis=0) + total_pixel_count = np.sum(np.array([item[1] for item in results])) + std = np.sqrt(std_sum / total_pixel_count) + logging.info('Finished computing the std') + + # Shut down the pool + pool.close() + + return mean, std + + +def _return_mean(image_path): + img = np.array(Image.open(image_path).convert('RGB')) + mean = np.array([np.mean(img[:, :, 0]), np.mean(img[:, :, 1]), np.mean(img[:, :, 2])]) / 255.0 + return mean + + +def _return_std(image_path, mean): + img = np.array(Image.open(image_path).convert('RGB')) / 255.0 + m2 = np.square(np.array([img[:, :, 0] - mean[0], img[:, :, 1] - mean[1], img[:, :, 2] - mean[2]])) + return np.sum(np.sum(m2, axis=1), 1), m2.size / 3.0 + + +def _cms_inmem(file_names): + """ + Computes mean and image_classification deviation in an offline fashion. This is possible only when the dataset can + be allocated in memory. + + Parameters + ---------- + file_names: List of String + List of file names of the dataset + Returns + ------- + mean : double + std : double + """ + img = np.zeros([file_names.size] + list(np.array(Image.open(file_names[0]).convert('RGB')).shape)) + + # Load all samples + for i, sample in enumerate(file_names): + img[i] = np.array(Image.open(sample).convert('RGB')) + + mean = np.array([np.mean(img[:, :, :, 0]), np.mean(img[:, :, :, 1]), np.mean(img[:, :, :, 2])]) / 255.0 + std = np.array([np.std(img[:, :, :, 0]), np.std(img[:, :, :, 1]), np.std(img[:, :, :, 2])]) / 255.0 + + return mean, std \ No newline at end of file diff --git a/src/datamodules/RGB/utils/misc.py b/src/datamodules/utils/misc.py similarity index 55% rename from src/datamodules/RGB/utils/misc.py rename to src/datamodules/utils/misc.py index 3fb47da7..824a008e 100644 --- a/src/datamodules/RGB/utils/misc.py +++ b/src/datamodules/utils/misc.py @@ -1,14 +1,48 @@ -""" -General purpose utility functions. +from pathlib import Path +from typing import Union -""" +import numpy as np +import torch + +from src.datamodules.utils.exceptions import PathNone, PathNotDir, PathMissingSplitDir, PathMissingDirinSplitDir -from pathlib import Path -from src.datamodules.utils.exceptions import PathMissingDirinSplitDir, PathNone, PathNotDir, PathMissingSplitDir +def _get_argmax(output: Union[torch.Tensor, np.ndarray], dim=1): + """ + takes the biggest value from a pixel across all classes + :param output: (Batch_size x num_classes x W x H) + matrix with the given attributes + :return: (Batch_size x W x H) + matrix with the hisdb class number for each pixel + """ + if isinstance(output, torch.Tensor): + return torch.argmax(output, dim=dim) + if isinstance(output, np.ndarray): + return np.argmax(output, axis=dim) + return output def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_name: str): + """ + Checks if the data_dir folder has the following structure: + + {data_dir} + - train + - {data_folder_name} + - {gt_folder_name} + - val + - {data_folder_name} + - {gt_folder_name} + - test + - {data_folder_name} + - {gt_folder_name} + + + :param data_dir: + :param data_folder_name: + :param gt_folder_name: + :return: + """ if data_dir is None: raise PathNone("Please provide the path to root dir of the dataset " "(folder containing the train/val/test folder)") @@ -32,4 +66,4 @@ def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_na if len(type_folders) != 2: raise PathMissingDirinSplitDir(f'Folder {split.name} does not contain a {data_folder_name} ' f'and {gt_folder_name} folder') - return Path(data_dir) + return Path(data_dir) \ No newline at end of file diff --git a/src/datamodules/utils/output_tools.py b/src/datamodules/utils/output_tools.py new file mode 100644 index 00000000..5b31a40c --- /dev/null +++ b/src/datamodules/utils/output_tools.py @@ -0,0 +1,36 @@ +import numpy as np + + +def merge_patches(patch, coordinates, full_output): + """ + This function merges the patch into the full output image + Overlapping values are resolved by taking the max. + + Parameters + ---------- + patch: numpy matrix of size [#classes x crop_size x crop_size] + a patch from the larger image + coordinates: tuple of ints + top left coordinates of the patch within the larger image for all patches in a batch + full_output: numpy matrix of size [#C x H x W] + output image at full size + Returns + ------- + full_output: numpy matrix [#C x Htot x Wtot] + """ + assert len(full_output.shape) == 3 + assert full_output.size != 0 + + # Resolve patch coordinates + x1, y1 = coordinates + x2, y2 = x1 + patch.shape[2], y1 + patch.shape[1] + + # If this triggers it means that a patch is 'out-of-bounds' of the image and that should never happen! + assert x2 <= full_output.shape[2] + assert y2 <= full_output.shape[1] + + mask = np.isnan(full_output[:, y1:y2, x1:x2]) + # if still NaN in full_output just insert value from crop, if there is a value then take max + full_output[:, y1:y2, x1:x2] = np.where(mask, patch, np.maximum(patch, full_output[:, y1:y2, x1:x2])) + + return full_output \ No newline at end of file diff --git a/src/datamodules/utils/single_transforms.py b/src/datamodules/utils/single_transforms.py index 0aaa4a8b..bc1125f1 100644 --- a/src/datamodules/utils/single_transforms.py +++ b/src/datamodules/utils/single_transforms.py @@ -5,7 +5,7 @@ from PIL import Image from torchvision.transforms import Pad -from src.datamodules.DivaHisDB.utils import functional as F_custom +import src.datamodules.utils.functional class ResizePad(object): @@ -148,32 +148,7 @@ def __call__(self, img): class OneHotToPixelLabelling(object): def __call__(self, tensor): - return F_custom.argmax_onehot(tensor) + return src.datamodules.utils.functional.argmax_onehot(tensor) -class OneHotEncoding(object): - def __init__(self, class_encodings): - self.class_encodings = class_encodings - def __call__(self, gt): - """ - Args: - - Returns: - - """ - return F_custom.gt_to_one_hot(gt, self.class_encodings) - - -class IntegerEncoding(object): - def __init__(self, class_encodings): - self.class_encodings = class_encodings - - def __call__(self, gt): - """ - Args: - - Returns: - - """ - return F_custom.gt_to_int_encoding(gt, self.class_encodings) diff --git a/src/tasks/DivaHisDB/semantic_segmentation.py b/src/tasks/DivaHisDB/semantic_segmentation.py index 8860c9a1..feac1b2e 100644 --- a/src/tasks/DivaHisDB/semantic_segmentation.py +++ b/src/tasks/DivaHisDB/semantic_segmentation.py @@ -6,8 +6,8 @@ import torch.optim import torchmetrics +from src.datamodules.utils.misc import _get_argmax from src.tasks.base_task import AbstractTask -from src.datamodules.DivaHisDB.utils.output_tools import _get_argmax from src.utils import utils from src.tasks.utils.outputs import OutputKeys, reduce_dict diff --git a/src/tasks/RGB/semantic_segmentation.py b/src/tasks/RGB/semantic_segmentation.py index 01106966..6621bfe0 100644 --- a/src/tasks/RGB/semantic_segmentation.py +++ b/src/tasks/RGB/semantic_segmentation.py @@ -6,8 +6,8 @@ import torch.optim import torchmetrics +from src.datamodules.utils.misc import _get_argmax from src.tasks.base_task import AbstractTask -from src.datamodules.RGB.utils.output_tools import _get_argmax from src.utils import utils from src.tasks.utils.outputs import OutputKeys, reduce_dict diff --git a/src/tasks/RGB/semantic_segmentation_full_page.py b/src/tasks/RGB/semantic_segmentation_full_page.py index 99ee0fb0..7adce3e9 100644 --- a/src/tasks/RGB/semantic_segmentation_full_page.py +++ b/src/tasks/RGB/semantic_segmentation_full_page.py @@ -6,9 +6,9 @@ import torch.optim import torchmetrics -from src.datamodules.RolfFormat.utils.output_tools import save_output_page_image +from src.datamodules.RGB.utils.output_tools import save_output_page_image +from src.datamodules.utils.misc import _get_argmax from src.tasks.base_task import AbstractTask -from src.datamodules.DivaHisDB.utils.output_tools import _get_argmax from src.utils import utils from src.tasks.utils.outputs import OutputKeys, reduce_dict diff --git a/tests/datamodules/DivaHisDB/test_misc.py b/tests/datamodules/DivaHisDB/test_misc.py index 8172ab2a..6a605dd6 100644 --- a/tests/datamodules/DivaHisDB/test_misc.py +++ b/tests/datamodules/DivaHisDB/test_misc.py @@ -1,7 +1,7 @@ import pytest from src.datamodules.utils.exceptions import PathNone, PathNotDir, PathMissingSplitDir, PathMissingDirinSplitDir -from src.datamodules.DivaHisDB.utils.misc import validate_path_for_segmentation +from src.datamodules.utils.misc import validate_path_for_segmentation @pytest.fixture diff --git a/tests/tasks/sem_seg/test_output_tools.py b/tests/tasks/sem_seg/test_output_tools.py index f7bca67a..e02e6010 100644 --- a/tests/tasks/sem_seg/test_output_tools.py +++ b/tests/tasks/sem_seg/test_output_tools.py @@ -2,10 +2,12 @@ from PIL import Image from torch import tensor, equal -from src.datamodules.DivaHisDB.utils.output_tools import _get_argmax, merge_patches, output_to_class_encodings, \ +from src.datamodules.DivaHisDB.utils.output_tools import output_to_class_encodings, \ save_output_page_image - +from src.datamodules.utils.output_tools import merge_patches # batchsize (2) x classes (4) x W (2) x H (2) +from src.datamodules.utils.misc import _get_argmax + BATCH = tensor([[[[0., 0.3], [4., 2.]], [[1., 4.1], [-0.2, 1.9]], [[1.1, -0.8], [4.9, 1.3]], diff --git a/tools/merge_cropped_output_HisDB.py b/tools/merge_cropped_output_HisDB.py index 5ffead16..75cebad0 100644 --- a/tools/merge_cropped_output_HisDB.py +++ b/tools/merge_cropped_output_HisDB.py @@ -14,7 +14,8 @@ from src.datamodules.DivaHisDB.datamodule_cropped import DivaHisDBDataModuleCropped from src.datamodules.DivaHisDB.datasets.cropped_dataset import CroppedHisDBDataset -from src.datamodules.DivaHisDB.utils.output_tools import merge_patches, save_output_page_image +from src.datamodules.DivaHisDB.utils.output_tools import save_output_page_image +from src.datamodules.utils.output_tools import merge_patches from tools.generate_cropped_dataset import pil_loader from tools.viz import visualize diff --git a/tools/merge_cropped_output_RGB.py b/tools/merge_cropped_output_RGB.py index 04462db7..87bbaaf7 100644 --- a/tools/merge_cropped_output_RGB.py +++ b/tools/merge_cropped_output_RGB.py @@ -12,11 +12,11 @@ from PIL import Image from tqdm import tqdm -from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB -from src.datamodules.RGB.utils.output_tools import merge_patches, save_output_page_image from src.datamodules.RGB.datamodule_cropped import DataModuleCroppedRGB +from src.datamodules.RGB.datasets.cropped_dataset import CroppedDatasetRGB +from src.datamodules.RGB.utils.output_tools import save_output_page_image +from src.datamodules.utils.output_tools import merge_patches from tools.generate_cropped_dataset import pil_loader -from tools.viz import visualize @dataclass From 2078d02985c434fed327def65f9ecec56c6fc4c8 Mon Sep 17 00:00:00 2001 From: Paul M Date: Mon, 22 Nov 2021 17:27:25 +0100 Subject: [PATCH 083/108] :books: log number of samples per dataset after dataset is loaded --- src/datamodules/DivaHisDB/datamodule_cropped.py | 7 +++++-- src/datamodules/RGB/datamodule_cropped.py | 7 +++++-- src/datamodules/RGB/datamodule_full_page.py | 7 +++++-- src/datamodules/RolfFormat/datamodule.py | 7 +++++-- src/datamodules/RotNet/datamodule_cropped.py | 7 +++++-- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index ab0979bc..824fc63a 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -65,15 +65,18 @@ def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: self.train = CroppedHisDBDataset(**self._create_dataset_parameters('train'), selection=self.selection_train) - self.val = CroppedHisDBDataset(**self._create_dataset_parameters('val'), selection=self.selection_val) - + log.info(f'Initialized train dataset with {len(self.train)} samples.') self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) + + self.val = CroppedHisDBDataset(**self._create_dataset_parameters('val'), selection=self.selection_val) + log.info(f'Initialized val dataset with {len(self.val)} samples.') self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) if stage == 'test' or stage is not None: self.test = CroppedHisDBDataset(**self._create_dataset_parameters('test'), selection=self.selection_test) + log.info(f'Initialized test dataset with {len(self.test)} samples.') # self._check_min_num_samples(num_samples=len(self.test), data_split='test', # drop_last=False) diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index de6b7f06..02e23c41 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -65,15 +65,18 @@ def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: self.train = CroppedDatasetRGB(**self._create_dataset_parameters('train'), selection=self.selection_train) - self.val = CroppedDatasetRGB(**self._create_dataset_parameters('val'), selection=self.selection_val) - + log.info(f'Initialized train dataset with {len(self.train)} samples.') self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) + + self.val = CroppedDatasetRGB(**self._create_dataset_parameters('val'), selection=self.selection_val) + log.info(f'Initialized val dataset with {len(self.val)} samples.') self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) if stage == 'test' or stage is not None: self.test = CroppedDatasetRGB(**self._create_dataset_parameters('test'), selection=self.selection_test) + log.info(f'Initialized test dataset with {len(self.test)} samples.') # self._check_min_num_samples(num_samples=len(self.test), data_split='test', # drop_last=False) diff --git a/src/datamodules/RGB/datamodule_full_page.py b/src/datamodules/RGB/datamodule_full_page.py index 39d31400..c6c9a40f 100644 --- a/src/datamodules/RGB/datamodule_full_page.py +++ b/src/datamodules/RGB/datamodule_full_page.py @@ -65,15 +65,18 @@ def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: self.train = DatasetRGB(**self._create_dataset_parameters('train'), selection=self.selection_train) - self.val = DatasetRGB(**self._create_dataset_parameters('val'), selection=self.selection_val) - + log.info(f'Initialized train dataset with {len(self.train)} samples.') self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) + + self.val = DatasetRGB(**self._create_dataset_parameters('val'), selection=self.selection_val) + log.info(f'Initialized val dataset with {len(self.val)} samples.') self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) if stage == 'test' or stage is not None: self.test = DatasetRGB(**self._create_dataset_parameters('test'), selection=self.selection_test) + log.info(f'Initialized test dataset with {len(self.test)} samples.') # self._check_min_num_samples(num_samples=len(self.test), data_split='test', # drop_last=False) diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index 8f09e82d..80ab2ad4 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -138,17 +138,20 @@ def setup(self, stage: Optional[str] = None): self.train = DatasetRolfFormat(dataset_specs=self.train_dataset_specs, is_test=False, **common_kwargs) + log.info(f'Initialized train dataset with {len(self.train)} samples.') + self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) + self.val = DatasetRolfFormat(dataset_specs=self.val_dataset_specs, is_test=False, **common_kwargs) - - self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) + log.info(f'Initialized val dataset with {len(self.val)} samples.') self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) if stage == 'test' or stage is not None: self.test = DatasetRolfFormat(dataset_specs=self.test_dataset_specs, is_test=True, **common_kwargs) + log.info(f'Initialized test dataset with {len(self.test)} samples.') # self._check_min_num_samples(num_samples=len(self.test), data_split='test', drop_last=False) if stage == 'predict': diff --git a/src/datamodules/RotNet/datamodule_cropped.py b/src/datamodules/RotNet/datamodule_cropped.py index b91be885..890a34f3 100644 --- a/src/datamodules/RotNet/datamodule_cropped.py +++ b/src/datamodules/RotNet/datamodule_cropped.py @@ -56,15 +56,18 @@ def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: self.train = CroppedRotNet(**self._create_dataset_parameters('train'), selection=self.selection_train) - self.val = CroppedRotNet(**self._create_dataset_parameters('val'), selection=self.selection_val) - + log.info(f'Initialized train dataset with {len(self.train)} samples.') self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) + + self.val = CroppedRotNet(**self._create_dataset_parameters('val'), selection=self.selection_val) + log.info(f'Initialized val dataset with {len(self.val)} samples.') self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) if stage == 'test' or stage is not None: self.test = CroppedRotNet(**self._create_dataset_parameters('test'), selection=self.selection_test) + log.info(f'Initialized test dataset with {len(self.test)} samples.') # self._check_min_num_samples(num_samples=len(self.test), data_split='test', # drop_last=False) From af6674c0328ecba82d440f93351daca6dd1fdc46 Mon Sep 17 00:00:00 2001 From: Paul M Date: Mon, 22 Nov 2021 17:37:33 +0100 Subject: [PATCH 084/108] :recycle: :art: moved ImageDimensions dataclass to misc, and check image dimensions in full page datasets --- src/datamodules/RGB/datasets/full_page_dataset.py | 7 +------ src/datamodules/RolfFormat/datamodule.py | 3 ++- src/datamodules/RolfFormat/datasets/dataset.py | 6 ++---- src/datamodules/RolfFormat/utils/image_analytics.py | 2 +- src/datamodules/utils/dataset_predict.py | 8 +++++++- src/datamodules/utils/misc.py | 9 ++++++++- tests/datamodules/util/test_predict_dataset.py | 3 ++- 7 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/datamodules/RGB/datasets/full_page_dataset.py b/src/datamodules/RGB/datasets/full_page_dataset.py index b9f8293b..40210a31 100644 --- a/src/datamodules/RGB/datasets/full_page_dataset.py +++ b/src/datamodules/RGB/datasets/full_page_dataset.py @@ -13,6 +13,7 @@ from torchvision.datasets.folder import pil_loader, has_file_allowed_extension from torchvision.transforms import ToTensor +from src.datamodules.utils.misc import ImageDimensions from src.utils import utils IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif') @@ -20,12 +21,6 @@ log = utils.get_logger(__name__) -@dataclass -class ImageDimensions: - width: int - height: int - - class DatasetRGB(data.Dataset): """A generic data loader where the images are arranged in this way: :: diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index 80ab2ad4..571a76f9 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -5,7 +5,8 @@ from torch.utils.data import DataLoader from torchvision import transforms -from src.datamodules.RolfFormat.datasets.dataset import DatasetRolfFormat, DatasetSpecs, ImageDimensions +from src.datamodules.RolfFormat.datasets.dataset import DatasetRolfFormat, DatasetSpecs +from src.datamodules.utils.misc import ImageDimensions from src.datamodules.RolfFormat.utils.image_analytics import get_analytics_data, get_analytics_gt, get_image_dims from src.datamodules.RolfFormat.utils.twin_transforms import IntegerEncoding from src.datamodules.RolfFormat.utils.wrapper_transforms import OnlyImage, OnlyTarget diff --git a/src/datamodules/RolfFormat/datasets/dataset.py b/src/datamodules/RolfFormat/datasets/dataset.py index 8e8866bd..14968d57 100644 --- a/src/datamodules/RolfFormat/datasets/dataset.py +++ b/src/datamodules/RolfFormat/datasets/dataset.py @@ -13,12 +13,14 @@ from torchvision.datasets.folder import pil_loader from torchvision.transforms import ToTensor +from src.datamodules.utils.misc import ImageDimensions from src.utils import utils IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif') log = utils.get_logger(__name__) + @dataclass class DatasetSpecs: data_root: str @@ -29,10 +31,6 @@ class DatasetSpecs: range_from: int range_to: int -@dataclass -class ImageDimensions: - width: int - height: int class DatasetRolfFormat(data.Dataset): """A generic data loader where the images are arranged in this way: :: diff --git a/src/datamodules/RolfFormat/utils/image_analytics.py b/src/datamodules/RolfFormat/utils/image_analytics.py index 3f96202b..7ded9b98 100644 --- a/src/datamodules/RolfFormat/utils/image_analytics.py +++ b/src/datamodules/RolfFormat/utils/image_analytics.py @@ -6,7 +6,7 @@ from PIL import Image from src.datamodules.RGB.utils.image_analytics import _get_class_frequencies_weights_segmentation -from src.datamodules.RolfFormat.datasets.dataset import ImageDimensions +from src.datamodules.utils.misc import ImageDimensions from src.datamodules.utils.image_analytics import compute_mean_std diff --git a/src/datamodules/utils/dataset_predict.py b/src/datamodules/utils/dataset_predict.py index edf95e00..836ba8b4 100644 --- a/src/datamodules/utils/dataset_predict.py +++ b/src/datamodules/utils/dataset_predict.py @@ -5,6 +5,7 @@ from torchvision.datasets.folder import pil_loader from torchvision.transforms import ToTensor +from src.datamodules.utils.misc import ImageDimensions from src.utils import utils log = utils.get_logger(__name__) @@ -12,7 +13,7 @@ class DatasetPredict(data.Dataset): - def __init__(self, image_path_list: List[str], + def __init__(self, image_path_list: List[str], image_dims: ImageDimensions, image_transform=None, target_transform=None, twin_transform=None, classes=None, **kwargs): """ @@ -26,6 +27,8 @@ def __init__(self, image_path_list: List[str], self.image_path_list = list(image_path_list) + self.image_dims = image_dims + # Init list self.classes = classes # self.crops_per_image = crops_per_image @@ -54,6 +57,9 @@ def __getitem__(self, index): def _load_data_and_gt(self, index): data_img = pil_loader(self.image_path_list[index]) + + assert data_img.height == self.image_dims.height and data_img.width == self.image_dims.width + return data_img def _apply_transformation(self, img): diff --git a/src/datamodules/utils/misc.py b/src/datamodules/utils/misc.py index 824a008e..cda99bd1 100644 --- a/src/datamodules/utils/misc.py +++ b/src/datamodules/utils/misc.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from pathlib import Path from typing import Union @@ -66,4 +67,10 @@ def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_na if len(type_folders) != 2: raise PathMissingDirinSplitDir(f'Folder {split.name} does not contain a {data_folder_name} ' f'and {gt_folder_name} folder') - return Path(data_dir) \ No newline at end of file + return Path(data_dir) + + +@dataclass +class ImageDimensions: + width: int + height: int \ No newline at end of file diff --git a/tests/datamodules/util/test_predict_dataset.py b/tests/datamodules/util/test_predict_dataset.py index 8406d350..d2fddde7 100644 --- a/tests/datamodules/util/test_predict_dataset.py +++ b/tests/datamodules/util/test_predict_dataset.py @@ -3,6 +3,7 @@ from torchvision.transforms import ToTensor from src.datamodules.utils.dataset_predict import DatasetPredict +from src.datamodules.utils.misc import ImageDimensions from tests.test_data.dummy_data_hisdb.dummy_data import data_dir @@ -14,7 +15,7 @@ def file_path_list(data_dir): @pytest.fixture def predict_dataset(file_path_list): - return DatasetPredict(image_path_list=file_path_list) + return DatasetPredict(image_path_list=file_path_list, image_dims=ImageDimensions(width=487, height=649)) def test__load_data_and_gt(predict_dataset): From 5541066240e0c438c62a11e78467c28c6361a93d Mon Sep 17 00:00:00 2001 From: Paul M Date: Mon, 22 Nov 2021 17:53:25 +0100 Subject: [PATCH 085/108] :bug: fixed imports --- src/datamodules/RolfFormat/datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index 571a76f9..e295ed54 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -5,13 +5,13 @@ from torch.utils.data import DataLoader from torchvision import transforms +from src.datamodules.RGB.utils.single_transform import IntegerEncoding from src.datamodules.RolfFormat.datasets.dataset import DatasetRolfFormat, DatasetSpecs from src.datamodules.utils.misc import ImageDimensions from src.datamodules.RolfFormat.utils.image_analytics import get_analytics_data, get_analytics_gt, get_image_dims -from src.datamodules.RolfFormat.utils.twin_transforms import IntegerEncoding -from src.datamodules.RolfFormat.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.datamodules.base_datamodule import AbstractDatamodule from src.datamodules.utils.dataset_predict import DatasetPredict +from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils log = utils.get_logger(__name__) From d8246e4e196faf7fa6e6f6d25d0901deaefe67ec Mon Sep 17 00:00:00 2001 From: Paul M Date: Mon, 22 Nov 2021 17:57:04 +0100 Subject: [PATCH 086/108] :art: expose predict_output_path and test_output_path and set default values --- src/tasks/DivaHisDB/semantic_segmentation.py | 4 +++- src/tasks/RGB/semantic_segmentation.py | 4 +++- src/tasks/RGB/semantic_segmentation_full_page.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/tasks/DivaHisDB/semantic_segmentation.py b/src/tasks/DivaHisDB/semantic_segmentation.py index feac1b2e..4b0ce5fa 100644 --- a/src/tasks/DivaHisDB/semantic_segmentation.py +++ b/src/tasks/DivaHisDB/semantic_segmentation.py @@ -23,7 +23,8 @@ def __init__(self, metric_train: Optional[torchmetrics.Metric] = None, metric_val: Optional[torchmetrics.Metric] = None, metric_test: Optional[torchmetrics.Metric] = None, - test_output_path: Optional[Union[str, Path]] = 'predictions', + test_output_path: Optional[Union[str, Path]] = 'test_output', + predict_output_path: Optional[Union[str, Path]] = 'predict_output', confusion_matrix_val: Optional[bool] = False, confusion_matrix_test: Optional[bool] = False, confusion_matrix_log_every_n_epoch: Optional[int] = 1, @@ -45,6 +46,7 @@ def __init__(self, metric_val=metric_val, metric_test=metric_test, test_output_path=test_output_path, + predict_output_path=predict_output_path, lr=lr, confusion_matrix_val=confusion_matrix_val, confusion_matrix_test=confusion_matrix_test, diff --git a/src/tasks/RGB/semantic_segmentation.py b/src/tasks/RGB/semantic_segmentation.py index 6621bfe0..99172e1a 100644 --- a/src/tasks/RGB/semantic_segmentation.py +++ b/src/tasks/RGB/semantic_segmentation.py @@ -23,7 +23,8 @@ def __init__(self, metric_train: Optional[torchmetrics.Metric] = None, metric_val: Optional[torchmetrics.Metric] = None, metric_test: Optional[torchmetrics.Metric] = None, - test_output_path: Optional[Union[str, Path]] = 'predictions', + test_output_path: Optional[Union[str, Path]] = 'test_output', + predict_output_path: Optional[Union[str, Path]] = 'predict_output', confusion_matrix_val: Optional[bool] = False, confusion_matrix_test: Optional[bool] = False, confusion_matrix_log_every_n_epoch: Optional[int] = 1, @@ -45,6 +46,7 @@ def __init__(self, metric_val=metric_val, metric_test=metric_test, test_output_path=test_output_path, + predict_output_path=predict_output_path, lr=lr, confusion_matrix_val=confusion_matrix_val, confusion_matrix_test=confusion_matrix_test, diff --git a/src/tasks/RGB/semantic_segmentation_full_page.py b/src/tasks/RGB/semantic_segmentation_full_page.py index 7adce3e9..78302b17 100644 --- a/src/tasks/RGB/semantic_segmentation_full_page.py +++ b/src/tasks/RGB/semantic_segmentation_full_page.py @@ -24,7 +24,8 @@ def __init__(self, metric_train: Optional[torchmetrics.Metric] = None, metric_val: Optional[torchmetrics.Metric] = None, metric_test: Optional[torchmetrics.Metric] = None, - test_output_path: Optional[Union[str, Path]] = 'predictions', + test_output_path: Optional[Union[str, Path]] = 'test_output', + predict_output_path: Optional[Union[str, Path]] = 'predict_output', confusion_matrix_val: Optional[bool] = False, confusion_matrix_test: Optional[bool] = False, confusion_matrix_log_every_n_epoch: Optional[int] = 1, @@ -46,6 +47,7 @@ def __init__(self, metric_val=metric_val, metric_test=metric_test, test_output_path=test_output_path, + predict_output_path=predict_output_path, lr=lr, confusion_matrix_val=confusion_matrix_val, confusion_matrix_test=confusion_matrix_test, From 1769758c27454786eafe33804af9f025a881c048 Mon Sep 17 00:00:00 2001 From: Paul M Date: Mon, 22 Nov 2021 18:02:29 +0100 Subject: [PATCH 087/108] :loud_sound: log predict dataset size --- src/datamodules/RolfFormat/datamodule.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index e295ed54..b582b377 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -148,7 +148,7 @@ def setup(self, stage: Optional[str] = None): log.info(f'Initialized val dataset with {len(self.val)} samples.') self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) - if stage == 'test' or stage is not None: + if stage == 'test': self.test = DatasetRolfFormat(dataset_specs=self.test_dataset_specs, is_test=True, **common_kwargs) @@ -158,6 +158,8 @@ def setup(self, stage: Optional[str] = None): if stage == 'predict': self.predict = DatasetPredict(image_path_list=self.pred_file_path_list, **common_kwargs) + log.info(f'Initialized predict dataset with {len(self.predict)} samples.') + # self._check_min_num_samples(num_samples=len(self.test), data_split='test', drop_last=False) def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: bool): num_processes = self.trainer.num_processes From cb9e3f2102d04d78a0a56151c3b152dd0a02df38 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 23 Nov 2021 10:14:09 +0100 Subject: [PATCH 088/108] :art: check now when we save the conf mat to wandb if we have a wandb logger --- src/callbacks/wandb_callbacks.py | 2 +- src/tasks/base_task.py | 18 +++++++++++------- .../sem_seg/test_semantic_segmentation.py | 3 ++- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/callbacks/wandb_callbacks.py b/src/callbacks/wandb_callbacks.py index 69c243c8..04ff0645 100644 --- a/src/callbacks/wandb_callbacks.py +++ b/src/callbacks/wandb_callbacks.py @@ -12,7 +12,7 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger: if isinstance(logger, WandbLogger): return logger - raise Exception( + raise ValueError( "You are using wandb related callback, but WandbLogger was not found for some reason..." ) diff --git a/src/tasks/base_task.py b/src/tasks/base_task.py index 3173248c..4252b334 100644 --- a/src/tasks/base_task.py +++ b/src/tasks/base_task.py @@ -303,8 +303,6 @@ def _create_conf_mat(self, matrix: np.ndarray, stage: str = 'val'): # print(f'With all_gather: {str(len(outputs[0][OutputKeys.PREDICTION][0]))}') conf_mat_name = f'CM_epoch_{self.trainer.current_epoch}' - logger = get_wandb_logger(self.trainer) - experiment = logger.experiment # set figure size plt.figure(figsize=(14, 8)) @@ -325,8 +323,14 @@ def _create_conf_mat(self, matrix: np.ndarray, stage: str = 'val'): # save as csv or tsv to disc if self.trainer.is_global_zero: df.to_csv(path_or_buf=conf_mat_file_path, sep='\t') - # save tsv to wandb - experiment.save(glob_str=str(conf_mat_file_path), base_path=os.getcwd()) - # names should be uniqe or else charts from different experiments in wandb will overlap - experiment.log({f"confusion_matrix_{stage}_img/ep_{self.trainer.current_epoch}": wandb.Image(plt)}, - commit=False) + + try: + # save tsv to wandb + logger = get_wandb_logger(self.trainer) + experiment = logger.experiment + experiment.save(glob_str=str(conf_mat_file_path), base_path=os.getcwd()) + # names should be uniqe or else charts from different experiments in wandb will overlap + experiment.log({f"confusion_matrix_{stage}_img/ep_{self.trainer.current_epoch}": wandb.Image(plt)}, + commit=False) + except ValueError as e: + return diff --git a/tests/tasks/sem_seg/test_semantic_segmentation.py b/tests/tasks/sem_seg/test_semantic_segmentation.py index 243d580c..1b822fb0 100644 --- a/tests/tasks/sem_seg/test_semantic_segmentation.py +++ b/tests/tasks/sem_seg/test_semantic_segmentation.py @@ -29,7 +29,8 @@ def baby_unet(): segmentation = SemanticSegmentationHisDB(model=model, optimizer=torch.optim.Adam(params=model.parameters()), loss_fn=torch.nn.CrossEntropyLoss(), - test_output_path=tmp_path + test_output_path=tmp_path, + confusion_matrix_val=True ) # different paths needed later From 8600cc560f6a959295255a2eb67c52b61d1dea1b Mon Sep 17 00:00:00 2001 From: Paul M Date: Tue, 23 Nov 2021 10:52:13 +0100 Subject: [PATCH 089/108] :loud_sound: log warn/error when wandb logger is missing --- src/callbacks/wandb_callbacks.py | 10 ++++++++-- src/tasks/base_task.py | 3 ++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/callbacks/wandb_callbacks.py b/src/callbacks/wandb_callbacks.py index 04ff0645..df438e15 100644 --- a/src/callbacks/wandb_callbacks.py +++ b/src/callbacks/wandb_callbacks.py @@ -2,6 +2,8 @@ from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only +from src.utils import utils + def get_wandb_logger(trainer: Trainer) -> WandbLogger: if isinstance(trainer.logger, WandbLogger): @@ -26,5 +28,9 @@ def __init__(self, log: str = "gradients", log_freq: int = 100): @rank_zero_only def on_train_start(self, trainer, pl_module): - logger = get_wandb_logger(trainer=trainer) - logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + try: + logger = get_wandb_logger(trainer=trainer) + logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + except ValueError as e: + logger = utils.get_logger(__name__) + logger.error('No wandb logger found. WatchModelWithWandb callback will not do anything.') diff --git a/src/tasks/base_task.py b/src/tasks/base_task.py index 4252b334..02f02cb4 100644 --- a/src/tasks/base_task.py +++ b/src/tasks/base_task.py @@ -333,4 +333,5 @@ def _create_conf_mat(self, matrix: np.ndarray, stage: str = 'val'): experiment.log({f"confusion_matrix_{stage}_img/ep_{self.trainer.current_epoch}": wandb.Image(plt)}, commit=False) except ValueError as e: - return + log.warn('No wandb logger found. Confusion matrix images are not saved.') + From f26d5ccec987dc8f21afaa39cc0f9c8bc8e900a5 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Tue, 23 Nov 2021 10:54:15 +0100 Subject: [PATCH 090/108] :wrench: removed predict True were it is not working --- configs/experiment/development_baby_unet_cb55_10.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/experiment/development_baby_unet_cb55_10.yaml b/configs/experiment/development_baby_unet_cb55_10.yaml index b7432b00..e91117c7 100644 --- a/configs/experiment/development_baby_unet_cb55_10.yaml +++ b/configs/experiment/development_baby_unet_cb55_10.yaml @@ -27,7 +27,7 @@ seed: 42 train: True test: True -predict: True +predict: False trainer: _target_: pytorch_lightning.Trainer From e9d65df3f66d86611aec2108f4d864bfd8c73282 Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 24 Nov 2021 09:28:40 +0100 Subject: [PATCH 091/108] :loud_sound: log entire config --- run.py | 2 +- src/utils/utils.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/run.py b/run.py index d3d41d2f..27e5f4ed 100644 --- a/run.py +++ b/run.py @@ -25,7 +25,7 @@ def main(config: DictConfig): # Pretty print config using Rich library if config.get("print_config"): - utils.print_config(config, resolve=False) + utils.print_config(config, resolve=False, add_missing_fields=True) # Train model return execute(config) diff --git a/src/utils/utils.py b/src/utils/utils.py index 8839d81d..9541d249 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -116,13 +116,16 @@ def print_config( "optimizer", "datamodule", "callbacks", + "loss", "metric", "logger", "seed", "train", - "test" + "test", + "predict" ), resolve: bool = True, + add_missing_fields: bool = True, ) -> None: """Prints content of DictConfig using Rich library and its tree structure. @@ -136,6 +139,12 @@ def print_config( style = "dim" tree = Tree(f":gear: CONFIG", style=style, guide_style=style) + if add_missing_fields: + fields = list(fields) + for key in sorted(config.keys()): + if key not in fields: + fields.append(key) + for field in fields: branch = tree.add(field, style=style, guide_style=style) From c73aae0683f2918c18d710b9b94cc7e44bbc8f7c Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 24 Nov 2021 10:33:01 +0100 Subject: [PATCH 092/108] :pushpin: set wandb requirements to >=0.12.6 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e173e5f4..e3bd4d94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ hydra-colorlog==1.1.0 hydra-optuna-sweeper==1.1.0 # --------- loggers --------- # -wandb==0.12.6 +wandb>=0.12.6 # --------- others --------- # rich==10.1.0 From c4dc72af1d0337ae500e110489634d13ef99fa77 Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 24 Nov 2021 10:59:25 +0100 Subject: [PATCH 093/108] :wrench: added back header single_layer.yaml --- configs/model/header/single_layer.yaml | 5 +++++ src/models/headers/fully_connected.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 configs/model/header/single_layer.yaml diff --git a/configs/model/header/single_layer.yaml b/configs/model/header/single_layer.yaml new file mode 100644 index 00000000..9ec82edf --- /dev/null +++ b/configs/model/header/single_layer.yaml @@ -0,0 +1,5 @@ +_target_: src.models.headers.fully_connected.SingleLinear + +num_classes: ${datamodule:num_classes} +# needs to be calculated from the output of the last layer of the backbone (do not forget to flatten!) +in_channels: 109512 \ No newline at end of file diff --git a/src/models/headers/fully_connected.py b/src/models/headers/fully_connected.py index c8922a18..0b7604e1 100644 --- a/src/models/headers/fully_connected.py +++ b/src/models/headers/fully_connected.py @@ -15,3 +15,17 @@ def __init__(self, num_classes: int = 4, in_channels: int = 109512): def forward(self, x): x = self.fc(x) return x + + +class SingleLinear(nn.Module): + def __init__(self, num_classes: int = 4, in_channels: int = 109512): + super(SingleLinear, self).__init__() + + self.fc = nn.Sequential( + torch.nn.Flatten(), + nn.Linear(in_channels, num_classes) + ) + + def forward(self, x): + x = self.fc(x) + return x From 202983064cfdeff3757dbe35732ba75241579dfe Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Wed, 24 Nov 2021 11:03:15 +0100 Subject: [PATCH 094/108] :wrench: fixed in_channels in cnn basic experiment --- configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml b/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml index 55dc405d..66e40834 100644 --- a/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml +++ b/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml @@ -43,6 +43,10 @@ task: confusion_matrix_val: False confusion_matrix_test: False +model: + header: + in_channels: 109512 + datamodule: _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped From d616220feee11fa2aa0f3f005cc2bc7f2b10ee57 Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 24 Nov 2021 11:32:17 +0100 Subject: [PATCH 095/108] :sparkles: dataset_predict.py supports glob --- .../datamodule/rolf_format_dev_prediction.yaml | 8 ++++---- .../experiment/dev_rolf_format_prediction.yaml | 2 +- src/datamodules/RolfFormat/datamodule.py | 2 +- src/datamodules/utils/dataset_predict.py | 15 ++++++++++++++- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/configs/datamodule/rolf_format_dev_prediction.yaml b/configs/datamodule/rolf_format_dev_prediction.yaml index 24fcf872..1a8baf60 100644 --- a/configs/datamodule/rolf_format_dev_prediction.yaml +++ b/configs/datamodule/rolf_format_dev_prediction.yaml @@ -56,10 +56,10 @@ test_specs: range_to: 1099 pred_file_path_list: - - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1061.jpg" - - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1062.jpg" - - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1063.jpg" - - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1064.jpg" + - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-106[0-2].jpg" + - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-106[7,9].jpg" + - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-107*.jpg" + - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1085.jpg" image_dims: width: 640 diff --git a/configs/experiment/dev_rolf_format_prediction.yaml b/configs/experiment/dev_rolf_format_prediction.yaml index 37654f42..9e889ebd 100644 --- a/configs/experiment/dev_rolf_format_prediction.yaml +++ b/configs/experiment/dev_rolf_format_prediction.yaml @@ -32,7 +32,7 @@ predict: True model: backbone: - path_to_weights: /netscratch/experiments_lars_paul/lars/2021-11-22/11-52-43/checkpoints/epoch=1/backbone.pth + path_to_weights: /netscratch/experiments_lars_paul/paul/2021-11-24/09-12-01/checkpoints/epoch=1/backbone.pth trainer: _target_: pytorch_lightning.Trainer diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index b582b377..8cee03f5 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -233,4 +233,4 @@ def get_img_name_prediction(self, index): if not hasattr(self, 'predict'): raise Exception('This method can just be called during prediction') - return Path(self.predict.image_path_list[index]).stem + return self.predict.image_path_list[index].stem diff --git a/src/datamodules/utils/dataset_predict.py b/src/datamodules/utils/dataset_predict.py index 836ba8b4..916f2a94 100644 --- a/src/datamodules/utils/dataset_predict.py +++ b/src/datamodules/utils/dataset_predict.py @@ -1,3 +1,5 @@ +from glob import glob +from pathlib import Path from typing import List import torch.utils.data as data @@ -25,7 +27,8 @@ def __init__(self, image_path_list: List[str], image_dims: ImageDimensions, twin_transform : callable """ - self.image_path_list = list(image_path_list) + self._raw_image_path_list = list(image_path_list) + self.image_path_list = self.expend_glob_path_list(glob_path_list=self.raw_image_path_list) self.image_dims = image_dims @@ -84,3 +87,13 @@ def _apply_transformation(self, img): img = ToTensor()(img) return img + + @staticmethod + def expend_glob_path_list(glob_path_list: List[str]) -> List[Path]: + output_list = [] + for glob_path in glob_path_list: + for s in sorted(glob(glob_path)): + path = Path(s) + if path not in output_list: + output_list.append(Path(s)) + return output_list From 7c81a11b52f15e0c0961317b05fba0c5de12b43f Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 24 Nov 2021 11:34:53 +0100 Subject: [PATCH 096/108] :sparkles: dataset_predict.py supports glob --- src/datamodules/utils/dataset_predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datamodules/utils/dataset_predict.py b/src/datamodules/utils/dataset_predict.py index 916f2a94..017543ad 100644 --- a/src/datamodules/utils/dataset_predict.py +++ b/src/datamodules/utils/dataset_predict.py @@ -28,7 +28,7 @@ def __init__(self, image_path_list: List[str], image_dims: ImageDimensions, """ self._raw_image_path_list = list(image_path_list) - self.image_path_list = self.expend_glob_path_list(glob_path_list=self.raw_image_path_list) + self.image_path_list = self.expend_glob_path_list(glob_path_list=self._raw_image_path_list) self.image_dims = image_dims From 2d8fd7ec8981e1cb950e45d421adbd14932bc6f0 Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 24 Nov 2021 11:42:00 +0100 Subject: [PATCH 097/108] :white_check_mark: fixed tests --- tests/datamodules/util/test_predict_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datamodules/util/test_predict_dataset.py b/tests/datamodules/util/test_predict_dataset.py index d2fddde7..41aeff69 100644 --- a/tests/datamodules/util/test_predict_dataset.py +++ b/tests/datamodules/util/test_predict_dataset.py @@ -10,7 +10,7 @@ @pytest.fixture def file_path_list(data_dir): test_data_path = data_dir / 'test' / 'data' - return list(test_data_path.iterdir()) + return [str(p) for p in test_data_path.iterdir()] @pytest.fixture From 169d26fbfff725b5b8ed879fb83cc9328c959e69 Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 24 Nov 2021 16:48:19 +0100 Subject: [PATCH 098/108] :construction: class weights are optional, however loss does not work with class weights, yet. --- configs/datamodule/rolf_format_dev.yaml | 2 - .../rolf_format_dev_no_weights.yaml | 142 ++++++++++++++++++ configs/loss/crossentropyloss.yaml | 2 + .../DivaHisDB/datamodule_cropped.py | 6 +- src/datamodules/RGB/datamodule_cropped.py | 5 +- src/datamodules/RGB/datamodule_full_page.py | 5 +- src/datamodules/RolfFormat/datamodule.py | 20 ++- .../RolfFormat/utils/image_analytics.py | 2 +- src/datamodules/RotNet/datamodule_cropped.py | 6 +- src/datamodules/base_datamodule.py | 8 + .../DivaHisDB/test_hisDBDataModule.py | 5 +- 11 files changed, 191 insertions(+), 12 deletions(-) create mode 100644 configs/datamodule/rolf_format_dev_no_weights.yaml diff --git a/configs/datamodule/rolf_format_dev.yaml b/configs/datamodule/rolf_format_dev.yaml index 9db9f874..05b4c9b6 100644 --- a/configs/datamodule/rolf_format_dev.yaml +++ b/configs/datamodule/rolf_format_dev.yaml @@ -154,5 +154,3 @@ classes: G: 255 B: 0 weight: 0.024771204134236315 - - diff --git a/configs/datamodule/rolf_format_dev_no_weights.yaml b/configs/datamodule/rolf_format_dev_no_weights.yaml new file mode 100644 index 00000000..d65aa25b --- /dev/null +++ b/configs/datamodule/rolf_format_dev_no_weights.yaml @@ -0,0 +1,142 @@ +_target_: src.datamodules.RolfFormat.datamodule.DataModuleRolfFormat + +num_workers: 4 +batch_size: 8 +shuffle: True +drop_last: False + +data_root: /netscratch/datasets/semantic_segmentation/rolf_format + +train_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1000 + range_to: 1059 + append2: + doc_dir: "SetA1_sizeM_Rolf/layoutR/data" + doc_names: "A1-MR-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" + gt_names: "A1-MR-truthD-####.gif" + range_from: 1000 + range_to: 1059 + +val_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1060 + range_to: 1079 + append2: + doc_dir: "SetA1_sizeM_Rolf/layoutR/data" + doc_names: "A1-MR-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" + gt_names: "A1-MR-truthD-####.gif" + range_from: 1060 + range_to: 1079 + +test_specs: + append1: + doc_dir: "SetA1_sizeM_Rolf/layoutD/data" + doc_names: "A1-MD-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutD/gtruth" + gt_names: "A1-MD-truthD-####.gif" + range_from: 1080 + range_to: 1099 + append2: + doc_dir: "SetA1_sizeM_Rolf/layoutR/data" + doc_names: "A1-MR-page-####.jpg" + gt_dir: "SetA1_sizeM_Rolf/layoutR/gtruth" + gt_names: "A1-MR-truthD-####.gif" + range_from: 1080 + range_to: 1099 + +image_dims: + width: 640 + height: 896 + +image_analytics: + mean: + R: 0.8664800196201524 + G: 0.7408864118075618 + B: 0.6299955083595935 + std: + R: 0.2156624188591712 + G: 0.20890185198454636 + B: 0.1870731300038113 + +classes: + class0: + color: + R: 0 + G: 0 + B: 0 + class1: + color: + R: 0 + G: 102 + B: 0 + class2: + color: + R: 0 + G: 102 + B: 102 + class3: + color: + R: 0 + G: 153 + B: 153 + class4: + color: + R: 0 + G: 255 + B: 0 + class5: + color: + R: 0 + G: 255 + B: 255 + class6: + color: + R: 102 + G: 0 + B: 0 + class7: + color: + R: 102 + G: 0 + B: 102 + class8: + color: + R: 102 + G: 102 + B: 0 + class9: + color: + R: 153 + G: 0 + B: 153 + class10: + color: + R: 153 + G: 153 + B: 0 + class11: + color: + R: 255 + G: 0 + B: 0 + class12: + color: + R: 255 + G: 0 + B: 255 + class13: + color: + R: 255 + G: 255 + B: 0 diff --git a/configs/loss/crossentropyloss.yaml b/configs/loss/crossentropyloss.yaml index 0adfec9f..8cb5608b 100644 --- a/configs/loss/crossentropyloss.yaml +++ b/configs/loss/crossentropyloss.yaml @@ -1,2 +1,4 @@ # documentation: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss _target_: torch.nn.CrossEntropyLoss + +#weight: ${datamodule:class_weights} \ No newline at end of file diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index 824fc63a..e8e95da4 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Union, List, Optional +import torch from torch.utils.data import DataLoader from torchvision import transforms @@ -37,7 +38,7 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.std = analytics_data['std'] self.class_encodings = analytics_gt['class_encodings'] self.num_classes = len(self.class_encodings) - self.class_weights = analytics_gt['class_weights'] + self.class_weights = torch.as_tensor(analytics_gt['class_weights']) self.twin_transform = TwinRandomCrop(crop_size=crop_size) self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), @@ -61,6 +62,9 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.dims = (3, crop_size, crop_size) + # Check default attributes using base_datamodule function + self._check_attributes() + def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index 02e23c41..25ab3ecd 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -39,7 +39,7 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.class_encodings = analytics_gt['class_encodings'] self.class_encodings_tensor = torch.tensor(self.class_encodings) / 255 self.num_classes = len(self.class_encodings) - self.class_weights = analytics_gt['class_weights'] + self.class_weights = torch.as_tensor(analytics_gt['class_weights']) self.twin_transform = TwinRandomCrop(crop_size=crop_size) self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), @@ -61,6 +61,9 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.dims = (3, crop_size, crop_size) + # Check default attributes using base_datamodule function + self._check_attributes() + def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: diff --git a/src/datamodules/RGB/datamodule_full_page.py b/src/datamodules/RGB/datamodule_full_page.py index c6c9a40f..1595fc97 100644 --- a/src/datamodules/RGB/datamodule_full_page.py +++ b/src/datamodules/RGB/datamodule_full_page.py @@ -41,7 +41,7 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.class_encodings = analytics_gt['class_encodings'] self.class_encodings_tensor = torch.tensor(self.class_encodings) / 255 self.num_classes = len(self.class_encodings) - self.class_weights = analytics_gt['class_weights'] + self.class_weights = torch.as_tensor(analytics_gt['class_weights']) self.twin_transform = None self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), @@ -61,6 +61,9 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.selection_val = selection_val self.selection_test = selection_test + # Check default attributes using base_datamodule function + self._check_attributes() + def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index 8cee03f5..97878a5c 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -1,4 +1,3 @@ -from pathlib import Path from typing import Union, List, Optional import torch @@ -7,10 +6,10 @@ from src.datamodules.RGB.utils.single_transform import IntegerEncoding from src.datamodules.RolfFormat.datasets.dataset import DatasetRolfFormat, DatasetSpecs -from src.datamodules.utils.misc import ImageDimensions from src.datamodules.RolfFormat.utils.image_analytics import get_analytics_data, get_analytics_gt, get_image_dims from src.datamodules.base_datamodule import AbstractDatamodule from src.datamodules.utils.dataset_predict import DatasetPredict +from src.datamodules.utils.misc import ImageDimensions from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils @@ -62,7 +61,17 @@ def __init__(self, data_root: str, analytics_gt['class_encodings'].append([class_specs['color']['R'], class_specs['color']['G'], class_specs['color']['B']]) - analytics_gt['class_weights'].append(class_specs['weight']) + if 'weight' in class_specs: + analytics_gt['class_weights'].append(class_specs['weight']) + else: + analytics_gt['class_weights'].append(None) + + if all(x is None for x in analytics_gt['class_weights']): + analytics_gt['class_weights'] = None + elif any(x is None for x in analytics_gt['class_weights']): + log.error('Some classes have a class weight and others do not. ' + 'If you set class weights, you have to do this for all classes.') + raise ValueError self.image_dims = image_dims self.dims = (3, self.image_dims.width, self.image_dims.height) @@ -72,7 +81,7 @@ def __init__(self, data_root: str, self.class_encodings = analytics_gt['class_encodings'] self.class_encodings_tensor = torch.tensor(self.class_encodings) / 255 self.num_classes = len(self.class_encodings) - self.class_weights = analytics_gt['class_weights'] + self.class_weights = torch.as_tensor(analytics_gt['class_weights']) self.twin_transform = None self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), @@ -85,6 +94,9 @@ def __init__(self, data_root: str, self.shuffle = shuffle self.drop_last = drop_last + # Check default attributes using base_datamodule function + self._check_attributes() + def _print_analytics_data(self, analytics_data): indent = 4 * ' ' lines = [''] diff --git a/src/datamodules/RolfFormat/utils/image_analytics.py b/src/datamodules/RolfFormat/utils/image_analytics.py index 7ded9b98..bdcc838d 100644 --- a/src/datamodules/RolfFormat/utils/image_analytics.py +++ b/src/datamodules/RolfFormat/utils/image_analytics.py @@ -6,8 +6,8 @@ from PIL import Image from src.datamodules.RGB.utils.image_analytics import _get_class_frequencies_weights_segmentation -from src.datamodules.utils.misc import ImageDimensions from src.datamodules.utils.image_analytics import compute_mean_std +from src.datamodules.utils.misc import ImageDimensions def get_analytics_data(data_gt_path_list, **kwargs): diff --git a/src/datamodules/RotNet/datamodule_cropped.py b/src/datamodules/RotNet/datamodule_cropped.py index 890a34f3..188aad7a 100644 --- a/src/datamodules/RotNet/datamodule_cropped.py +++ b/src/datamodules/RotNet/datamodule_cropped.py @@ -2,6 +2,7 @@ from typing import Union, List, Optional import numpy as np +import torch from torch.utils.data import DataLoader from torchvision import transforms @@ -32,7 +33,7 @@ def __init__(self, data_dir: str, data_folder_name: str, self.std = analytics_data['std'] self.class_encodings = np.array(ROTATION_ANGLES) self.num_classes = len(self.class_encodings) - self.class_weights = np.array([1 / self.num_classes for _ in range(self.num_classes)]) + self.class_weights = torch.cuda.FloatTensor([1 / self.num_classes for _ in range(self.num_classes)]) self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), @@ -52,6 +53,9 @@ def __init__(self, data_dir: str, data_folder_name: str, self.dims = (3, crop_size, crop_size) + # Check default attributes using base_datamodule function + self._check_attributes() + def setup(self, stage: Optional[str] = None): super().setup() if stage == 'fit' or stage is None: diff --git a/src/datamodules/base_datamodule.py b/src/datamodules/base_datamodule.py index ab30a1d8..133292f7 100644 --- a/src/datamodules/base_datamodule.py +++ b/src/datamodules/base_datamodule.py @@ -1,6 +1,7 @@ from typing import Optional import pytorch_lightning as pl +import torch from omegaconf import OmegaConf @@ -8,6 +9,7 @@ class AbstractDatamodule(pl.LightningDataModule): def __init__(self): super().__init__() self.num_classes = -1 + self.class_weights = None resolver_name = 'datamodule' if not OmegaConf.has_resolver(resolver_name): OmegaConf.register_new_resolver( @@ -19,3 +21,9 @@ def __init__(self): def setup(self, stage: Optional[str] = None) -> None: if not self.dims: raise ValueError("the dimensions of the data needs to be set! self.dims") + + def _check_attributes(self): + assert self.num_classes > 0 + if self.class_weights is not None: + assert len(self.class_weights) == self.num_classes + assert torch.is_tensor(self.class_weights) \ No newline at end of file diff --git a/tests/datamodules/DivaHisDB/test_hisDBDataModule.py b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py index 3ab11412..9a037618 100644 --- a/tests/datamodules/DivaHisDB/test_hisDBDataModule.py +++ b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import torch from numpy import uint8 from omegaconf import OmegaConf @@ -24,7 +25,9 @@ def test_init_datamodule(data_module_cropped): assert data_module_cropped.dims == (3, 256, 256) assert data_module_cropped.num_classes == 4 assert data_module_cropped.class_encodings == [1, 2, 4, 8] - assert data_module_cropped.class_weights == [0.004952207651647859, 0.07424270397485577, 0.8964025044572563, 0.02440258391624002] + assert torch.all( + torch.eq(data_module_cropped.class_weights, + torch.tensor([0.004952207651647859, 0.07424270397485577, 0.8964025044572563, 0.02440258391624002]))) assert data_module_cropped.mean == [0.7050454974582426, 0.6503181590413943, 0.5567698583877997] assert data_module_cropped.std == [0.3104060859619883, 0.3053311838884032, 0.28919611393432726] with pytest.raises(AttributeError): From dde33c9e4d297e4e8120462a4e3e8ff3fc9008a5 Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 24 Nov 2021 16:55:58 +0100 Subject: [PATCH 099/108] :book: changed paths in configs to global variant --- configs/datamodule/cb55_10_cropped_datamodule.yaml | 2 +- configs/datamodule/cb55_cropped_datamodule.yaml | 2 +- configs/datamodule/rolf_format_dev.yaml | 2 +- configs/datamodule/rolf_format_dev_no_weights.yaml | 2 +- configs/datamodule/rolf_format_dev_prediction.yaml | 10 +++++----- configs/datamodule/rolf_format_layoutD_gtD.yaml | 2 +- configs/experiment/cb55_full_run_unet.yaml | 2 +- configs/experiment/cb55_select_train15_unet.yaml | 2 +- configs/experiment/cb55_select_train1_val1_unet.yaml | 2 +- configs/experiment/dev_rgb_full_page.yaml | 2 +- configs/experiment/dev_rolf_format_prediction.yaml | 2 +- configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml | 2 +- .../dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml | 4 ++-- configs/experiment/dev_rotnet_resnet18_cb55_10.yaml | 2 +- configs/experiment/dev_rotnet_resnet50_cb55_10.yaml | 2 +- configs/experiment/development_baby_unet_cb55_10.yaml | 2 +- configs/experiment/development_baby_unet_rgb_data.yaml | 2 +- configs/experiment/rotnet_resnet18_cb55_full.yaml | 2 +- .../experiment/rotnet_resnet18_cb55_train10_last.yaml | 2 +- .../experiment/rotnet_resnet18_cb55_train19_last.yaml | 2 +- .../experiment/synthetic_baby_unet_layoutD_gtD.yaml | 2 +- .../experiment/synthetic_baby_unet_layoutR_gtD.yaml | 2 +- tests/utils/test_utils.py | 2 +- tools/generate_cropped_dataset.py | 4 ++-- 24 files changed, 30 insertions(+), 30 deletions(-) diff --git a/configs/datamodule/cb55_10_cropped_datamodule.yaml b/configs/datamodule/cb55_10_cropped_datamodule.yaml index 44e7d090..1c8f1b32 100644 --- a/configs/datamodule/cb55_10_cropped_datamodule.yaml +++ b/configs/datamodule/cb55_10_cropped_datamodule.yaml @@ -1,7 +1,7 @@ _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped -data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation +data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/datamodule/cb55_cropped_datamodule.yaml b/configs/datamodule/cb55_cropped_datamodule.yaml index fe0f1dcc..8d794c56 100644 --- a/configs/datamodule/cb55_cropped_datamodule.yaml +++ b/configs/datamodule/cb55_cropped_datamodule.yaml @@ -1,6 +1,6 @@ _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped -data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 +data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/datamodule/rolf_format_dev.yaml b/configs/datamodule/rolf_format_dev.yaml index 05b4c9b6..0c6c2c1e 100644 --- a/configs/datamodule/rolf_format_dev.yaml +++ b/configs/datamodule/rolf_format_dev.yaml @@ -5,7 +5,7 @@ batch_size: 8 shuffle: True drop_last: False -data_root: /netscratch/datasets/semantic_segmentation/rolf_format +data_root: /net/research-hisdoc/datasets/semantic_segmentation/rolf_format train_specs: append1: diff --git a/configs/datamodule/rolf_format_dev_no_weights.yaml b/configs/datamodule/rolf_format_dev_no_weights.yaml index d65aa25b..9051acfa 100644 --- a/configs/datamodule/rolf_format_dev_no_weights.yaml +++ b/configs/datamodule/rolf_format_dev_no_weights.yaml @@ -5,7 +5,7 @@ batch_size: 8 shuffle: True drop_last: False -data_root: /netscratch/datasets/semantic_segmentation/rolf_format +data_root: /net/research-hisdoc/datasets/semantic_segmentation/rolf_format train_specs: append1: diff --git a/configs/datamodule/rolf_format_dev_prediction.yaml b/configs/datamodule/rolf_format_dev_prediction.yaml index 1a8baf60..d5ca6151 100644 --- a/configs/datamodule/rolf_format_dev_prediction.yaml +++ b/configs/datamodule/rolf_format_dev_prediction.yaml @@ -5,7 +5,7 @@ batch_size: 8 shuffle: True drop_last: False -data_root: /netscratch/datasets/semantic_segmentation/rolf_format +data_root: /net/research-hisdoc/datasets/semantic_segmentation/rolf_format train_specs: append1: @@ -56,10 +56,10 @@ test_specs: range_to: 1099 pred_file_path_list: - - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-106[0-2].jpg" - - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-106[7,9].jpg" - - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-107*.jpg" - - "/netscratch/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1085.jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-106[0-2].jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-106[7,9].jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-107*.jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1085.jpg" image_dims: width: 640 diff --git a/configs/datamodule/rolf_format_layoutD_gtD.yaml b/configs/datamodule/rolf_format_layoutD_gtD.yaml index 60e11bbb..0eabbe00 100644 --- a/configs/datamodule/rolf_format_layoutD_gtD.yaml +++ b/configs/datamodule/rolf_format_layoutD_gtD.yaml @@ -5,7 +5,7 @@ batch_size: 8 shuffle: True drop_last: False -data_root: /netscratch/datasets/semantic_segmentation/rolf_format +data_root: /net/research-hisdoc/datasets/semantic_segmentation/rolf_format train_specs: append1: diff --git a/configs/experiment/cb55_full_run_unet.yaml b/configs/experiment/cb55_full_run_unet.yaml index f42a9817..bfc96b64 100644 --- a/configs/experiment/cb55_full_run_unet.yaml +++ b/configs/experiment/cb55_full_run_unet.yaml @@ -42,7 +42,7 @@ task: datamodule: _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/experiment/cb55_select_train15_unet.yaml b/configs/experiment/cb55_select_train15_unet.yaml index 550931b1..22095666 100644 --- a/configs/experiment/cb55_select_train15_unet.yaml +++ b/configs/experiment/cb55_select_train15_unet.yaml @@ -42,7 +42,7 @@ task: datamodule: _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/experiment/cb55_select_train1_val1_unet.yaml b/configs/experiment/cb55_select_train1_val1_unet.yaml index 407f304e..fec8063f 100644 --- a/configs/experiment/cb55_select_train1_val1_unet.yaml +++ b/configs/experiment/cb55_select_train1_val1_unet.yaml @@ -42,7 +42,7 @@ task: datamodule: _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/experiment/dev_rgb_full_page.yaml b/configs/experiment/dev_rgb_full_page.yaml index 8cf2fd1f..f89fa918 100644 --- a/configs/experiment/dev_rgb_full_page.yaml +++ b/configs/experiment/dev_rgb_full_page.yaml @@ -46,7 +46,7 @@ task: datamodule: _target_: src.datamodules.RGB.datamodule_full_page.DataModuleRGB - data_dir: /netscratch/datasets/semantic_segmentation/synthetic/SetA1_sizeM/layoutD/split + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/synthetic/SetA1_sizeM/layoutD/split num_workers: 4 batch_size: 2 shuffle: True diff --git a/configs/experiment/dev_rolf_format_prediction.yaml b/configs/experiment/dev_rolf_format_prediction.yaml index 9e889ebd..fb78a76f 100644 --- a/configs/experiment/dev_rolf_format_prediction.yaml +++ b/configs/experiment/dev_rolf_format_prediction.yaml @@ -32,7 +32,7 @@ predict: True model: backbone: - path_to_weights: /netscratch/experiments_lars_paul/paul/2021-11-24/09-12-01/checkpoints/epoch=1/backbone.pth + path_to_weights: /net/research-hisdoc/experiments_lars_paul/paul/2021-11-24/09-12-01/checkpoints/epoch=1/backbone.pth trainer: _target_: pytorch_lightning.Trainer diff --git a/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml b/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml index 55dc405d..78f7dc32 100644 --- a/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml +++ b/configs/experiment/dev_rotnet_cnn_basic_cb55_10.yaml @@ -46,7 +46,7 @@ task: datamodule: _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml b/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml index 4f565411..ff0f1cdb 100644 --- a/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml +++ b/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml @@ -46,7 +46,7 @@ task: datamodule: _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation crop_size: 256 num_workers: 4 batch_size: 16 @@ -57,7 +57,7 @@ datamodule: model: backbone: - path_to_weights: /netscratch/experiments_lars_paul/lars/2021-11-15/16-08-51/checkpoints/epoch=1/backbone.pth + path_to_weights: /net/research-hisdoc/experiments_lars_paul/lars/2021-11-15/16-08-51/checkpoints/epoch=1/backbone.pth header: in_channels: 512 diff --git a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml index 7ec3295e..e1c82a5d 100644 --- a/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml +++ b/configs/experiment/dev_rotnet_resnet18_cb55_10.yaml @@ -46,7 +46,7 @@ task: datamodule: _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml b/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml index 5d927551..f136fa1d 100644 --- a/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml +++ b/configs/experiment/dev_rotnet_resnet50_cb55_10.yaml @@ -46,7 +46,7 @@ task: datamodule: _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/experiment/development_baby_unet_cb55_10.yaml b/configs/experiment/development_baby_unet_cb55_10.yaml index e91117c7..afbc3b46 100644 --- a/configs/experiment/development_baby_unet_cb55_10.yaml +++ b/configs/experiment/development_baby_unet_cb55_10.yaml @@ -46,7 +46,7 @@ task: datamodule: _target_: src.datamodules.DivaHisDB.datamodule_cropped.DivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/experiment/development_baby_unet_rgb_data.yaml b/configs/experiment/development_baby_unet_rgb_data.yaml index cf62491c..0b3f2326 100644 --- a/configs/experiment/development_baby_unet_rgb_data.yaml +++ b/configs/experiment/development_baby_unet_rgb_data.yaml @@ -46,7 +46,7 @@ task: datamodule: _target_: src.datamodules.RGB.datamodule_cropped.DataModuleCroppedRGB - data_dir: /netscratch/datasets/semantic_segmentation/synthetic_cropped/SetA1_sizeM/layoutD/split + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/synthetic_cropped/SetA1_sizeM/layoutD/split crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/experiment/rotnet_resnet18_cb55_full.yaml b/configs/experiment/rotnet_resnet18_cb55_full.yaml index fe4ffca3..1b7cf970 100644 --- a/configs/experiment/rotnet_resnet18_cb55_full.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_full.yaml @@ -45,7 +45,7 @@ task: datamodule: _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml b/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml index 92b988f1..937572d4 100644 --- a/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_train10_last.yaml @@ -45,7 +45,7 @@ task: datamodule: _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 num_workers: 4 batch_size: 32 diff --git a/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml b/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml index ac9b1547..e9dd0895 100644 --- a/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml +++ b/configs/experiment/rotnet_resnet18_cb55_train19_last.yaml @@ -45,7 +45,7 @@ task: datamodule: _target_: src.datamodules.RotNet.datamodule_cropped.RotNetDivaHisDBDataModuleCropped - data_dir: /netscratch/datasets/semantic_segmentation/datasets_cropped/CB55 + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55 crop_size: 256 num_workers: 4 batch_size: 256 diff --git a/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml b/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml index 85fb90aa..0874f9dd 100644 --- a/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml @@ -44,7 +44,7 @@ task: datamodule: _target_: src.datamodules.RGB.datamodule_cropped.DataModuleCroppedRGB - data_dir: /netscratch/datasets/semantic_segmentation/synthetic_cropped/SetA1_sizeM/layoutD/split + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/synthetic_cropped/SetA1_sizeM/layoutD/split crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml b/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml index 938e809f..f633cf71 100644 --- a/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml @@ -44,7 +44,7 @@ task: datamodule: _target_: src.datamodules.RGB.datamodule_cropped.DataModuleCroppedRGB - data_dir: /netscratch/datasets/semantic_segmentation/synthetic_cropped/SetA1_sizeM/layoutR/split + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/synthetic_cropped/SetA1_sizeM/layoutR/split crop_size: 256 num_workers: 4 batch_size: 16 diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index aba67e62..d5d00ef8 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -36,7 +36,7 @@ def get_dict(): 'min_epochs': 1, 'max_epochs': 3, 'weights_summary': 'full', 'precision': 16}, 'datamodule': { '_target_': 'src.datamodules.hisDBDataModule.DIVAHisDBDataModule.DIVAHisDBDataModuleCropped', - 'data_dir': '/netscratch/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation', + 'data_dir': '/net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55-10-segmentation', 'crop_size': 256, 'num_workers': 4, 'batch_size': 16, 'shuffle': True, 'drop_last': True}, 'save_config': True, 'checkpoint_folder_name': '{epoch}/', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': True, 'disable_warnings': True}) diff --git a/tools/generate_cropped_dataset.py b/tools/generate_cropped_dataset.py index 737e64cf..8373b04c 100644 --- a/tools/generate_cropped_dataset.py +++ b/tools/generate_cropped_dataset.py @@ -322,7 +322,7 @@ def _convert_crop_id_to_coordinates(self, img_index, hcrop_index, vcrop_index): # -i # /dataset/DIVA-HisDB/segmentation/CB55 # -o - # /netscratch/datasets/semantic_segmentation/datasets_cropped/temp-CB55 + # /net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/temp-CB55 # -tr # 300 # -v @@ -332,7 +332,7 @@ def _convert_crop_id_to_coordinates(self, img_index, hcrop_index, vcrop_index): # dataset_generator = CroppedDatasetGenerator( # input_path=Path('/dataset/DIVA-HisDB/segmentation/CB55'), - # output_path=Path('/netscratch/datasets/semantic_segmentation/datasets_cropped/CB55'), + # output_path=Path('/net/research-hisdoc/datasets/semantic_segmentation/datasets_cropped/CB55'), # crop_size_train=300, # crop_size_val=300, # crop_size_test=256, From 114a3a7693cf89fac8117b3ea8c3172ac39d1d9d Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 24 Nov 2021 17:18:03 +0100 Subject: [PATCH 100/108] :recycle: changed names so full_page is default, and only mention when it is not full_page --- configs/experiment/cb55_full_run_unet.yaml | 2 +- .../experiment/cb55_select_train15_unet.yaml | 2 +- .../cb55_select_train1_val1_unet.yaml | 2 +- configs/experiment/dev_rgb_full_page.yaml | 2 +- configs/experiment/dev_rolf_format.yaml | 2 +- .../dev_rolf_format_prediction.yaml | 2 +- ...tnet_pt_resnet18_cb55_10_segmentation.yaml | 2 +- .../development_baby_unet_cb55_10.yaml | 2 +- .../development_baby_unet_rgb_data.yaml | 2 +- .../synthetic_baby_unet_layoutD_gtD.yaml | 2 +- ...tic_baby_unet_layoutD_gtD_rolf_format.yaml | 2 +- .../synthetic_baby_unet_layoutR_gtD.yaml | 2 +- configs/task/semantic_segmentation_HisDB.yaml | 2 - .../semantic_segmentation_HisDB_cropped.yaml | 2 + .../semantic_segmentation_RGB_cropped.yaml | 2 + .../semantic_segmentation_RGB_full_page.yaml | 2 - ...{datamodule_full_page.py => datamodule.py} | 0 ...on.py => semantic_segmentation_cropped.py} | 2 +- src/tasks/RGB/semantic_segmentation.py | 59 ++++++++++++------ ...ge.py => semantic_segmentation_cropped.py} | 61 ++++++------------- .../sem_seg/test_semantic_segmentation.py | 14 ++--- 21 files changed, 84 insertions(+), 84 deletions(-) delete mode 100644 configs/task/semantic_segmentation_HisDB.yaml create mode 100644 configs/task/semantic_segmentation_HisDB_cropped.yaml create mode 100644 configs/task/semantic_segmentation_RGB_cropped.yaml delete mode 100644 configs/task/semantic_segmentation_RGB_full_page.yaml rename src/datamodules/RGB/{datamodule_full_page.py => datamodule.py} (100%) rename src/tasks/DivaHisDB/{semantic_segmentation.py => semantic_segmentation_cropped.py} (99%) rename src/tasks/RGB/{semantic_segmentation_full_page.py => semantic_segmentation_cropped.py} (61%) diff --git a/configs/experiment/cb55_full_run_unet.yaml b/configs/experiment/cb55_full_run_unet.yaml index bfc96b64..69769d16 100644 --- a/configs/experiment/cb55_full_run_unet.yaml +++ b/configs/experiment/cb55_full_run_unet.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_HisDB.yaml + - /task: semantic_segmentation_HisDB_cropped.yaml - /loss: crossentropyloss.yaml - /metric: hisdbiou.yaml - /model/backbone: unet_model.yaml diff --git a/configs/experiment/cb55_select_train15_unet.yaml b/configs/experiment/cb55_select_train15_unet.yaml index 22095666..298ca13e 100644 --- a/configs/experiment/cb55_select_train15_unet.yaml +++ b/configs/experiment/cb55_select_train15_unet.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_HisDB.yaml + - /task: semantic_segmentation_HisDB_cropped.yaml - /loss: crossentropyloss.yaml - /metric: hisdbiou.yaml - /model/backbone: unet_model.yaml diff --git a/configs/experiment/cb55_select_train1_val1_unet.yaml b/configs/experiment/cb55_select_train1_val1_unet.yaml index fec8063f..23fabdb4 100644 --- a/configs/experiment/cb55_select_train1_val1_unet.yaml +++ b/configs/experiment/cb55_select_train1_val1_unet.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_HisDB.yaml + - /task: semantic_segmentation_HisDB_cropped.yaml - /loss: crossentropyloss.yaml - /metric: hisdbiou.yaml - /model/backbone: unet_model.yaml diff --git a/configs/experiment/dev_rgb_full_page.yaml b/configs/experiment/dev_rgb_full_page.yaml index f89fa918..28d6a2f7 100644 --- a/configs/experiment/dev_rgb_full_page.yaml +++ b/configs/experiment/dev_rgb_full_page.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_RGB_full_page.yaml + - /task: semantic_segmentation_RGB.yaml - /loss: crossentropyloss.yaml - /metric: iou.yaml - /model/backbone: baby_unet_model.yaml diff --git a/configs/experiment/dev_rolf_format.yaml b/configs/experiment/dev_rolf_format.yaml index aa25bb26..abdc016d 100644 --- a/configs/experiment/dev_rolf_format.yaml +++ b/configs/experiment/dev_rolf_format.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_RGB_full_page.yaml + - /task: semantic_segmentation_RGB.yaml - /datamodule: rolf_format_dev.yaml - /loss: crossentropyloss.yaml - /metric: iou.yaml diff --git a/configs/experiment/dev_rolf_format_prediction.yaml b/configs/experiment/dev_rolf_format_prediction.yaml index fb78a76f..ce6fc3ba 100644 --- a/configs/experiment/dev_rolf_format_prediction.yaml +++ b/configs/experiment/dev_rolf_format_prediction.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_RGB_full_page.yaml + - /task: semantic_segmentation_RGB.yaml - /datamodule: rolf_format_dev_prediction.yaml - /loss: crossentropyloss.yaml - /metric: iou.yaml diff --git a/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml b/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml index ff0f1cdb..fc203501 100644 --- a/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml +++ b/configs/experiment/dev_rotnet_pt_resnet18_cb55_10_segmentation.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_HisDB.yaml + - /task: semantic_segmentation_HisDB_cropped.yaml - /loss: crossentropyloss.yaml - /metric: hisdbiou.yaml - /model/backbone: resnet18.yaml diff --git a/configs/experiment/development_baby_unet_cb55_10.yaml b/configs/experiment/development_baby_unet_cb55_10.yaml index afbc3b46..e03410ce 100644 --- a/configs/experiment/development_baby_unet_cb55_10.yaml +++ b/configs/experiment/development_baby_unet_cb55_10.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_HisDB.yaml + - /task: semantic_segmentation_HisDB_cropped.yaml - /loss: crossentropyloss.yaml - /metric: hisdbiou.yaml - /model/backbone: baby_unet_model.yaml diff --git a/configs/experiment/development_baby_unet_rgb_data.yaml b/configs/experiment/development_baby_unet_rgb_data.yaml index 0b3f2326..4fe1f7b1 100644 --- a/configs/experiment/development_baby_unet_rgb_data.yaml +++ b/configs/experiment/development_baby_unet_rgb_data.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_RGB.yaml + - /task: semantic_segmentation_RGB_cropped.yaml - /loss: crossentropyloss.yaml - /metric: iou.yaml - /model/backbone: baby_unet_model.yaml diff --git a/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml b/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml index 0874f9dd..0d384e64 100644 --- a/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutD_gtD.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_RGB.yaml + - /task: semantic_segmentation_RGB_cropped.yaml - /loss: crossentropyloss.yaml - /metric: iou.yaml - /model/backbone: baby_unet_model.yaml diff --git a/configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml b/configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml index d83bd2a1..48246fda 100644 --- a/configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutD_gtD_rolf_format.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_RGB_full_page.yaml + - /task: semantic_segmentation_RGB.yaml - /datamodule: rolf_format_layoutD_gtD.yaml - /loss: crossentropyloss.yaml - /metric: iou.yaml diff --git a/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml b/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml index f633cf71..ee2012ec 100644 --- a/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml +++ b/configs/experiment/synthetic_baby_unet_layoutR_gtD.yaml @@ -5,7 +5,7 @@ defaults: - /plugins: default.yaml - - /task: semantic_segmentation_RGB.yaml + - /task: semantic_segmentation_RGB_cropped.yaml - /loss: crossentropyloss.yaml - /metric: iou.yaml - /model/backbone: baby_unet_model.yaml diff --git a/configs/task/semantic_segmentation_HisDB.yaml b/configs/task/semantic_segmentation_HisDB.yaml deleted file mode 100644 index 89dd207f..00000000 --- a/configs/task/semantic_segmentation_HisDB.yaml +++ /dev/null @@ -1,2 +0,0 @@ -_target_: src.tasks.DivaHisDB.semantic_segmentation.SemanticSegmentationHisDB - diff --git a/configs/task/semantic_segmentation_HisDB_cropped.yaml b/configs/task/semantic_segmentation_HisDB_cropped.yaml new file mode 100644 index 00000000..6f12a0e6 --- /dev/null +++ b/configs/task/semantic_segmentation_HisDB_cropped.yaml @@ -0,0 +1,2 @@ +_target_: src.tasks.DivaHisDB.semantic_segmentation_cropped.SemanticSegmentationCroppedHisDB + diff --git a/configs/task/semantic_segmentation_RGB_cropped.yaml b/configs/task/semantic_segmentation_RGB_cropped.yaml new file mode 100644 index 00000000..7a885f75 --- /dev/null +++ b/configs/task/semantic_segmentation_RGB_cropped.yaml @@ -0,0 +1,2 @@ +_target_: src.tasks.RGB.semantic_segmentation_cropped.SemanticSegmentationCroppedRGB + diff --git a/configs/task/semantic_segmentation_RGB_full_page.yaml b/configs/task/semantic_segmentation_RGB_full_page.yaml deleted file mode 100644 index 406872ed..00000000 --- a/configs/task/semantic_segmentation_RGB_full_page.yaml +++ /dev/null @@ -1,2 +0,0 @@ -_target_: src.tasks.RGB.semantic_segmentation_full_page.SemanticSegmentationFullPageRGB - diff --git a/src/datamodules/RGB/datamodule_full_page.py b/src/datamodules/RGB/datamodule.py similarity index 100% rename from src/datamodules/RGB/datamodule_full_page.py rename to src/datamodules/RGB/datamodule.py diff --git a/src/tasks/DivaHisDB/semantic_segmentation.py b/src/tasks/DivaHisDB/semantic_segmentation_cropped.py similarity index 99% rename from src/tasks/DivaHisDB/semantic_segmentation.py rename to src/tasks/DivaHisDB/semantic_segmentation_cropped.py index 4b0ce5fa..97674457 100644 --- a/src/tasks/DivaHisDB/semantic_segmentation.py +++ b/src/tasks/DivaHisDB/semantic_segmentation_cropped.py @@ -14,7 +14,7 @@ log = utils.get_logger(__name__) -class SemanticSegmentationHisDB(AbstractTask): +class SemanticSegmentationCroppedHisDB(AbstractTask): def __init__(self, model: nn.Module, diff --git a/src/tasks/RGB/semantic_segmentation.py b/src/tasks/RGB/semantic_segmentation.py index 99172e1a..a7613b5d 100644 --- a/src/tasks/RGB/semantic_segmentation.py +++ b/src/tasks/RGB/semantic_segmentation.py @@ -1,11 +1,12 @@ from pathlib import Path -from typing import Optional, Callable, Union +from typing import Optional, Callable, Union, Any import numpy as np import torch.nn as nn import torch.optim import torchmetrics +from src.datamodules.RGB.utils.output_tools import save_output_page_image from src.datamodules.utils.misc import _get_argmax from src.tasks.base_task import AbstractTask from src.utils import utils @@ -57,8 +58,8 @@ def __init__(self, def setup(self, stage: str) -> None: super().setup(stage) - if not hasattr(self.trainer.datamodule, 'get_img_name_coordinates'): - raise NotImplementedError('DataModule needs to implement get_img_name_coordinates function') + if not hasattr(self.trainer.datamodule, 'get_img_name'): + raise NotImplementedError('DataModule needs to implement get_img_name function') log.info("Setup done!") @@ -94,30 +95,50 @@ def test_step(self, batch, batch_idx, **kwargs): input_batch, target_batch, input_idx = batch output = super().test_step(batch=(input_batch, target_batch), batch_idx=batch_idx) - if not hasattr(self.trainer.datamodule, 'get_img_name_coordinates'): + if not hasattr(self.trainer.datamodule, 'get_img_name'): raise NotImplementedError('Datamodule does not provide detailed information of the crop') - for patch, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), - input_idx.detach().cpu().numpy()): - patch_info = self.trainer.datamodule.get_img_name_coordinates(idx) + for pred_raw, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), + input_idx.detach().cpu().numpy()): + patch_info = self.trainer.datamodule.get_img_name(idx) img_name = patch_info[0] - patch_name = patch_info[1] - dest_folder = self.test_output_path / 'patches' / img_name + dest_folder = self.test_output_path / 'pred_raw' dest_folder.mkdir(parents=True, exist_ok=True) - dest_filename = dest_folder / f'{patch_name}.npy' + dest_filename = dest_folder / f'{img_name}.npy' + np.save(file=str(dest_filename), arr=pred_raw) - np.save(file=str(dest_filename), arr=patch) + dest_folder = self.test_output_path / 'pred' + dest_folder.mkdir(parents=True, exist_ok=True) + save_output_page_image(image_name=f'{img_name}.gif', output_image=pred_raw, + output_folder=dest_folder, class_encoding=self.trainer.datamodule.class_encodings) return reduce_dict(input_dict=output, key_list=[]) def on_test_end(self) -> None: - datamodule_path = self.trainer.datamodule.data_dir - prediction_path = (self.test_output_path / 'patches').absolute() - output_path = (self.test_output_path / 'result').absolute() + pass + + ############################################################################################# + ######################################### PREDICT ########################################### + ############################################################################################# + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: + input_batch, input_idx = batch + output = super().predict_step(batch=input_batch, batch_idx=batch_idx, dataloader_idx=dataloader_idx) - data_folder_name = self.trainer.datamodule.data_folder_name - gt_folder_name = self.trainer.datamodule.gt_folder_name + if not hasattr(self.trainer.datamodule, 'get_img_name'): + raise NotImplementedError('Datamodule does not provide detailed information of the crop') - log.info(f'To run the merging of patches:') - log.info(f'python tools/merge_cropped_output_RGB.py -d {datamodule_path} -p {prediction_path} -o {output_path} ' - f'-df {data_folder_name} -gf {gt_folder_name}') + for pred_raw, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), + input_idx.detach().cpu().numpy()): + img_name = self.trainer.datamodule.get_img_name_prediction(idx) + dest_folder = self.predict_output_path / 'pred_raw' + dest_folder.mkdir(parents=True, exist_ok=True) + dest_filename = dest_folder / f'{img_name}.npy' + np.save(file=str(dest_filename), arr=pred_raw) + + dest_folder = self.predict_output_path / 'pred' + dest_folder.mkdir(parents=True, exist_ok=True) + save_output_page_image(image_name=f'{img_name}.gif', output_image=pred_raw, + output_folder=dest_folder, class_encoding=self.trainer.datamodule.class_encodings) + + return reduce_dict(input_dict=output, key_list=[]) diff --git a/src/tasks/RGB/semantic_segmentation_full_page.py b/src/tasks/RGB/semantic_segmentation_cropped.py similarity index 61% rename from src/tasks/RGB/semantic_segmentation_full_page.py rename to src/tasks/RGB/semantic_segmentation_cropped.py index 78302b17..2505fb2b 100644 --- a/src/tasks/RGB/semantic_segmentation_full_page.py +++ b/src/tasks/RGB/semantic_segmentation_cropped.py @@ -1,12 +1,11 @@ from pathlib import Path -from typing import Optional, Callable, Union, Any +from typing import Optional, Callable, Union import numpy as np import torch.nn as nn import torch.optim import torchmetrics -from src.datamodules.RGB.utils.output_tools import save_output_page_image from src.datamodules.utils.misc import _get_argmax from src.tasks.base_task import AbstractTask from src.utils import utils @@ -15,7 +14,7 @@ log = utils.get_logger(__name__) -class SemanticSegmentationFullPageRGB(AbstractTask): +class SemanticSegmentationCroppedRGB(AbstractTask): def __init__(self, model: nn.Module, @@ -58,8 +57,8 @@ def __init__(self, def setup(self, stage: str) -> None: super().setup(stage) - if not hasattr(self.trainer.datamodule, 'get_img_name'): - raise NotImplementedError('DataModule needs to implement get_img_name function') + if not hasattr(self.trainer.datamodule, 'get_img_name_coordinates'): + raise NotImplementedError('DataModule needs to implement get_img_name_coordinates function') log.info("Setup done!") @@ -95,50 +94,30 @@ def test_step(self, batch, batch_idx, **kwargs): input_batch, target_batch, input_idx = batch output = super().test_step(batch=(input_batch, target_batch), batch_idx=batch_idx) - if not hasattr(self.trainer.datamodule, 'get_img_name'): + if not hasattr(self.trainer.datamodule, 'get_img_name_coordinates'): raise NotImplementedError('Datamodule does not provide detailed information of the crop') - for pred_raw, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), - input_idx.detach().cpu().numpy()): - patch_info = self.trainer.datamodule.get_img_name(idx) + for patch, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), + input_idx.detach().cpu().numpy()): + patch_info = self.trainer.datamodule.get_img_name_coordinates(idx) img_name = patch_info[0] - dest_folder = self.test_output_path / 'pred_raw' + patch_name = patch_info[1] + dest_folder = self.test_output_path / 'patches' / img_name dest_folder.mkdir(parents=True, exist_ok=True) - dest_filename = dest_folder / f'{img_name}.npy' - np.save(file=str(dest_filename), arr=pred_raw) + dest_filename = dest_folder / f'{patch_name}.npy' - dest_folder = self.test_output_path / 'pred' - dest_folder.mkdir(parents=True, exist_ok=True) - save_output_page_image(image_name=f'{img_name}.gif', output_image=pred_raw, - output_folder=dest_folder, class_encoding=self.trainer.datamodule.class_encodings) + np.save(file=str(dest_filename), arr=patch) return reduce_dict(input_dict=output, key_list=[]) def on_test_end(self) -> None: - pass - - ############################################################################################# - ######################################### PREDICT ########################################### - ############################################################################################# - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: - input_batch, input_idx = batch - output = super().predict_step(batch=input_batch, batch_idx=batch_idx, dataloader_idx=dataloader_idx) + datamodule_path = self.trainer.datamodule.data_dir + prediction_path = (self.test_output_path / 'patches').absolute() + output_path = (self.test_output_path / 'result').absolute() - if not hasattr(self.trainer.datamodule, 'get_img_name'): - raise NotImplementedError('Datamodule does not provide detailed information of the crop') + data_folder_name = self.trainer.datamodule.data_folder_name + gt_folder_name = self.trainer.datamodule.gt_folder_name - for pred_raw, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), - input_idx.detach().cpu().numpy()): - img_name = self.trainer.datamodule.get_img_name_prediction(idx) - dest_folder = self.predict_output_path / 'pred_raw' - dest_folder.mkdir(parents=True, exist_ok=True) - dest_filename = dest_folder / f'{img_name}.npy' - np.save(file=str(dest_filename), arr=pred_raw) - - dest_folder = self.predict_output_path / 'pred' - dest_folder.mkdir(parents=True, exist_ok=True) - save_output_page_image(image_name=f'{img_name}.gif', output_image=pred_raw, - output_folder=dest_folder, class_encoding=self.trainer.datamodule.class_encodings) - - return reduce_dict(input_dict=output, key_list=[]) + log.info(f'To run the merging of patches:') + log.info(f'python tools/merge_cropped_output_RGB.py -d {datamodule_path} -p {prediction_path} -o {output_path} ' + f'-df {data_folder_name} -gf {gt_folder_name}') diff --git a/tests/tasks/sem_seg/test_semantic_segmentation.py b/tests/tasks/sem_seg/test_semantic_segmentation.py index 1b822fb0..c33d86b3 100644 --- a/tests/tasks/sem_seg/test_semantic_segmentation.py +++ b/tests/tasks/sem_seg/test_semantic_segmentation.py @@ -8,7 +8,7 @@ from pytorch_lightning import seed_everything from src.datamodules.DivaHisDB.datamodule_cropped import DivaHisDBDataModuleCropped -from src.tasks.DivaHisDB.semantic_segmentation import SemanticSegmentationHisDB +from src.tasks.DivaHisDB.semantic_segmentation_cropped import SemanticSegmentationCroppedHisDB from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped @@ -26,12 +26,12 @@ def baby_unet(): return UNet(num_classes=len(data_module.class_encodings), num_layers=2, features_start=32) model = baby_unet() - segmentation = SemanticSegmentationHisDB(model=model, - optimizer=torch.optim.Adam(params=model.parameters()), - loss_fn=torch.nn.CrossEntropyLoss(), - test_output_path=tmp_path, - confusion_matrix_val=True - ) + segmentation = SemanticSegmentationCroppedHisDB(model=model, + optimizer=torch.optim.Adam(params=model.parameters()), + loss_fn=torch.nn.CrossEntropyLoss(), + test_output_path=tmp_path, + confusion_matrix_val=True + ) # different paths needed later patches_path = segmentation.test_output_path / 'patches' From 81e59dc607ab42fb673b129ae6e2c5e2cb3c64b4 Mon Sep 17 00:00:00 2001 From: Paul M Date: Wed, 24 Nov 2021 18:07:09 +0100 Subject: [PATCH 101/108] :bug: :white_check_mark: :wrench: loss with weight fixed --- configs/loss/crossentropyloss.yaml | 2 +- configs/loss/crossentropyloss_no_weight.yaml | 2 ++ src/datamodules/RotNet/datamodule_cropped.py | 2 +- src/tasks/base_task.py | 6 ++++++ tests/datamodules/DivaHisDB/test_hisDBDataModule.py | 6 +++--- tests/datamodules/RotNet/test_datamodule_cropped.py | 3 ++- 6 files changed, 15 insertions(+), 6 deletions(-) create mode 100644 configs/loss/crossentropyloss_no_weight.yaml diff --git a/configs/loss/crossentropyloss.yaml b/configs/loss/crossentropyloss.yaml index 8cb5608b..900e5a28 100644 --- a/configs/loss/crossentropyloss.yaml +++ b/configs/loss/crossentropyloss.yaml @@ -1,4 +1,4 @@ # documentation: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss _target_: torch.nn.CrossEntropyLoss -#weight: ${datamodule:class_weights} \ No newline at end of file +weight: ${datamodule:class_weights} \ No newline at end of file diff --git a/configs/loss/crossentropyloss_no_weight.yaml b/configs/loss/crossentropyloss_no_weight.yaml new file mode 100644 index 00000000..ca626025 --- /dev/null +++ b/configs/loss/crossentropyloss_no_weight.yaml @@ -0,0 +1,2 @@ +# documentation: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss +_target_: torch.nn.CrossEntropyLoss \ No newline at end of file diff --git a/src/datamodules/RotNet/datamodule_cropped.py b/src/datamodules/RotNet/datamodule_cropped.py index 188aad7a..b7bf0c26 100644 --- a/src/datamodules/RotNet/datamodule_cropped.py +++ b/src/datamodules/RotNet/datamodule_cropped.py @@ -33,7 +33,7 @@ def __init__(self, data_dir: str, data_folder_name: str, self.std = analytics_data['std'] self.class_encodings = np.array(ROTATION_ANGLES) self.num_classes = len(self.class_encodings) - self.class_weights = torch.cuda.FloatTensor([1 / self.num_classes for _ in range(self.num_classes)]) + self.class_weights = torch.as_tensor([1 / self.num_classes for _ in range(self.num_classes)]) self.image_transform = OnlyImage(transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), diff --git a/src/tasks/base_task.py b/src/tasks/base_task.py index 02f02cb4..8e1ecb33 100644 --- a/src/tasks/base_task.py +++ b/src/tasks/base_task.py @@ -76,6 +76,7 @@ def __init__( if model is not None: self.model = model + self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) self.optimizer = optimizer self.scheduler = scheduler @@ -136,6 +137,11 @@ def step(self, look like this: {'B': {'x': 'value', 'y': 'value'}} """ + for key in self.loss_fn: + if hasattr(self.loss_fn[key], 'weight') and self.loss_fn[key].weight is not None: + if torch.is_tensor(self.loss_fn[key].weight): + self.loss_fn[key].weight = self.loss_fn[key].weight.cuda(device=self.device) + if metric_kwargs is None: metric_kwargs = {} x, y = batch diff --git a/tests/datamodules/DivaHisDB/test_hisDBDataModule.py b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py index 9a037618..c394187f 100644 --- a/tests/datamodules/DivaHisDB/test_hisDBDataModule.py +++ b/tests/datamodules/DivaHisDB/test_hisDBDataModule.py @@ -25,9 +25,9 @@ def test_init_datamodule(data_module_cropped): assert data_module_cropped.dims == (3, 256, 256) assert data_module_cropped.num_classes == 4 assert data_module_cropped.class_encodings == [1, 2, 4, 8] - assert torch.all( - torch.eq(data_module_cropped.class_weights, - torch.tensor([0.004952207651647859, 0.07424270397485577, 0.8964025044572563, 0.02440258391624002]))) + assert torch.equal(data_module_cropped.class_weights, + torch.tensor( + [0.004952207651647859, 0.07424270397485577, 0.8964025044572563, 0.02440258391624002])) assert data_module_cropped.mean == [0.7050454974582426, 0.6503181590413943, 0.5567698583877997] assert data_module_cropped.std == [0.3104060859619883, 0.3053311838884032, 0.28919611393432726] with pytest.raises(AttributeError): diff --git a/tests/datamodules/RotNet/test_datamodule_cropped.py b/tests/datamodules/RotNet/test_datamodule_cropped.py index 64306e81..78bb38bc 100644 --- a/tests/datamodules/RotNet/test_datamodule_cropped.py +++ b/tests/datamodules/RotNet/test_datamodule_cropped.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import torch from omegaconf import OmegaConf from src.datamodules.RotNet.datamodule_cropped import RotNetDivaHisDBDataModuleCropped @@ -22,7 +23,7 @@ def test_init_datamodule(data_module_cropped): assert data_module_cropped.dims == (3, 256, 256) assert data_module_cropped.num_classes == 4 assert np.array_equal(data_module_cropped.class_encodings, [0, 90, 180, 270]) - assert np.array_equal(data_module_cropped.class_weights, [.25, .25, .25, .25]) + assert torch.equal(data_module_cropped.class_weights, torch.tensor([.25, .25, .25, .25])) assert data_module_cropped.mean == [0.7050454974582426, 0.6503181590413943, 0.5567698583877997] assert data_module_cropped.std == [0.3104060859619883, 0.3053311838884032, 0.28919611393432726] with pytest.raises(AttributeError): From c10852ab97a00e7ada0a67ba93a12dc6327a297a Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Thu, 25 Nov 2021 11:33:27 +0100 Subject: [PATCH 102/108] :sparkles: added for backbone and header a freeze flag which freezes the corresponding part during training. --- src/execute.py | 14 +++++++++++++- src/utils/utils.py | 5 +++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/execute.py b/src/execute.py index 5bd04fdb..cccfba19 100644 --- a/src/execute.py +++ b/src/execute.py @@ -39,7 +39,7 @@ def execute(config: DictConfig) -> Optional[float]: header: LightningModule = _load_model_part(config=config, part_name='header') # container model - model: LightningModule = BackboneHeaderModel(backbone=backbone, header=header) + model: BackboneHeaderModel = BackboneHeaderModel(backbone=backbone, header=header) # Init optimizer log.info(f"Instantiating optimizer <{config.optimizer._target_}>") @@ -165,12 +165,18 @@ def _load_model_part(config: DictConfig, part_name: str): LightningModule: The loaded network """ + freeze = False strict = True if 'strict' in config.model.get(part_name): log.info(f"The model part {part_name} will be loaded with strict={config.model.get(part_name).strict}") strict = config.model.get(part_name).strict del config.model.get(part_name).strict + if 'freeze' in config.model.get(part_name): + log.info(f"The model part {part_name} is frozen during all stages!") + freeze = True + del config.model.get(part_name).freeze + if "path_to_weights" in config.model.get(part_name): log.info(f"Loading {part_name} weights from <{config.model.get(part_name).path_to_weights}>") path_to_weights = config.model.get(part_name).path_to_weights @@ -190,6 +196,12 @@ def _load_model_part(config: DictConfig, part_name: str): "Use 'path_to_weights' in your model to load a trained model") part: LightningModule = hydra.utils.instantiate(config.model.get(part_name)) + if freeze: + for param in part.parameters(): + param.requires_grad = False + + part.eval() + return part diff --git a/src/utils/utils.py b/src/utils/utils.py index 9541d249..4d7a72b7 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -93,6 +93,11 @@ def check_config(config: DictConfig) -> None: seed = random.randint(np.iinfo(np.uint32).min, np.iinfo(np.uint32).max) config['seed'] = seed log.info(f"No seed specified! Seed set to {seed}") + + if 'freeze' in config.model.backbone and 'freeze' in config.model.header and config.train: + if config.model.backbone.freeze and config.model.header.freeze: + log.error(f"Cannot train with no trainable parameters! Both header and backbone are frozen!") + # disable adding new keys to config OmegaConf.set_struct(config, True) From fe43aea81825a356bc466e8d5148f2f6fec06845 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Thu, 25 Nov 2021 11:34:02 +0100 Subject: [PATCH 103/108] :art: improved the gradient logging by just logging the model and not the whole task --- src/callbacks/wandb_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/callbacks/wandb_callbacks.py b/src/callbacks/wandb_callbacks.py index df438e15..96d7fb9f 100644 --- a/src/callbacks/wandb_callbacks.py +++ b/src/callbacks/wandb_callbacks.py @@ -30,7 +30,7 @@ def __init__(self, log: str = "gradients", log_freq: int = 100): def on_train_start(self, trainer, pl_module): try: logger = get_wandb_logger(trainer=trainer) - logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + logger.watch(model=pl_module.model, log=self.log, log_freq=self.log_freq) except ValueError as e: logger = utils.get_logger(__name__) logger.error('No wandb logger found. WatchModelWithWandb callback will not do anything.') From 92bd5a7f9b1d0251ec8d44ddfc1d49277eb946d5 Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Thu, 25 Nov 2021 11:35:20 +0100 Subject: [PATCH 104/108] :memo: added explanation to the docs how to freeze a model part --- README.md | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e7bcd591..02f714ca 100644 --- a/README.md +++ b/README.md @@ -111,4 +111,19 @@ e.g. ``` trainer: resume_from_checkpoint: /path/to/.ckpt/file -``` \ No newline at end of file +``` + +### Freezing model parts +You can freeze both parts of the model (backbone or header) with the `freeze` flag in the config. +E.g. you want to freeze the backbone: +In the command line: +``` +python run.py +model.backbone.freeze=True +``` +In the config (e.g. model/backbone/baby_unet_model.yaml): +``` +... +freeze: True +... +``` +CARE: You can not train a model when you do not have trainable parameters (e.g. freezing backbone and header). From f63d4d1e6b839318c7884257e84999a86c0120a1 Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 25 Nov 2021 12:09:15 +0100 Subject: [PATCH 105/108] :sparkles: :recycle: :art: :wrench: predict/test output can handle duplicates and print file list. --- configs/datamodule/rolf_format_dev.yaml | 8 +++ .../rolf_format_dev_prediction.yaml | 5 ++ configs/experiment/dev_rgb_full_page.yaml | 2 +- .../DivaHisDB/datamodule_cropped.py | 2 +- .../DivaHisDB/datasets/cropped_dataset.py | 6 +-- src/datamodules/RGB/datamodule.py | 8 +-- src/datamodules/RGB/datamodule_cropped.py | 4 +- .../RGB/datasets/cropped_dataset.py | 6 +-- .../RGB/datasets/full_page_dataset.py | 24 +++++---- src/datamodules/RGB/utils/image_analytics.py | 10 ++-- src/datamodules/RolfFormat/datamodule.py | 14 ++--- .../RolfFormat/datasets/dataset.py | 20 ++++--- .../RolfFormat/utils/image_analytics.py | 8 +-- src/datamodules/RotNet/datamodule_cropped.py | 2 +- src/datamodules/utils/dataset_predict.py | 5 +- src/datamodules/utils/misc.py | 47 ++++++++++++++-- src/tasks/RGB/semantic_segmentation.py | 54 +++++++++++++++---- .../datamodules/RGB/test_full_page_dataset.py | 2 +- 18 files changed, 161 insertions(+), 66 deletions(-) diff --git a/configs/datamodule/rolf_format_dev.yaml b/configs/datamodule/rolf_format_dev.yaml index 0c6c2c1e..52633f6e 100644 --- a/configs/datamodule/rolf_format_dev.yaml +++ b/configs/datamodule/rolf_format_dev.yaml @@ -54,6 +54,14 @@ test_specs: gt_names: "A1-MR-truthD-####.gif" range_from: 1080 range_to: 1099 + append3: + doc_dir: "../synthetic/SetA1_sizeM/layoutR/split/test/data" + doc_names: "A1-MR-page-####.jpg" + gt_dir: "../synthetic/SetA1_sizeM/layoutR/split/test/gtD" + gt_names: "A1-MR-page-####.gif" + range_from: 1080 + range_to: 1099 + image_dims: width: 640 diff --git a/configs/datamodule/rolf_format_dev_prediction.yaml b/configs/datamodule/rolf_format_dev_prediction.yaml index d5ca6151..af1e6dca 100644 --- a/configs/datamodule/rolf_format_dev_prediction.yaml +++ b/configs/datamodule/rolf_format_dev_prediction.yaml @@ -61,6 +61,11 @@ pred_file_path_list: - "/net/research-hisdoc/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-107*.jpg" - "/net/research-hisdoc/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1085.jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/synthetic/SetA1_sizeM/layoutR/split/*/data/A1-MR-page-106[0-2].jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/synthetic/SetA1_sizeM/layoutR/split/*/data/A1-MR-page-106[7,9].jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/synthetic/SetA1_sizeM/layoutR/split/*/data/A1-MR-page-107*.jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/synthetic/SetA1_sizeM/layoutR/split/*/data/A1-MR-page-1085.jpg" + image_dims: width: 640 height: 896 diff --git a/configs/experiment/dev_rgb_full_page.yaml b/configs/experiment/dev_rgb_full_page.yaml index 28d6a2f7..c16ea60c 100644 --- a/configs/experiment/dev_rgb_full_page.yaml +++ b/configs/experiment/dev_rgb_full_page.yaml @@ -44,7 +44,7 @@ task: confusion_matrix_test: True datamodule: - _target_: src.datamodules.RGB.datamodule_full_page.DataModuleRGB + _target_: src.datamodules.RGB.datamodule.DataModuleRGB data_dir: /net/research-hisdoc/datasets/semantic_segmentation/synthetic/SetA1_sizeM/layoutD/split num_workers: 4 diff --git a/src/datamodules/DivaHisDB/datamodule_cropped.py b/src/datamodules/DivaHisDB/datamodule_cropped.py index e8e95da4..86513fa9 100644 --- a/src/datamodules/DivaHisDB/datamodule_cropped.py +++ b/src/datamodules/DivaHisDB/datamodule_cropped.py @@ -78,7 +78,7 @@ def setup(self, stage: Optional[str] = None): self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) - if stage == 'test' or stage is not None: + if stage == 'test': self.test = CroppedHisDBDataset(**self._create_dataset_parameters('test'), selection=self.selection_test) log.info(f'Initialized test dataset with {len(self.test)} samples.') # self._check_min_num_samples(num_samples=len(self.test), data_split='test', diff --git a/src/datamodules/DivaHisDB/datasets/cropped_dataset.py b/src/datamodules/DivaHisDB/datasets/cropped_dataset.py index b5f3f5fb..ceb84a82 100644 --- a/src/datamodules/DivaHisDB/datasets/cropped_dataset.py +++ b/src/datamodules/DivaHisDB/datasets/cropped_dataset.py @@ -75,7 +75,7 @@ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, gt_folder_name=self.gt_folder_name, selection=self.selection) # TODO: make more fanzy stuff here - # self.img_paths = [pair for page in self.img_paths_per_page for pair in page] + # self.img_paths = [pair for page in self.img_gt_path_list for pair in page] self.num_samples = len(self.img_paths_per_page) if self.num_samples == 0: @@ -233,12 +233,12 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st sorted(path_gt_subdir.iterdir())): assert has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) == \ has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS), \ - 'get_gt_data_paths(): image file aligned with non-image file' + 'get_img_gt_path_list(): image file aligned with non-image file' if has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) and \ has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS): assert path_data_file.stem == path_gt_file.stem, \ - 'get_gt_data_paths(): mismatch between data filename and gt filename' + 'get_img_gt_path_list(): mismatch between data filename and gt filename' coordinates = re.compile(r'.+_x(\d+)_y(\d+)\.') m = coordinates.match(path_data_file.name) if m is None: diff --git a/src/datamodules/RGB/datamodule.py b/src/datamodules/RGB/datamodule.py index 1595fc97..5c27ab86 100644 --- a/src/datamodules/RGB/datamodule.py +++ b/src/datamodules/RGB/datamodule.py @@ -31,7 +31,7 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, analytics_data, analytics_gt = get_analytics(input_path=Path(data_dir), data_folder_name=self.data_folder_name, gt_folder_name=self.gt_folder_name, - get_gt_data_paths_func=DatasetRGB.get_gt_data_paths) + get_img_gt_path_list_func=DatasetRGB.get_img_gt_path_list) self.image_dims = ImageDimensions(width=analytics_data['width'], height=analytics_data['height']) self.dims = (3, self.image_dims.width, self.image_dims.height) @@ -77,7 +77,7 @@ def setup(self, stage: Optional[str] = None): self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) - if stage == 'test' or stage is not None: + if stage == 'test': self.test = DatasetRGB(**self._create_dataset_parameters('test'), selection=self.selection_test) log.info(f'Initialized test dataset with {len(self.test)} samples.') # self._check_min_num_samples(num_samples=len(self.test), data_split='test', @@ -137,7 +137,7 @@ def _create_dataset_parameters(self, dataset_type: str = 'train'): 'classes': self.class_encodings, 'is_test': is_test} - def get_img_name(self, index): + def get_output_filename_test(self, index): """ Returns the original filename of the doc image. You can just use this during testing! @@ -147,4 +147,4 @@ def get_img_name(self, index): if not hasattr(self, 'test'): raise Exception('This method can just be called during testing') - return self.test.img_paths_per_page[index][2:] + return self.test.output_file_list[index] diff --git a/src/datamodules/RGB/datamodule_cropped.py b/src/datamodules/RGB/datamodule_cropped.py index 25ab3ecd..092f5b91 100644 --- a/src/datamodules/RGB/datamodule_cropped.py +++ b/src/datamodules/RGB/datamodule_cropped.py @@ -32,7 +32,7 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, analytics_data, analytics_gt = get_analytics(input_path=Path(data_dir), data_folder_name=self.data_folder_name, gt_folder_name=self.gt_folder_name, - get_gt_data_paths_func=CroppedDatasetRGB.get_gt_data_paths) + get_img_gt_path_list_func=CroppedDatasetRGB.get_gt_data_paths) self.mean = analytics_data['mean'] self.std = analytics_data['std'] @@ -77,7 +77,7 @@ def setup(self, stage: Optional[str] = None): self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) - if stage == 'test' or stage is not None: + if stage == 'test': self.test = CroppedDatasetRGB(**self._create_dataset_parameters('test'), selection=self.selection_test) log.info(f'Initialized test dataset with {len(self.test)} samples.') # self._check_min_num_samples(num_samples=len(self.test), data_split='test', diff --git a/src/datamodules/RGB/datasets/cropped_dataset.py b/src/datamodules/RGB/datasets/cropped_dataset.py index 525e1851..a5b9122f 100644 --- a/src/datamodules/RGB/datasets/cropped_dataset.py +++ b/src/datamodules/RGB/datasets/cropped_dataset.py @@ -75,7 +75,7 @@ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, gt_folder_name=self.gt_folder_name, selection=self.selection) # TODO: make more fanzy stuff here - # self.img_paths = [pair for page in self.img_paths_per_page for pair in page] + # self.img_paths = [pair for page in self.img_gt_path_list for pair in page] self.num_samples = len(self.img_paths_per_page) if self.num_samples == 0: @@ -233,12 +233,12 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st sorted(path_gt_subdir.iterdir())): assert has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) == \ has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS), \ - 'get_gt_data_paths(): image file aligned with non-image file' + 'get_img_gt_path_list(): image file aligned with non-image file' if has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) and \ has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS): assert path_data_file.stem == path_gt_file.stem, \ - 'get_gt_data_paths(): mismatch between data filename and gt filename' + 'get_img_gt_path_list(): mismatch between data filename and gt filename' coordinates = re.compile(r'.+_x(\d+)_y(\d+)\.') m = coordinates.match(path_data_file.name) if m is None: diff --git a/src/datamodules/RGB/datasets/full_page_dataset.py b/src/datamodules/RGB/datasets/full_page_dataset.py index 40210a31..f00195c0 100644 --- a/src/datamodules/RGB/datasets/full_page_dataset.py +++ b/src/datamodules/RGB/datasets/full_page_dataset.py @@ -13,7 +13,7 @@ from torchvision.datasets.folder import pil_loader, has_file_allowed_extension from torchvision.transforms import ToTensor -from src.datamodules.utils.misc import ImageDimensions +from src.datamodules.utils.misc import ImageDimensions, get_output_file_list from src.utils import utils IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif') @@ -75,10 +75,14 @@ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, self.is_test = is_test # List of tuples that contain the path to the gt and image that belong together - self.img_paths_per_page = self.get_gt_data_paths(path, data_folder_name=self.data_folder_name, - gt_folder_name=self.gt_folder_name, selection=self.selection) + self.img_gt_path_list = self.get_img_gt_path_list(path, data_folder_name=self.data_folder_name, + gt_folder_name=self.gt_folder_name, selection=self.selection) - self.num_samples = len(self.img_paths_per_page) + if is_test: + self.image_path_list = [img_gt_path[0] for img_gt_path in self.img_gt_path_list] + self.output_file_list = get_output_file_list(image_path_list=self.image_path_list) + + self.num_samples = len(self.img_gt_path_list) if self.num_samples == 0: raise RuntimeError("Found 0 images in: {} \n Supported image extensions are: {}".format( path, ",".join(IMG_EXTENSIONS))) @@ -111,8 +115,8 @@ def _get_test_items(self, index): return img, gt, index def _load_data_and_gt(self, index): - data_img = pil_loader(self.img_paths_per_page[index][0]) - gt_img = pil_loader(self.img_paths_per_page[index][1]) + data_img = pil_loader(self.img_gt_path_list[index][0]) + gt_img = pil_loader(self.img_gt_path_list[index][1]) assert data_img.height == self.image_dims.height and data_img.width == self.image_dims.width assert gt_img.height == self.image_dims.height and gt_img.width == self.image_dims.width @@ -155,8 +159,8 @@ def _apply_transformation(self, img, gt): return img, gt @staticmethod - def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, - selection: Optional[Union[int, List[str]]] = None) \ + def get_img_gt_path_list(directory: Path, data_folder_name: str, gt_folder_name: str, + selection: Optional[Union[int, List[str]]] = None) \ -> List[Tuple[Path, Path, str]]: """ Structure of the folder @@ -230,10 +234,10 @@ def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: st assert has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) == \ has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS), \ - 'get_gt_data_paths(): image file aligned with non-image file' + 'get_img_gt_path_list(): image file aligned with non-image file' assert path_data_file.stem == path_gt_file.stem, \ - 'get_gt_data_paths(): mismatch between data filename and gt filename' + 'get_img_gt_path_list(): mismatch between data filename and gt filename' # TODO check if we need x/y paths.append((path_data_file, path_gt_file, path_data_file.stem)) diff --git a/src/datamodules/RGB/utils/image_analytics.py b/src/datamodules/RGB/utils/image_analytics.py index 4fe4c515..bfb2a853 100644 --- a/src/datamodules/RGB/utils/image_analytics.py +++ b/src/datamodules/RGB/utils/image_analytics.py @@ -16,7 +16,7 @@ from src.datamodules.utils.image_analytics import compute_mean_std -def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, get_gt_data_paths_func, **kwargs): +def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, get_img_gt_path_list_func, **kwargs): """ Parameters ---------- @@ -53,10 +53,10 @@ def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, if missing_analytics_data or missing_analytics_gt: train_path = input_path / 'train' - gt_data_path_list = get_gt_data_paths_func(train_path, data_folder_name=data_folder_name, - gt_folder_name=gt_folder_name) - file_names_data = np.asarray([str(item[0]) for item in gt_data_path_list]) - file_names_gt = np.asarray([str(item[1]) for item in gt_data_path_list]) + img_gt_path_list = get_img_gt_path_list_func(train_path, data_folder_name=data_folder_name, + gt_folder_name=gt_folder_name) + file_names_data = np.asarray([str(item[0]) for item in img_gt_path_list]) + file_names_gt = np.asarray([str(item[1]) for item in img_gt_path_list]) if missing_analytics_data: mean, std = compute_mean_std(file_names=file_names_data, **kwargs) diff --git a/src/datamodules/RolfFormat/datamodule.py b/src/datamodules/RolfFormat/datamodule.py index 97878a5c..9816bd65 100644 --- a/src/datamodules/RolfFormat/datamodule.py +++ b/src/datamodules/RolfFormat/datamodule.py @@ -34,14 +34,14 @@ def __init__(self, data_root: str, self.pred_file_path_list = pred_file_path_list if image_analytics is None or classes is None or image_dims is None: - train_paths_data_gt = DatasetRolfFormat.get_gt_data_paths(list_specs=self.train_dataset_specs) + train_paths_data_gt = DatasetRolfFormat.get_img_gt_path_list(list_specs=self.train_dataset_specs) if image_dims is None: image_dims = get_image_dims(data_gt_path_list=train_paths_data_gt) self._print_image_dims(image_dims=image_dims) if image_analytics is None: - analytics_data = get_analytics_data(data_gt_path_list=train_paths_data_gt) + analytics_data = get_analytics_data(img_gt_path_list=train_paths_data_gt) self._print_analytics_data(analytics_data=analytics_data) else: analytics_data = {'mean': [image_analytics['mean']['R'], @@ -52,7 +52,7 @@ def __init__(self, data_root: str, image_analytics['std']['B']]} if classes is None: - analytics_gt = get_analytics_gt(data_gt_path_list=train_paths_data_gt) + analytics_gt = get_analytics_gt(img_gt_path_list=train_paths_data_gt) self._print_analytics_gt(analytics_gt=analytics_gt) else: analytics_gt = {'class_encodings': [], @@ -223,7 +223,7 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: drop_last=False, pin_memory=True) - def get_img_name(self, index): + def get_output_filename_test(self, index): """ Returns the original filename of the doc image. You can just use this during testing! @@ -233,9 +233,9 @@ def get_img_name(self, index): if not hasattr(self, 'test'): raise Exception('This method can just be called during testing') - return self.test.img_paths_per_page[index][2:] + return self.test.output_file_list[index] - def get_img_name_prediction(self, index): + def get_output_filename_predict(self, index): """ Returns the original filename of the doc image. You can just use this during testing! @@ -245,4 +245,4 @@ def get_img_name_prediction(self, index): if not hasattr(self, 'predict'): raise Exception('This method can just be called during prediction') - return self.predict.image_path_list[index].stem + return self.predict.output_file_list[index] diff --git a/src/datamodules/RolfFormat/datasets/dataset.py b/src/datamodules/RolfFormat/datasets/dataset.py index 14968d57..18c9456f 100644 --- a/src/datamodules/RolfFormat/datasets/dataset.py +++ b/src/datamodules/RolfFormat/datasets/dataset.py @@ -13,7 +13,7 @@ from torchvision.datasets.folder import pil_loader from torchvision.transforms import ToTensor -from src.datamodules.utils.misc import ImageDimensions +from src.datamodules.utils.misc import ImageDimensions, get_output_file_list from src.utils import utils IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif') @@ -80,9 +80,13 @@ def __init__(self, dataset_specs: List[DatasetSpecs], image_dims: ImageDimension self.is_test = is_test # List of tuples that contain the path to the gt and image that belong together - self.img_paths_per_page = self.get_gt_data_paths(list_specs=self.dataset_specs) + self.img_gt_path_list = self.get_img_gt_path_list(list_specs=self.dataset_specs) - self.num_samples = len(self.img_paths_per_page) + if is_test: + self.image_path_list = [img_gt_path[0] for img_gt_path in self.img_gt_path_list] + self.output_file_list = get_output_file_list(image_path_list=self.image_path_list) + + self.num_samples = len(self.img_gt_path_list) assert self.num_samples > 0 @@ -111,8 +115,8 @@ def _get_test_items(self, index): return img, gt, index def _load_data_and_gt(self, index): - data_img = pil_loader(self.img_paths_per_page[index][0]) - gt_img = pil_loader(self.img_paths_per_page[index][1]) + data_img = pil_loader(str(self.img_gt_path_list[index][0])) + gt_img = pil_loader(str(self.img_gt_path_list[index][1])) assert data_img.height == self.image_dims.height and data_img.width == self.image_dims.width assert gt_img.height == self.image_dims.height and gt_img.width == self.image_dims.width @@ -158,7 +162,7 @@ def _apply_transformation(self, img, gt): def _get_paths_from_specs(data_root: str, doc_dir: str, doc_names: str, gt_dir: str, gt_names: str, - range_from: int, range_to: int): + range_from: int, range_to: int) -> List[Tuple[Path, Path]]: path_root = Path(data_root) path_doc_dir = path_root / doc_dir @@ -197,14 +201,14 @@ def _get_paths_from_specs(data_root: str, assert path_doc_file.exists() == path_gt_file.exists() if path_doc_file.exists() and path_gt_file.exists(): - paths.append((path_doc_file, path_gt_file, path_doc_file.stem)) + paths.append((path_doc_file, path_gt_file)) assert len(paths) > 0 return paths @staticmethod - def get_gt_data_paths(list_specs: List[DatasetSpecs]) -> List[Tuple[Path, Path, str]]: + def get_img_gt_path_list(list_specs: List[DatasetSpecs]) -> List[Tuple[Path, Path]]: paths = [] for specs in list_specs: diff --git a/src/datamodules/RolfFormat/utils/image_analytics.py b/src/datamodules/RolfFormat/utils/image_analytics.py index bdcc838d..1e59acb6 100644 --- a/src/datamodules/RolfFormat/utils/image_analytics.py +++ b/src/datamodules/RolfFormat/utils/image_analytics.py @@ -10,8 +10,8 @@ from src.datamodules.utils.misc import ImageDimensions -def get_analytics_data(data_gt_path_list, **kwargs): - file_names_data = np.asarray([str(item[0]) for item in data_gt_path_list]) +def get_analytics_data(img_gt_path_list, **kwargs): + file_names_data = np.asarray([str(item[0]) for item in img_gt_path_list]) mean, std = compute_mean_std(file_names=file_names_data, **kwargs) analytics_data = {'mean': mean.tolist(), @@ -20,8 +20,8 @@ def get_analytics_data(data_gt_path_list, **kwargs): return analytics_data -def get_analytics_gt(data_gt_path_list, **kwargs): - file_names_gt = np.asarray([str(item[1]) for item in data_gt_path_list]) +def get_analytics_gt(img_gt_path_list, **kwargs): + file_names_gt = np.asarray([str(item[1]) for item in img_gt_path_list]) # Measure weights for class balancing logging.info(f'Measuring class weights') diff --git a/src/datamodules/RotNet/datamodule_cropped.py b/src/datamodules/RotNet/datamodule_cropped.py index b7bf0c26..bee2a69a 100644 --- a/src/datamodules/RotNet/datamodule_cropped.py +++ b/src/datamodules/RotNet/datamodule_cropped.py @@ -69,7 +69,7 @@ def setup(self, stage: Optional[str] = None): self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) - if stage == 'test' or stage is not None: + if stage == 'test': self.test = CroppedRotNet(**self._create_dataset_parameters('test'), selection=self.selection_test) log.info(f'Initialized test dataset with {len(self.test)} samples.') # self._check_min_num_samples(num_samples=len(self.test), data_split='test', diff --git a/src/datamodules/utils/dataset_predict.py b/src/datamodules/utils/dataset_predict.py index 017543ad..1015cd91 100644 --- a/src/datamodules/utils/dataset_predict.py +++ b/src/datamodules/utils/dataset_predict.py @@ -7,7 +7,7 @@ from torchvision.datasets.folder import pil_loader from torchvision.transforms import ToTensor -from src.datamodules.utils.misc import ImageDimensions +from src.datamodules.utils.misc import ImageDimensions, get_output_file_list from src.utils import utils log = utils.get_logger(__name__) @@ -29,6 +29,7 @@ def __init__(self, image_path_list: List[str], image_dims: ImageDimensions, self._raw_image_path_list = list(image_path_list) self.image_path_list = self.expend_glob_path_list(glob_path_list=self._raw_image_path_list) + self.output_file_list = get_output_file_list(image_path_list=self.image_path_list) self.image_dims = image_dims @@ -97,3 +98,5 @@ def expend_glob_path_list(glob_path_list: List[str]) -> List[Path]: if path not in output_list: output_list.append(Path(s)) return output_list + + diff --git a/src/datamodules/utils/misc.py b/src/datamodules/utils/misc.py index cda99bd1..c5234747 100644 --- a/src/datamodules/utils/misc.py +++ b/src/datamodules/utils/misc.py @@ -1,11 +1,20 @@ from dataclasses import dataclass from pathlib import Path -from typing import Union +from typing import Union, List import numpy as np import torch from src.datamodules.utils.exceptions import PathNone, PathNotDir, PathMissingSplitDir, PathMissingDirinSplitDir +from src.utils import utils + +log = utils.get_logger(__name__) + + +@dataclass +class ImageDimensions: + width: int + height: int def _get_argmax(output: Union[torch.Tensor, np.ndarray], dim=1): @@ -70,7 +79,35 @@ def validate_path_for_segmentation(data_dir, data_folder_name: str, gt_folder_na return Path(data_dir) -@dataclass -class ImageDimensions: - width: int - height: int \ No newline at end of file +def get_output_file_list(image_path_list: List[Path]) -> List[str]: + duplicate_filenames = [] + output_list = [] + for p in image_path_list: + filename = p.stem + if filename not in output_list: + output_list.append(filename) + else: + duplicate_filenames.append(filename) + new_filename = find_new_filename(filename=filename, current_list=output_list) + assert new_filename is not None and len(new_filename) > 0 + assert new_filename not in output_list + output_list.append(new_filename) + + assert len(image_path_list) == len(output_list) + + if len(duplicate_filenames) > 0: + log.warn(f'Duplicate filenames in output list. ' + f'Output filenames have been changed to be unique. Duplicates:\n' + f'{duplicate_filenames}') + + return output_list + + +def find_new_filename(filename: str, current_list: List[str]) -> str: + for i in range(len(current_list)): + new_filename = f'{filename}_{i}' + if new_filename not in current_list: + return new_filename + else: + log.error('Unexpected error: Did not find new filename that is not a duplicate!') + raise AssertionError \ No newline at end of file diff --git a/src/tasks/RGB/semantic_segmentation.py b/src/tasks/RGB/semantic_segmentation.py index a7613b5d..9e82678c 100644 --- a/src/tasks/RGB/semantic_segmentation.py +++ b/src/tasks/RGB/semantic_segmentation.py @@ -1,10 +1,11 @@ from pathlib import Path -from typing import Optional, Callable, Union, Any +from typing import Optional, Callable, Union, Any, List import numpy as np import torch.nn as nn import torch.optim import torchmetrics +from pytorch_lightning.utilities import rank_zero_only from src.datamodules.RGB.utils.output_tools import save_output_page_image from src.datamodules.utils.misc import _get_argmax @@ -58,8 +59,8 @@ def __init__(self, def setup(self, stage: str) -> None: super().setup(stage) - if not hasattr(self.trainer.datamodule, 'get_img_name'): - raise NotImplementedError('DataModule needs to implement get_img_name function') + if not hasattr(self.trainer.datamodule, 'get_output_filename_test'): + raise NotImplementedError('DataModule needs to implement get_output_filename_test function') log.info("Setup done!") @@ -91,17 +92,28 @@ def validation_step(self, batch, batch_idx, **kwargs): ########################################### TEST ############################################ ############################################################################################# + @rank_zero_only + def on_test_start(self) -> None: + # print output file list + dataset = self.trainer.datamodule.test + output_path = self.test_output_path + info_filename = 'info_file_mapping.txt' + + self.write_file_mapping(output_file_list=dataset.output_file_list, + image_path_list=dataset.image_path_list, + output_path=output_path, + info_filename=info_filename) + def test_step(self, batch, batch_idx, **kwargs): input_batch, target_batch, input_idx = batch output = super().test_step(batch=(input_batch, target_batch), batch_idx=batch_idx) - if not hasattr(self.trainer.datamodule, 'get_img_name'): - raise NotImplementedError('Datamodule does not provide detailed information of the crop') + if not hasattr(self.trainer.datamodule, 'get_output_filename_test'): + raise NotImplementedError('Datamodule does not provide output info for test') for pred_raw, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), input_idx.detach().cpu().numpy()): - patch_info = self.trainer.datamodule.get_img_name(idx) - img_name = patch_info[0] + img_name = self.trainer.datamodule.get_output_filename_test(idx) dest_folder = self.test_output_path / 'pred_raw' dest_folder.mkdir(parents=True, exist_ok=True) dest_filename = dest_folder / f'{img_name}.npy' @@ -121,16 +133,28 @@ def on_test_end(self) -> None: ######################################### PREDICT ########################################### ############################################################################################# + @rank_zero_only + def on_predict_start(self) -> None: + # print output file list + dataset = self.trainer.datamodule.predict + output_path = self.predict_output_path + info_filename = 'info_file_mapping.txt' + + self.write_file_mapping(output_file_list=dataset.output_file_list, + image_path_list=dataset.image_path_list, + output_path=output_path, + info_filename=info_filename) + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: input_batch, input_idx = batch output = super().predict_step(batch=input_batch, batch_idx=batch_idx, dataloader_idx=dataloader_idx) - if not hasattr(self.trainer.datamodule, 'get_img_name'): - raise NotImplementedError('Datamodule does not provide detailed information of the crop') + if not hasattr(self.trainer.datamodule, 'get_output_filename_predict'): + raise NotImplementedError('Datamodule does not provide output info for predict') for pred_raw, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), input_idx.detach().cpu().numpy()): - img_name = self.trainer.datamodule.get_img_name_prediction(idx) + img_name = self.trainer.datamodule.get_output_filename_predict(idx) dest_folder = self.predict_output_path / 'pred_raw' dest_folder.mkdir(parents=True, exist_ok=True) dest_filename = dest_folder / f'{img_name}.npy' @@ -142,3 +166,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] output_folder=dest_folder, class_encoding=self.trainer.datamodule.class_encodings) return reduce_dict(input_dict=output, key_list=[]) + + @staticmethod + def write_file_mapping(output_file_list: List[str], image_path_list: List[Path], + output_path: Path, info_filename: str): + assert len(output_file_list) == len(image_path_list) + output_path.mkdir(parents=True, exist_ok=True) + output_info_file = output_path / info_filename + with output_info_file.open('w') as f: + for output_filename, image_path in zip(output_file_list, image_path_list): + f.write(f'{output_filename}\t{image_path}\n') \ No newline at end of file diff --git a/tests/datamodules/RGB/test_full_page_dataset.py b/tests/datamodules/RGB/test_full_page_dataset.py index bfc97363..5c68c133 100644 --- a/tests/datamodules/RGB/test_full_page_dataset.py +++ b/tests/datamodules/RGB/test_full_page_dataset.py @@ -11,7 +11,7 @@ def dataset_train(data_dir): def test_get_gt_data_paths(data_dir): - file_list = DatasetRGB.get_gt_data_paths(directory=data_dir / 'train', data_folder_name='data', gt_folder_name='gt') + file_list = DatasetRGB.get_img_gt_path_list(directory=data_dir / 'train', data_folder_name='data', gt_folder_name='gt') assert len(file_list) == 1 assert file_list[0] == (data_dir / 'train' / 'data' / 'e-codices_fmb-cb-0055_0098v_max.jpg', data_dir / 'train' / 'gt' / 'e-codices_fmb-cb-0055_0098v_max.png', From 24f25530e4c59165699f1245f6f1ae0d251bab61 Mon Sep 17 00:00:00 2001 From: Paul M Date: Thu, 25 Nov 2021 12:41:38 +0100 Subject: [PATCH 106/108] :sparkles: :art: :wrench: predict implemented for RGB datamodule (full page) --- .../experiment/dev_rgb_full_page_predict.yaml | 80 +++++++++++++++++++ src/datamodules/RGB/datamodule.py | 71 ++++++++++++---- 2 files changed, 135 insertions(+), 16 deletions(-) create mode 100644 configs/experiment/dev_rgb_full_page_predict.yaml diff --git a/configs/experiment/dev_rgb_full_page_predict.yaml b/configs/experiment/dev_rgb_full_page_predict.yaml new file mode 100644 index 00000000..64f35df5 --- /dev/null +++ b/configs/experiment/dev_rgb_full_page_predict.yaml @@ -0,0 +1,80 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py +experiment=exp_example_full + +defaults: + - /plugins: default.yaml + - /task: semantic_segmentation_RGB.yaml + - /loss: crossentropyloss.yaml + - /metric: iou.yaml + - /model/backbone: baby_unet_model.yaml + - /model/header: identity.yaml + - /optimizer: adam.yaml + - /callbacks: + - check_compatibility.yaml + - model_checkpoint.yaml + - watch_model_wandb.yaml + - /logger: + - wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) + - csv.yaml + +# we override default configurations with nulls to prevent them from loading at all +# instead we define all modules and their paths directly in this config, +# so everything is stored in one place for more readibility + +seed: 42 + +train: False +test: False +predict: True + +model: + backbone: + path_to_weights: /net/research-hisdoc/experiments_lars_paul/paul/2021-11-25/12-32-04/checkpoints/epoch=1/backbone.pth + +trainer: + _target_: pytorch_lightning.Trainer + gpus: -1 + accelerator: 'ddp' + min_epochs: 1 + max_epochs: 2 + weights_summary: full + precision: 16 + +task: + confusion_matrix_log_every_n_epoch: 1 + confusion_matrix_val: True + confusion_matrix_test: True + +datamodule: + _target_: src.datamodules.RGB.datamodule.DataModuleRGB + + data_dir: /net/research-hisdoc/datasets/semantic_segmentation/synthetic/SetA1_sizeM/layoutD/split + num_workers: 4 + batch_size: 2 + shuffle: True + drop_last: True + data_folder_name: data + gt_folder_name: gtD + + pred_file_path_list: + - "/net/research-hisdoc/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-106[0-2].jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-106[7,9].jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-107*.jpg" + - "/net/research-hisdoc/datasets/semantic_segmentation/rolf_format/SetA1_sizeM_Rolf/layoutR/data/A1-MR-page-1085.jpg" + +callbacks: + model_checkpoint: + monitor: "val/iou" + mode: "max" + filename: ${checkpoint_folder_name}dev-baby-unet-rgb-data + watch_model: + log_freq: 1 + +logger: + wandb: + name: 'dev-RGB-full-page' + tags: [ "best_model", "synthetic", "RGB", "Rolf", "full_page" ] + group: 'dev-runs' + notes: "Testing" diff --git a/src/datamodules/RGB/datamodule.py b/src/datamodules/RGB/datamodule.py index 5c27ab86..2835322d 100644 --- a/src/datamodules/RGB/datamodule.py +++ b/src/datamodules/RGB/datamodule.py @@ -9,6 +9,7 @@ from src.datamodules.RGB.utils.image_analytics import get_analytics from src.datamodules.RGB.utils.single_transform import IntegerEncoding from src.datamodules.base_datamodule import AbstractDatamodule +from src.datamodules.utils.dataset_predict import DatasetPredict from src.datamodules.utils.misc import validate_path_for_segmentation from src.datamodules.utils.wrapper_transforms import OnlyImage, OnlyTarget from src.utils import utils @@ -18,6 +19,7 @@ class DataModuleRGB(AbstractDatamodule): def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, + pred_file_path_list: List[str] = None, selection_train: Optional[Union[int, List[str]]] = None, selection_val: Optional[Union[int, List[str]]] = None, selection_test: Optional[Union[int, List[str]]] = None, @@ -28,6 +30,9 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, self.data_folder_name = data_folder_name self.gt_folder_name = gt_folder_name + if pred_file_path_list is not None: + self.pred_file_path_list = pred_file_path_list + analytics_data, analytics_gt = get_analytics(input_path=Path(data_dir), data_folder_name=self.data_folder_name, gt_folder_name=self.gt_folder_name, @@ -66,22 +71,48 @@ def __init__(self, data_dir: str, data_folder_name: str, gt_folder_name: str, def setup(self, stage: Optional[str] = None): super().setup() + + common_kwargs = {'classes': self.class_encodings, + 'image_dims': self.image_dims, + 'image_transform': self.image_transform, + 'target_transform': self.target_transform, + 'twin_transform': self.twin_transform} + + dataset_kwargs = {'data_folder_name': self.data_folder_name, + 'gt_folder_name': self.gt_folder_name} + if stage == 'fit' or stage is None: - self.train = DatasetRGB(**self._create_dataset_parameters('train'), selection=self.selection_train) + self.train = DatasetRGB(path=self.data_dir / 'train', + selection=self.selection_train, + is_test=False, + **dataset_kwargs, + **common_kwargs) log.info(f'Initialized train dataset with {len(self.train)} samples.') self._check_min_num_samples(num_samples=len(self.train), data_split='train', drop_last=self.drop_last) - self.val = DatasetRGB(**self._create_dataset_parameters('val'), selection=self.selection_val) + self.val = DatasetRGB(path=self.data_dir / 'val', + selection=self.selection_val, + is_test=False, + **dataset_kwargs, + **common_kwargs) log.info(f'Initialized val dataset with {len(self.val)} samples.') self._check_min_num_samples(num_samples=len(self.val), data_split='val', drop_last=self.drop_last) if stage == 'test': - self.test = DatasetRGB(**self._create_dataset_parameters('test'), selection=self.selection_test) + self.test = DatasetRGB(path=self.data_dir / 'test', + selection=self.selection_test, + is_test=True, + **dataset_kwargs, + **common_kwargs) log.info(f'Initialized test dataset with {len(self.test)} samples.') - # self._check_min_num_samples(num_samples=len(self.test), data_split='test', - # drop_last=False) + + if stage == 'predict': + self.predict = DatasetPredict(image_path_list=self.pred_file_path_list, + is_test=False, + **common_kwargs) + log.info(f'Initialized predict dataset with {len(self.predict)} samples.') def _check_min_num_samples(self, num_samples: int, data_split: str, drop_last: bool): num_processes = self.trainer.num_processes @@ -125,17 +156,13 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] drop_last=False, pin_memory=True) - def _create_dataset_parameters(self, dataset_type: str = 'train'): - is_test = dataset_type == 'test' - return {'path': self.data_dir / dataset_type, - 'data_folder_name': self.data_folder_name, - 'gt_folder_name': self.gt_folder_name, - 'image_dims': self.image_dims, - 'image_transform': self.image_transform, - 'target_transform': self.target_transform, - 'twin_transform': self.twin_transform, - 'classes': self.class_encodings, - 'is_test': is_test} + def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + return DataLoader(self.predict, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + drop_last=False, + pin_memory=True) def get_output_filename_test(self, index): """ @@ -148,3 +175,15 @@ def get_output_filename_test(self, index): raise Exception('This method can just be called during testing') return self.test.output_file_list[index] + + def get_output_filename_predict(self, index): + """ + Returns the original filename of the doc image. + You can just use this during testing! + :param index: + :return: + """ + if not hasattr(self, 'predict'): + raise Exception('This method can just be called during prediction') + + return self.predict.output_file_list[index] From 1e1e6689c5bd9bbc4449d3dabcdc9b03814fae9e Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Thu, 25 Nov 2021 12:46:23 +0100 Subject: [PATCH 107/108] :recycle: :white_check_mark: refactored tests and added test for classification task --- .../DivaHisDB/utils}/__init__.py | 0 .../{ => utils}/test_image_analytics.py | 0 .../DivaHisDB/utils}/test_output_tools.py | 0 .../{DivaHisDB => util}/test_misc.py | 0 tests/metrics/__init__.py | 0 .../sem_seg => metrics}/test_accuracy.py | 0 tests/tasks/DivaHisDB/__init__.py | 0 .../test_semantic_segmentation.py | 44 +++++++++----- tests/tasks/classification/__init__.py | 0 .../classification/test_classification.py | 60 +++++++++++++++++++ 10 files changed, 90 insertions(+), 14 deletions(-) rename tests/{tasks/sem_seg => datamodules/DivaHisDB/utils}/__init__.py (100%) rename tests/datamodules/DivaHisDB/{ => utils}/test_image_analytics.py (100%) rename tests/{tasks/sem_seg => datamodules/DivaHisDB/utils}/test_output_tools.py (100%) rename tests/datamodules/{DivaHisDB => util}/test_misc.py (100%) create mode 100644 tests/metrics/__init__.py rename tests/{tasks/sem_seg => metrics}/test_accuracy.py (100%) create mode 100644 tests/tasks/DivaHisDB/__init__.py rename tests/tasks/{sem_seg => DivaHisDB}/test_semantic_segmentation.py (54%) create mode 100644 tests/tasks/classification/__init__.py create mode 100644 tests/tasks/classification/test_classification.py diff --git a/tests/tasks/sem_seg/__init__.py b/tests/datamodules/DivaHisDB/utils/__init__.py similarity index 100% rename from tests/tasks/sem_seg/__init__.py rename to tests/datamodules/DivaHisDB/utils/__init__.py diff --git a/tests/datamodules/DivaHisDB/test_image_analytics.py b/tests/datamodules/DivaHisDB/utils/test_image_analytics.py similarity index 100% rename from tests/datamodules/DivaHisDB/test_image_analytics.py rename to tests/datamodules/DivaHisDB/utils/test_image_analytics.py diff --git a/tests/tasks/sem_seg/test_output_tools.py b/tests/datamodules/DivaHisDB/utils/test_output_tools.py similarity index 100% rename from tests/tasks/sem_seg/test_output_tools.py rename to tests/datamodules/DivaHisDB/utils/test_output_tools.py diff --git a/tests/datamodules/DivaHisDB/test_misc.py b/tests/datamodules/util/test_misc.py similarity index 100% rename from tests/datamodules/DivaHisDB/test_misc.py rename to tests/datamodules/util/test_misc.py diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tasks/sem_seg/test_accuracy.py b/tests/metrics/test_accuracy.py similarity index 100% rename from tests/tasks/sem_seg/test_accuracy.py rename to tests/metrics/test_accuracy.py diff --git a/tests/tasks/DivaHisDB/__init__.py b/tests/tasks/DivaHisDB/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tasks/sem_seg/test_semantic_segmentation.py b/tests/tasks/DivaHisDB/test_semantic_segmentation.py similarity index 54% rename from tests/tasks/sem_seg/test_semantic_segmentation.py rename to tests/tasks/DivaHisDB/test_semantic_segmentation.py index c33d86b3..808639a2 100644 --- a/tests/tasks/sem_seg/test_semantic_segmentation.py +++ b/tests/tasks/DivaHisDB/test_semantic_segmentation.py @@ -1,6 +1,7 @@ import os import numpy as np +import pytest import pytorch_lightning as pl import torch.optim.optimizer from omegaconf import OmegaConf @@ -12,35 +13,50 @@ from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped -def test_semantic_segmentation(data_dir_cropped, tmp_path): +@pytest.fixture(autouse=True) +def clear_resolvers(): OmegaConf.clear_resolvers() seed_everything(42) + os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' + + +@pytest.fixture() +def model(): + return UNet(num_classes=4, num_layers=2, features_start=32) + +@pytest.fixture() +def datamodule_and_dir(data_dir_cropped): # datamodule data_module = DivaHisDBDataModuleCropped( data_dir=str(data_dir_cropped), data_folder_name='data', gt_folder_name='gt', batch_size=2, num_workers=2) + return data_module, data_dir_cropped - def baby_unet(): - return UNet(num_classes=len(data_module.class_encodings), num_layers=2, features_start=32) - model = baby_unet() - segmentation = SemanticSegmentationCroppedHisDB(model=model, - optimizer=torch.optim.Adam(params=model.parameters()), - loss_fn=torch.nn.CrossEntropyLoss(), - test_output_path=tmp_path, - confusion_matrix_val=True - ) +@pytest.fixture() +def task(model, tmp_path): + task = SemanticSegmentationCroppedHisDB(model=model, + optimizer=torch.optim.Adam(params=model.parameters()), + loss_fn=torch.nn.CrossEntropyLoss(), + test_output_path=tmp_path, + confusion_matrix_val=True + ) + return task + + +def test_semantic_segmentation(tmp_path, task, datamodule_and_dir): + data_module, data_dir_cropped = datamodule_and_dir # different paths needed later - patches_path = segmentation.test_output_path / 'patches' + patches_path = task.test_output_path / 'patches' test_data_patch = data_dir_cropped / 'test' / 'data' - os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' - trainer = pl.Trainer(max_epochs=2, precision=32, default_root_dir=segmentation.test_output_path, accelerator='ddp_cpu') + trainer = pl.Trainer(max_epochs=2, precision=32, default_root_dir=task.test_output_path, + accelerator='ddp_cpu') - trainer.fit(segmentation, datamodule=data_module) + trainer.fit(task, datamodule=data_module) results = trainer.test() print(results) diff --git a/tests/tasks/classification/__init__.py b/tests/tasks/classification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tasks/classification/test_classification.py b/tests/tasks/classification/test_classification.py new file mode 100644 index 00000000..6cc13d34 --- /dev/null +++ b/tests/tasks/classification/test_classification.py @@ -0,0 +1,60 @@ +import os + +import numpy as np +import pytest +import pytorch_lightning as pl +import torch.optim.optimizer +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything + +from src.datamodules.RotNet.datamodule_cropped import RotNetDivaHisDBDataModuleCropped +from src.models.backbones.baby_cnn import CNN_basic +from src.models.headers.fully_connected import SingleLinear +from src.tasks.classification.classification import Classification +from tests.test_data.dummy_data_hisdb.dummy_data import data_dir_cropped + + +@pytest.fixture(autouse=True) +def clear_resolvers(): + OmegaConf.clear_resolvers() + seed_everything(42) + os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' + + +@pytest.fixture() +def model(): + return torch.nn.Sequential(CNN_basic(), SingleLinear()) + + +@pytest.fixture() +def datamodule_and_dir(data_dir_cropped): + # datamodule + data_module = RotNetDivaHisDBDataModuleCropped( + data_dir=str(data_dir_cropped), + data_folder_name='data', + batch_size=2, num_workers=2) + return data_module, data_dir_cropped + + +@pytest.fixture() +def task(model, tmp_path): + task = Classification(model=model, + optimizer=torch.optim.Adam(params=model.parameters()), + loss_fn=torch.nn.CrossEntropyLoss(), + confusion_matrix_val=True + ) + return task + + +def test_classification(tmp_path, task, datamodule_and_dir): + data_module, _ = datamodule_and_dir + + trainer = pl.Trainer(max_epochs=2, precision=32, default_root_dir=task.test_output_path, + accelerator='ddp_cpu') + + trainer.fit(task, datamodule=data_module) + + results = trainer.test() + print(results) + assert np.isclose(results[0]['test/crossentropyloss'], 1.5777363777160645, rtol=2e-03) + assert np.isclose(results[0]['test/crossentropyloss_epoch'], 1.5777363777160645, rtol=2e-03) From 660dc4621c04669306b427bed56e68e7206ad66b Mon Sep 17 00:00:00 2001 From: Lars Voegtlin Date: Thu, 25 Nov 2021 14:22:44 +0100 Subject: [PATCH 108/108] :construction_worker: exclude the test folder for the test coverage --- .github/workflows/ci-testing.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 4e790563..354d81f7 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -63,13 +63,13 @@ jobs: - name: Statistics if: success() run: | - coverage report -m + coverage report -m --omit="*/test*" - name: Produce statistics html if: success() run: | - coverage html - coverage xml + coverage html --omit="*/test*" + coverage xml --omit="*/test*" - name: Code coverage results upload if: success()