Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unequal augmentations within load_dataloader for 'disnet' training, in train_dis.py #5

Open
Cmonsta6 opened this issue Oct 13, 2024 · 2 comments

Comments

@Cmonsta6
Copy link

I have been using your very helpful script to train GT_Encoder, first time using Lightning and it's really great.

I was about to start training Disnet and I noticed something.

    if args.train_type == 'disnet':
        from utils.isnet_dataset import Dataset
        from utils.augmentation import RandomBlur
        tr_ds = Dataset(image_path=args.tr_im_path, gt_path=args.tr_gt_path,
                        image_transform=image_transform,
                        gt_transform=mask_transform,
                        random_blur=None,
                        load_on_mem=args.load_data_on_mem)

image_transform and gt_transform are using different augmentations.

def load_dataloader(args):    
        
    mask_transform = A.Compose([
        # A.Resize(width=args.input_size, height=args.input_size),
        A.RandomCrop(width=1024, height=1024),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.8),
        A.RandomRotate90(p=0.8)
    ])

    image_transform = A.Compose([
        A.CLAHE(p=0.8),
        A.RandomBrightnessContrast(p=0.8),
        A.RandomGamma(p=0.8)]
    )

This would random crop and rotate/flip the masks, but not the images.

@Cmonsta6
Copy link
Author

Also,

parser.add_argument('--input_size', type=int, default=1280)

The argument value is not used in the augmentations. Images and masks during training are not resized. Validation set is resized, but not to the --input_size value.


    vd_transform = A.Compose([
        A.Resize(width=1024, height=1024)
    ])

@Cmonsta6
Copy link
Author

Cmonsta6 commented Oct 13, 2024

I don't understand how pull requests work, functionally, nor do I think it would be appropriate given the amount of mods I have done beyond this.

The original code used bilinear for validation images and masks.

    vd_transform = A.Compose([
        A.Resize(width=1024, height=1024)
    ])

I changed to using nearest resampling for masks for disnet and gt_encoder trainer to maintain the hard edges, and area resampling for rgb images. Area resampling does a slightly better job if the downsampling step is big, and if not then it just looks the same as bilinear.

That noted, here is my solution to the issues

class CustomResizeTransform:
    def __init__(self, input_size, is_mask=False):
        # For images, use INTER_AREA; for masks, use INTER_NEAREST
        self.image_resize = A.Resize(width=input_size, height=input_size, interpolation=cv2.INTER_AREA)
        self.mask_resize = A.Resize(width=input_size, height=input_size, interpolation=cv2.INTER_NEAREST)
        self.is_mask = is_mask

    def __call__(self, image, mask=None):
        if self.is_mask:
            # When the 'image' is a mask (GT_Encoder), apply nearest-neighbor resampling
            image_resized = self.mask_resize(image=image)['image']
            return {'image': image_resized}
        else:
            # For DISNet, resize both image and mask using different interpolations
            image_resized = self.image_resize(image=image)['image']
            if mask is not None:
                mask_resized = self.mask_resize(image=mask)['image']  # Apply nearest-neighbor for mask
                return {'image': image_resized, 'mask': mask_resized}
            return {'image': image_resized}

def load_dataloader(args):    
    mask_transform = A.Compose([
        A.Resize(width=args.input_size, height=args.input_size, interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.8),
        A.RandomRotate90(p=0.8)
    ])

    duo_transform = A.Compose([
        A.Resize(width=args.input_size, height=args.input_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.8),
        A.RandomRotate90(p=0.8),
        A.CLAHE(p=0.8),  # Only affects the image, not the mask
        A.RandomBrightnessContrast(p=0.8),  # Only affects the image
        A.RandomGamma(p=0.8)  # Only affects the image
    ], additional_targets={'mask': 'image'})  # Ensures same spatial transformations for masks

    # Custom resizing for validation with different interpolation methods for images and masks (DISNet)
    vd_transform = CustomResizeTransform(input_size=args.input_size)
    
    if args.train_type == 'disnet':
        # DISNet dataset (image + mask)
        from utils.isnet_dataset import Dataset
        from utils.augmentation import RandomBlur
        tr_ds = Dataset(image_path=args.tr_im_path, gt_path=args.tr_gt_path,
                        transform=duo_transform,
                        random_blur=None,
                        load_on_mem=args.load_data_on_mem)
        vd_ds = Dataset(image_path=args.vd_im_path, gt_path=args.vd_gt_path,
                        transform=vd_transform,  # Use unified custom transform for validation
                        load_on_mem=args.load_data_on_mem)
    else:
        # GT_Encoder dataset (same image and gt, where 'image' is a mask)
        from utils.gt_dataset import Dataset
        # GT_Encoder: Apply augmentations only to the training set (same for image and gt)
        tr_ds = Dataset(image_path=args.tr_gt_path, transform=mask_transform)
        
        # For validation, apply nearest-neighbor resizing (since 'images' are masks)
        vd_transform = CustomResizeTransform(input_size=args.input_size, is_mask=True)
        
        # Use the resized transform for validation in the GT_Encoder case
        vd_ds = Dataset(image_path=args.vd_gt_path, transform=vd_transform)
    
    tr_dl = DataLoader(tr_ds, args.batch_size, shuffle=True, num_workers=8)
    vd_dl = DataLoader(vd_ds, args.batch_size, shuffle=False, num_workers=4)
    
    return tr_dl, vd_dl

My edit requires a changes to /utils/isnet_dataset.py too.

import os
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from glob import glob
import cv2
import numpy as np
import albumentations as A
from tqdm import tqdm
from torchvision import transforms as T

class Dataset(Dataset):
    def __init__(self, image_path='../../data/DIS5K/DIS-TR/im', gt_path='../../data/DIS5K/DIS-TR/gt',
                 transform=None, load_on_mem=False, random_blur=None):
        self.images = sorted(glob(os.path.join(image_path, '*.jpg')))
        self.gts = sorted(glob(os.path.join(gt_path, '*.png')))
        
        self.transform = transform  # Unified transformation for both image and mask
        self.random_blur = random_blur
        
        print(f'images : {len(self.images)}')
        print(f'gts : {len(self.gts)}')
        
        self.load_on_mem = load_on_mem
        if self.load_on_mem:
            self.load_data()

    def __len__(self):
        return len(self.gts)

    def load_data(self):
        self.im_lst = []
        self.gt_lst = []
        for im, gt in tqdm(zip(self.images, self.gts), total=self.__len__()):
            image, gt = cv2.imread(im), cv2.imread(gt, cv2.IMREAD_GRAYSCALE)
            self.im_lst.append(image)
            self.gt_lst.append(gt)

    def _transform(self, image, gt):
        # Apply spatial transformations (flip, rotate, resize) to both image and mask
        if self.transform:
            transformed = self.transform(image=image, mask=gt)  # Same transform for both
            image, gt = transformed['image'], transformed['mask']

        # Blur the image only if specified (optional and independent of spatial transformations)
        if self.random_blur:
            image = self.random_blur()(image=image)['image']  # Blur only the image, not the mask

        # Threshold and normalize the mask (binary values 0 and 1)
        gt = (gt > 128).astype(np.float32)

        # Normalize image pixel values to [0, 1]
        image = (image / 255.0).astype(np.float32)

        # Convert to tensor
        image = transforms.ToTensor()(image)
        gt = transforms.ToTensor()(gt)
        
        return image, gt

    def __getitem__(self, idx):
        if self.load_on_mem:
            image, gt = self.im_lst[idx], self.gt_lst[idx]
        else:
            image, gt = cv2.imread(self.images[idx]), cv2.imread(self.gts[idx], cv2.IMREAD_GRAYSCALE)
        
        # Apply spatial transformations (flip, rotate, resize) and optional blur
        image, gt = self._transform(image, gt)

        return {'image': image, 'gt': gt}



If that doesn't work my humble apologies, for I am but a humble glue eater slapping code together.
Oh and for myself at least, --input_size has to be power of 2 sizes, otherwise there's dimensional mismatch.

x = torch.cat([out, skip_x], dim=1)  # dim 1 is the channel dimension

RuntimeError: Sizes of tensors must match except in dimension 1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant