diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index 4eefdbb..bc97154 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -1,7 +1,6 @@ """This module implements data pipeline blocks for augmentation operations.""" from typing import Any, Dict, Optional, Tuple, Union, Iterator - import kornia as K import torch from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D @@ -11,6 +10,227 @@ from torch.utils.data.datapipes.datapipe import IterDataPipe +def apply_intensity_augmentation( + image: torch.Tensor, + instances: torch.Tensor, + uniform_noise_min: Optional[float] = 0.0, + uniform_noise_max: Optional[float] = 0.04, + uniform_noise_p: float = 0.0, + gaussian_noise_mean: Optional[float] = 0.02, + gaussian_noise_std: Optional[float] = 0.004, + gaussian_noise_p: float = 0.0, + contrast_min: Optional[float] = 0.5, + contrast_max: Optional[float] = 2.0, + contrast_p: float = 0.0, + brightness: Optional[float] = 0.0, + brightness_p: float = 0.0, +) -> Tuple[torch.Tensor]: + """Apply kornia intensity augmentation on image and instances. + + Args: + image: Input image. Shape: (n_samples, C, H, W) + instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2) + uniform_noise_min: Minimum value for uniform noise (uniform_noise_min >=0). + uniform_noise_max: Maximum value for uniform noise (uniform_noise_max <=1). + uniform_noise_p: Probability of applying random uniform noise. + gaussian_noise_mean: The mean of the gaussian distribution. + gaussian_noise_std: The standard deviation of the gaussian distribution. + gaussian_noise_p: Probability of applying random gaussian noise. + contrast_min: Minimum contrast factor to apply. Default: 0.5. + contrast_max: Maximum contrast factor to apply. Default: 2.0. + contrast_p: Probability of applying random contrast. + brightness: The brightness factor to apply Default: 0.0. + brightness_p: Probability of applying random brightness. + + Returns: + Returns tuple: (image, instances) with augmentation applied. + """ + aug_stack = [] + if uniform_noise_p > 0: + aug_stack.append( + RandomUniformNoise( + noise=(uniform_noise_min, uniform_noise_max), + p=uniform_noise_p, + keepdim=True, + same_on_batch=True, + ) + ) + if gaussian_noise_p > 0: + aug_stack.append( + K.augmentation.RandomGaussianNoise( + mean=gaussian_noise_mean, + std=gaussian_noise_std, + p=gaussian_noise_p, + keepdim=True, + same_on_batch=True, + ) + ) + if contrast_p > 0: + aug_stack.append( + K.augmentation.RandomContrast( + contrast=(contrast_min, contrast_max), + p=contrast_p, + keepdim=True, + same_on_batch=True, + ) + ) + if brightness_p > 0: + aug_stack.append( + K.augmentation.RandomBrightness( + brightness=brightness, + p=brightness_p, + keepdim=True, + same_on_batch=True, + ) + ) + + augmenter = AugmentationSequential( + *aug_stack, + data_keys=["input", "keypoints"], + keepdim=True, + same_on_batch=True, + ) + + inst_shape = instances.shape + # Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2) + # or + # Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2) + instances = instances.reshape(inst_shape[0], -1, 2) + # (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2) + + aug_image, aug_instances = augmenter(image, instances) + + # After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2) + # or + # After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2) + return aug_image, aug_instances.reshape(*inst_shape) + + +def apply_geometric_augmentation( + image: torch.Tensor, + instances: torch.Tensor, + rotation: Optional[float] = 15.0, + scale: Union[ + Optional[float], Tuple[float, float], Tuple[float, float, float, float] + ] = None, + translate_width: Optional[float] = 0.02, + translate_height: Optional[float] = 0.02, + affine_p: float = 0.0, + erase_scale_min: Optional[float] = 0.0001, + erase_scale_max: Optional[float] = 0.01, + erase_ratio_min: Optional[float] = 1, + erase_ratio_max: Optional[float] = 1, + erase_p: float = 0.0, + mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None, + mixup_p: float = 0.0, + random_crop_height: int = 0, + random_crop_width: int = 0, + random_crop_p: float = 0.0, +) -> Tuple[torch.Tensor]: + """Apply kornia geometric augmentation on image and instances. + + Args: + image: Input image. Shape: (n_samples, C, H, W) + instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2) + rotation: Angles in degrees as a scalar float of the amount of rotation. A + random angle in `(-rotation, rotation)` will be sampled and applied to both + images and keypoints. Set to 0 to disable rotation augmentation. + scale: scaling factor interval. If (a, b) represents isotropic scaling, the scale + is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale + is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d. + Default: None. + translate_width: Maximum absolute fraction for horizontal translation. For example, + if translate_width=a, then horizontal shift is randomly sampled in the range + -img_width * a < dx < img_width * a. Will not translate by default. + translate_height: Maximum absolute fraction for vertical translation. For example, + if translate_height=a, then vertical shift is randomly sampled in the range + -img_height * a < dy < img_height * a. Will not translate by default. + affine_p: Probability of applying random affine transformations. + erase_scale_min: Minimum value of range of proportion of erased area against input image. Default: 0.0001. + erase_scale_max: Maximum value of range of proportion of erased area against input image. Default: 0.01. + erase_ratio_min: Minimum value of range of aspect ratio of erased area. Default: 1. + erase_ratio_max: Maximum value of range of aspect ratio of erased area. Default: 1. + erase_p: Probability of applying random erase. + mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`. + mixup_p: Probability of applying random mixup v2. + random_crop_height: Desired output height of the crop. Must be int. + random_crop_width: Desired output width of the crop. Must be int. + random_crop_p: Probability of applying random crop. + + + Returns: + Returns tuple: (image, instances) with augmentation applied. + """ + if isinstance(scale, float): + scale = (scale, scale) + aug_stack = [] + if affine_p > 0: + aug_stack.append( + K.augmentation.RandomAffine( + degrees=rotation, + translate=(translate_width, translate_height), + scale=scale, + p=affine_p, + keepdim=True, + same_on_batch=True, + ) + ) + + if erase_p > 0: + aug_stack.append( + K.augmentation.RandomErasing( + scale=(erase_scale_min, erase_scale_max), + ratio=(erase_ratio_min, erase_ratio_max), + p=erase_p, + keepdim=True, + same_on_batch=True, + ) + ) + if mixup_p > 0: + aug_stack.append( + K.augmentation.RandomMixUpV2( + lambda_val=mixup_lambda, + p=mixup_p, + keepdim=True, + same_on_batch=True, + ) + ) + if random_crop_p > 0: + if random_crop_height > 0 and random_crop_width > 0: + aug_stack.append( + K.augmentation.RandomCrop( + size=(random_crop_height, random_crop_width), + pad_if_needed=True, + p=random_crop_p, + keepdim=True, + same_on_batch=True, + ) + ) + else: + raise ValueError(f"random_crop_hw height and width must be greater than 0.") + + augmenter = AugmentationSequential( + *aug_stack, + data_keys=["input", "keypoints"], + keepdim=True, + same_on_batch=True, + ) + + inst_shape = instances.shape + # Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2) + # or + # Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2) + instances = instances.reshape(inst_shape[0], -1, 2) + # (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2) + + aug_image, aug_instances = augmenter(image, instances) + + # After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2) + # or + # After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2) + return aug_image, aug_instances.reshape(*inst_shape) + + class RandomUniformNoise(IntensityAugmentationBase2D): """Data transformer for applying random uniform noise to input images. diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index 3e5aaa4..fc12152 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -1,6 +1,6 @@ """Generate confidence maps.""" -from typing import Dict, Iterator +from typing import Dict, Iterator, Tuple import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -8,6 +8,92 @@ from sleap_nn.data.utils import make_grid_vectors +def generate_confmaps( + instance: torch.Tensor, + img_hw: Tuple[int], + sigma: float = 1.5, + output_stride: int = 2, +) -> torch.Tensor: + """Generate Confidence maps. + + Args: + instance: Input keypoints. (n_samples, n_instances, n_nodes, 2) or + (n_samples, n_nodes, 2). + img_hw: Image size as tuple (height, width). + sigma: The standard deviation of the Gaussian distribution that is used to + generate confidence maps. Default: 1.5. + output_stride: The relative stride to use when generating confidence maps. + A larger stride will generate smaller confidence maps. Default: 2. + + Returns: + Confidence maps for the input keypoints. + """ + if instance.ndim != 3: + instance = instance.view(instance.shape[0], -1, 2) + # instances: (n_samples, n_nodes, 2) + + height, width = img_hw + + xv, yv = make_grid_vectors(height, width, output_stride) + + confidence_maps = make_confmaps( + instance, + xv, + yv, + sigma * output_stride, + ) # (n_samples, n_nodes, height/ output_stride, width/ output_stride) + + return confidence_maps + + +def generate_multiconfmaps( + instances: torch.Tensor, + img_hw: Tuple[int], + num_instances: int, + sigma: float = 1.5, + output_stride: int = 2, + is_centroids: bool = False, +) -> torch.Tensor: + """Generate multi-instance confidence maps. + + Args: + instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or + for centroids - (n_samples, n_instances, 2) + img_hw: Image size as tuple (height, width). + num_instances: Original number of instances in the frame. + sigma: The standard deviation of the Gaussian distribution that is used to + generate confidence maps. Default: 1.5. + output_stride: The relative stride to use when generating confidence maps. + A larger stride will generate smaller confidence maps. Default: 2. + is_centroids: True if confidence maps should be generates for centroids else False. + Default: False. + + Returns: + Confidence maps for the input keypoints. + """ + if is_centroids: + points = instances[:, :num_instances, :].unsqueeze(dim=-2) + # (n_samples, n_instances, 1, 2) + else: + points = instances[ + :, :num_instances, :, : + ] # (n_samples, n_instances, n_nodes, 2) + + height, width = img_hw + + xv, yv = make_grid_vectors(height, width, output_stride) + + confidence_maps = make_multi_confmaps( + points, + xv, + yv, + sigma * output_stride, + ) # (n_samples, n_nodes, height/ output_stride, width/ output_stride). + # If `is_centroids`, (n_samples, 1, height/ output_stride, width/ output_stride). + + return confidence_maps + + def make_confmaps( points_batch: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float ) -> torch.Tensor: diff --git a/sleap_nn/data/edge_maps.py b/sleap_nn/data/edge_maps.py index 5c571f4..33ac0cb 100644 --- a/sleap_nn/data/edge_maps.py +++ b/sleap_nn/data/edge_maps.py @@ -247,6 +247,82 @@ def get_edge_points( return edge_sources, edge_destinations +def generate_pafs( + instances: torch.Tensor, + img_hw: Tuple[int], + sigma: float = 1.5, + output_stride=2, + edge_inds: Optional[torch.Tensor] = attrs.field( + default=None, converter=attrs.converters.optional(ensure_list) + ), + flatten_channels: bool = False, +) -> torch.Tensor: + """Generate part-affinity fields. + + Args: + instances: Input instances. (n_samples, n_instances, n_nodes, 2) + img_hw: Image size as tuple (height, width). + sigma: The standard deviation of the Gaussian distribution that is used to + generate confidence maps. Default: 1.5. + output_stride: The relative stride to use when generating confidence maps. + A larger stride will generate smaller confidence maps. Default: 2. + edge_inds: `torch.Tensor` to use for looking up the index of the + edges. + flatten_channels: If False, the generated tensors are of shape + [height, width, n_edges, 2]. If True, generated tensors are of shape + [height, width, n_edges * 2] by flattening the last 2 axes. + + Returns: + The "part_affinity_fields" key will be a tensor of shape + (grid_height, grid_width, n_edges, 2) containing the combined part affinity + fields of all instances in the frame. + + If the `flatten_channels` attribute is set to True, the last 2 axes of the + "part_affinity_fields" are flattened to produce a tensor of shape + (grid_height, grid_width, n_edges * 2). This is a convenient form when + training models as a rank-4 (batched) tensor will generally be expected. + """ + image_height, image_width = img_hw + + # Generate sampling grid vectors. + xv, yv = make_grid_vectors( + image_height=image_height, + image_width=image_width, + output_stride=output_stride, + ) + grid_height = len(yv) + grid_width = len(xv) + n_edges = len(edge_inds) + + instances = instances[0] # n_samples=1 + in_img = (instances > 0) & (instances < torch.stack([xv[-1], yv[-1]]).view(1, 1, 2)) + in_img = in_img.all(dim=-1).any(dim=1) + assert len(in_img.shape) == 1 + instances = instances[in_img] + + edge_sources, edge_destinations = get_edge_points(instances, edge_inds) + assert len(edge_sources.shape) == 3 + assert edge_sources.shape[1:] == (n_edges, 2) + + assert len(edge_destinations.shape) == 3 + assert edge_destinations.shape[1:] == (n_edges, 2) + + pafs = make_multi_pafs( + xv=xv, + yv=yv, + edge_sources=edge_sources, + edge_destinations=edge_destinations, + sigma=sigma, + ) + assert pafs.shape == (grid_height, grid_width, n_edges, 2) + + if flatten_channels: + pafs = pafs.reshape(grid_height, grid_width, n_edges * 2) + assert pafs.shape == (grid_height, grid_width, n_edges * 2) + + return pafs + + class PartAffinityFieldsGenerator(IterDataPipe): """Transformer to generate part affinity fields. diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index 8e344b0..47761ff 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -35,13 +35,13 @@ def find_points_bbox_midpoint(points: torch.Tensor) -> torch.Tensor: return (pts_max + pts_min) * 0.5 -def find_centroids( +def generate_centroids( points: torch.Tensor, anchor_ind: Optional[int] = None ) -> torch.Tensor: """Return centroids, falling back to bounding box midpoints. Args: - points: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2), + points: A torch.Tensor of dtype torch.float32 and of shape (..., n_nodes, 2), i.e., rank >= 2. anchor_ind: The index of the node to use as the anchor for the centroid. If not provided or if not present in the instance, the midpoint of the bounding box @@ -60,7 +60,7 @@ def find_centroids( if missing_anchors.any(): centroids[missing_anchors] = find_points_bbox_midpoint(points[missing_anchors]) - return centroids + return centroids # (..., n_instances, 2) class InstanceCentroidFinder(IterDataPipe): @@ -88,7 +88,7 @@ def __init__( def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Add `"centroids"` key to example.""" for ex in self.source_dp: - ex["centroids"] = find_centroids( + ex["centroids"] = generate_centroids( ex["instances"], anchor_ind=self.anchor_ind ) # (n_samples, n_instances, 2) yield ex diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 6956732..11d419b 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -7,7 +7,6 @@ import torch from kornia.geometry.transform import crop_and_resize from torch.utils.data.datapipes.datapipe import IterDataPipe -from sleap_nn.data.resizing import find_padding_for_stride def find_instance_crop_size( @@ -107,6 +106,54 @@ def make_centered_bboxes( return corners + offset +def generate_crops( + image: torch.Tensor, + instance: torch.Tensor, + centroid: torch.Tensor, + crop_size: Tuple[int], +) -> Dict[str, torch.Tensor]: + """Generate cropped image for the given centroid. + + Args: + image: Input source image. (n_samples, C, H, W) + instance: Keypoints for the instance to be cropped. (n_nodes, 2) + centroid: Centroid of the instance to be cropped. (2) + crop_size: (height, width) of the crop to be generated. + + Returns: + A dictionary with cropped images, bounding box for the cropped instance, keypoints and + centroids adjusted to the crop. + """ + box_size = crop_size + + # Generate bounding boxes from centroid. + instance_bbox = torch.unsqueeze( + make_centered_bboxes(centroid, box_size[0], box_size[1]), 0 + ) # (n_samples=1, 4, 2) + + # Generate cropped image of shape (n_samples, C, crop_H, crop_W) + instance_image = crop_and_resize( + image, + boxes=instance_bbox, + size=box_size, + ) + + # Access top left point (x,y) of bounding box and subtract this offset from + # position of nodes. + point = instance_bbox[0][0] + center_instance = (instance - point).unsqueeze(0) # (n_samples=1, n_nodes, 2) + centered_centroid = (centroid - point).unsqueeze(0) # (n_samples=1, 2) + + cropped_sample = { + "instance_image": instance_image, + "instance_bbox": instance_bbox, + "instance": center_instance, + "centroid": centered_centroid, + } + + return cropped_sample + + class InstanceCropper(IterDataPipe): """IterDataPipe for cropping instances. diff --git a/sleap_nn/data/normalization.py b/sleap_nn/data/normalization.py index 341e710..8376bb4 100644 --- a/sleap_nn/data/normalization.py +++ b/sleap_nn/data/normalization.py @@ -41,6 +41,13 @@ def convert_to_rgb(image: torch.Tensor): return image +def apply_normalization(image: torch.Tensor): + """Normalize image tensor.""" + if not torch.is_floating_point(image): + image = image.to(torch.float32) / 255.0 + return image + + class Normalizer(IterDataPipe): """IterDataPipe for applying normalization. diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index b97c2c9..1f5eb93 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -1,6 +1,6 @@ """This module implements pipeline blocks for reading input data such as labels.""" -from typing import Dict, Iterator, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple import numpy as np import sleap_io as sio @@ -28,6 +28,69 @@ def get_max_instances(labels: sio.Labels): return max_instances +def process_lf( + lf: sio.LabeledFrame, + video_idx: int, + max_instances: int, + user_instances_only: bool = True, +) -> Dict[str, Any]: + """Get sample dict from `sio.LabeledFrame`. + + Args: + lf: Input `sio.LabeledFrame`. + video_idx: Video index of the given lf. + max_instances: Maximum number of instances that could occur in a single LabeledFrame. + user_instances_only: True if filter labels only to user instances else False. + Default: True. + + Returns: + Dict with image, instancs, frame index, video index, original image size and + number of instances. + + """ + # Filter to user instances + if user_instances_only: + if lf.user_instances is not None and len(lf.user_instances) > 0: + lf.instances = lf.user_instances + + image = np.transpose(lf.image, (2, 0, 1)) # HWC -> CHW + + instances = [] + for inst in lf: + if not inst.is_empty: + instances.append(inst.numpy()) + instances = np.stack(instances, axis=0) + + # Add singleton time dimension for single frames. + image = np.expand_dims(image, axis=0) # (n_samples=1, C, H, W) + instances = np.expand_dims( + instances, axis=0 + ) # (n_samples=1, num_instances, num_nodes, 2) + + image = torch.from_numpy(image.astype("float32")) + instances = torch.from_numpy(instances.astype("float32")) + + num_instances, nodes = instances.shape[1:3] + img_height, img_width = image.shape[-2:] + + # append with nans for broadcasting + nans = torch.full((1, np.abs(max_instances - num_instances), nodes, 2), torch.nan) + instances = torch.cat( + [instances, nans], dim=1 + ) # (n_samples, max_instances, num_nodes, 2) + + ex = { + "image": image, + "instances": instances, + "video_idx": torch.tensor(video_idx, dtype=torch.int32), + "frame_idx": torch.tensor(lf.frame_idx, dtype=torch.int32), + "orig_size": torch.Tensor([img_height, img_width]), + "num_instances": num_instances, + } + + return ex + + class LabelsReader(IterDataPipe): """IterDataPipe for reading frames from Labels object. diff --git a/sleap_nn/data/resizing.py b/sleap_nn/data/resizing.py index b39a10b..4193b3c 100644 --- a/sleap_nn/data/resizing.py +++ b/sleap_nn/data/resizing.py @@ -33,7 +33,7 @@ def find_padding_for_stride( return pad_height, pad_width -def pad_to_stride(image: torch.Tensor, max_stride: int) -> torch.Tensor: +def apply_pad_to_stride(image: torch.Tensor, max_stride: int) -> torch.Tensor: """Pad an image to meet a max stride constraint. This is useful for ensuring there is no size mismatch between an image and the @@ -51,19 +51,20 @@ def pad_to_stride(image: torch.Tensor, max_stride: int) -> torch.Tensor: The input image with 0-padding applied to the bottom and/or right such that the new shape's height and width are both divisible by `max_stride`. """ - image_height, image_width = image.shape[-2:] - pad_height, pad_width = find_padding_for_stride( - image_height=image_height, - image_width=image_width, - max_stride=max_stride, - ) - - if pad_height > 0 or pad_width > 0: - image = F.pad( - image, - (0, pad_width, 0, pad_height), - mode="constant", - ).to(torch.float32) + if max_stride > 1: + image_height, image_width = image.shape[-2:] + pad_height, pad_width = find_padding_for_stride( + image_height=image_height, + image_width=image_width, + max_stride=max_stride, + ) + + if pad_height > 0 or pad_width > 0: + image = F.pad( + image, + (0, pad_width, 0, pad_height), + mode="constant", + ).to(torch.float32) return image @@ -84,6 +85,55 @@ def resize_image(image: torch.Tensor, scale: float): return image +def apply_resizer(image: torch.Tensor, instances: torch.Tensor, scale: float = 1.0): + """Rescale image and keypoints by a scale factor. + + Args: + image: Image tensor of shape (..., channels, height, width) + instances: Keypoints tensor. + scale: Factor to resize the image dimensions by, specified as a float + scalar. Default: 1.0. + + Returns: + Tuple with resized image and corresponding keypoints. + """ + if scale != 1.0: + image = resize_image(image, scale) + instances = instances * scale + return image, instances + + +def apply_sizematcher( + image: torch.Tensor, + max_height: Optional[int] = None, + max_width: Optional[int] = None, +): + """Apply padding to smaller image to (max_height, max_width) shape.""" + img_height, img_width = image.shape[-2:] + # pad images to max_height and max_width + if max_height is None: + max_height = img_height + if max_width is None: + max_width = img_width + pad_height = max_height - img_height + pad_width = max_width - img_width + if pad_height < 0: + raise ValueError( + f"Max height {max_height} should be greater than the current image height: {img_height}" + ) + if pad_width < 0: + raise ValueError( + f"Max width {max_width} should be greater than the current image width: {img_width}" + ) + image = F.pad( + image, + (0, pad_width, 0, pad_height), + mode="constant", + ).to(torch.float32) + + return image + + class Resizer(IterDataPipe): """IterDataPipe for resizing images. @@ -156,8 +206,9 @@ def __init__( def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return an example dictionary with the resized image and `orig_size` key to represent the original shape of the source image.""" for ex in self.source_datapipe: - if self.max_stride > 1: - ex[self.image_key] = pad_to_stride(ex[self.image_key], self.max_stride) + ex[self.image_key] = apply_pad_to_stride( + ex[self.image_key], self.max_stride + ) yield ex diff --git a/sleap_nn/inference/predictors.py b/sleap_nn/inference/predictors.py index 870227c..2f591e1 100644 --- a/sleap_nn/inference/predictors.py +++ b/sleap_nn/inference/predictors.py @@ -18,7 +18,7 @@ Resizer, PadToStride, resize_image, - pad_to_stride, + apply_pad_to_stride, ) from sleap_nn.data.normalization import Normalizer, convert_to_grayscale, convert_to_rgb from sleap_nn.data.instance_centroids import InstanceCentroidFinder @@ -51,7 +51,7 @@ class Predictor(ABC): Attributes: preprocess: Only for VideoReader provider. True if preprocessing (reszizing and - pad_to_stride) should be applied on the frames read in the video reader. + apply_pad_to_stride) should be applied on the frames read in the video reader. Default: True. video_preprocess_config: Preprocessing config for VideoReader with keys: [`batch_size`, `scale`, `is_rgb`, `max_stride`]. Default: {"batch_size": 4, "scale": 1.0, @@ -277,7 +277,7 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: scale = self.video_preprocess_config["scale"] if scale != 1.0: ex["image"] = resize_image(ex["image"], scale) - ex["image"] = pad_to_stride( + ex["image"] = apply_pad_to_stride( ex["image"], self.video_preprocess_config["max_stride"] ) outputs_list = self.inference_model(ex) diff --git a/sleap_nn/inference/topdown.py b/sleap_nn/inference/topdown.py index 33379ed..04a79b2 100644 --- a/sleap_nn/inference/topdown.py +++ b/sleap_nn/inference/topdown.py @@ -6,7 +6,7 @@ import numpy as np from sleap_nn.data.resizing import ( resize_image, - pad_to_stride, + apply_pad_to_stride, ) from sleap_nn.inference.peak_finding import crop_bboxes from sleap_nn.data.instance_cropping import make_centered_bboxes @@ -156,7 +156,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: orig_image = inputs["image"] scaled_image = resize_image(orig_image, self.input_scale) if self.max_stride != 1: - scaled_image = pad_to_stride(scaled_image, self.max_stride) + scaled_image = apply_pad_to_stride(scaled_image, self.max_stride) cms = self.torch_model(scaled_image) @@ -395,7 +395,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: orig_image = inputs["instance_image"] scaled_image = resize_image(orig_image, self.input_scale) if self.max_stride != 1: - scaled_image = pad_to_stride(scaled_image, self.max_stride) + scaled_image = apply_pad_to_stride(scaled_image, self.max_stride) cms = self.torch_model(scaled_image) diff --git a/tests/data/test_augmentation.py b/tests/data/test_augmentation.py index 7c619c8..e0a711b 100644 --- a/tests/data/test_augmentation.py +++ b/tests/data/test_augmentation.py @@ -1,8 +1,14 @@ import pytest import torch from torch.utils.data import DataLoader - +import sleap_io as sio from sleap_nn.data.augmentation import KorniaAugmenter, RandomUniformNoise +from sleap_nn.data.augmentation import ( + apply_intensity_augmentation, + apply_geometric_augmentation, +) +from sleap_nn.data.normalization import apply_normalization +from sleap_nn.data.providers import process_lf from sleap_nn.data.normalization import Normalizer from sleap_nn.data.providers import LabelsReader @@ -35,6 +41,62 @@ def test_uniform_noise(minimal_instance): assert aug_img.shape == (1, 1, 384, 384) +def test_apply_intensity_augmentation(minimal_instance): + """Test `apply_intensity_augmentation` function.""" + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 2) + ex["image"] = apply_normalization(ex["image"]) + + img, pts = apply_intensity_augmentation( + ex["image"], + ex["instances"], + uniform_noise_p=1.0, + contrast_p=1.0, + brightness_p=1.0, + gaussian_noise_p=1.0, + ) + # Test all augmentations. + assert torch.is_tensor(img) + assert torch.is_tensor(pts) + assert not torch.equal(img, ex["image"]) + assert img.shape == (1, 1, 384, 384) + assert pts.shape == (1, 2, 2, 2) + + +def test_apply_geometric_augmentation(minimal_instance): + """Test `apply_geometric_augmentation` function.""" + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 2) + ex["image"] = apply_normalization(ex["image"]) + + img, pts = apply_geometric_augmentation( + ex["image"], + ex["instances"], + scale=0.5, + affine_p=1.0, + random_crop_height=100, + random_crop_width=100, + random_crop_p=1.0, + erase_p=1.0, + mixup_p=1.0, + ) + # Test all augmentations. + assert torch.is_tensor(img) + assert torch.is_tensor(pts) + assert not torch.equal(img, ex["image"]) + assert img.shape == (1, 1, 100, 100) + assert pts.shape == (1, 2, 2, 2) + + with pytest.raises( + ValueError, match="crop_hw height and width must be greater than 0." + ): + img, pts = apply_geometric_augmentation( + ex["image"], ex["instances"], random_crop_p=1.0, random_crop_height=0 + ) + + def test_kornia_augmentation(minimal_instance): """Test KorniaAugmenter module.""" p = LabelsReader.from_filename(minimal_instance) diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index 5196429..691f08d 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -1,16 +1,18 @@ import torch - +import sleap_io as sio from sleap_nn.data.confidence_maps import ( ConfidenceMapGenerator, MultiConfidenceMapGenerator, make_multi_confmaps, make_confmaps, + generate_confmaps, + generate_multiconfmaps, ) from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.instance_cropping import InstanceCropper from sleap_nn.data.normalization import Normalizer from sleap_nn.data.resizing import Resizer -from sleap_nn.data.providers import LabelsReader +from sleap_nn.data.providers import LabelsReader, process_lf from sleap_nn.data.utils import make_grid_vectors import numpy as np @@ -121,3 +123,35 @@ def test_multi_confmaps(minimal_instance): assert sample["confidence_maps"].shape == (1, 2, 768, 768) assert torch.sum(sample["confidence_maps"] > 0.93) == 4 + + +def test_generate_confmaps(minimal_instance): + """Test `generate_confmaps` function.""" + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 2) + + confmaps = generate_confmaps( + ex["instances"][:, 0].unsqueeze(dim=1), img_hw=(384, 384) + ) + assert confmaps.shape == (1, 2, 192, 192) + + +def test_generate_multiconfmaps(minimal_instance): + """Test `generate_multiconfmaps` function.""" + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 2) + + confmaps = generate_multiconfmaps( + ex["instances"], img_hw=(384, 384), num_instances=ex["num_instances"] + ) + assert confmaps.shape == (1, 2, 192, 192) + + confmaps = generate_multiconfmaps( + ex["instances"][:, :, 0, :], + img_hw=(384, 384), + num_instances=ex["num_instances"], + is_centroids=True, + ) + assert confmaps.shape == (1, 1, 192, 192) diff --git a/tests/data/test_edge_maps.py b/tests/data/test_edge_maps.py index ec77ea2..d994932 100644 --- a/tests/data/test_edge_maps.py +++ b/tests/data/test_edge_maps.py @@ -1,13 +1,15 @@ import numpy as np import torch +import sleap_io as sio from sleap_nn.data.utils import make_grid_vectors -from sleap_nn.data.providers import LabelsReader +from sleap_nn.data.providers import LabelsReader, process_lf from sleap_nn.data.edge_maps import ( distance_to_edge, make_edge_maps, make_pafs, make_multi_pafs, get_edge_points, + generate_pafs, PartAffinityFieldsGenerator, ) @@ -173,6 +175,20 @@ def test_get_edge_points(): ) +def test_generate_pafs(minimal_instance): + """Test `generate_pafs` function.""" + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 2) + + pafs = generate_pafs( + ex["instances"], + img_hw=(384, 384), + edge_inds=torch.Tensor(labels.skeletons[0].edge_inds), + ) + assert pafs.shape == (192, 192, 1, 2) + + def test_part_affinity_fields_generator(minimal_instance): provider = LabelsReader.from_filename(minimal_instance) paf_generator = PartAffinityFieldsGenerator( diff --git a/tests/data/test_instance_centroids.py b/tests/data/test_instance_centroids.py index de890c9..245c0c4 100644 --- a/tests/data/test_instance_centroids.py +++ b/tests/data/test_instance_centroids.py @@ -1,14 +1,42 @@ import torch - +import sleap_io as sio from sleap_nn.data.instance_centroids import ( InstanceCentroidFinder, - find_centroids, + generate_centroids, ) -from sleap_nn.data.providers import LabelsReader +from sleap_nn.data.providers import LabelsReader, process_lf + + +def test_generate_centroids(minimal_instance): + """Test `generate_centroids` function.""" + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 2) + + centroids = generate_centroids(ex["instances"], 1).int() + gt = torch.Tensor([[[152, 158], [278, 203]]]).int() + assert torch.equal(centroids, gt) + + partial_instance = torch.Tensor( + [ + [ + [[92.6522, 202.7260], [152.3419, 158.4236], [97.2618, 53.5834]], + [[205.9301, 187.8896], [torch.nan, torch.nan], [201.4264, 75.2373]], + [ + [torch.nan, torch.nan], + [torch.nan, torch.nan], + [torch.nan, torch.nan], + ], + ] + ] + ) + centroids = generate_centroids(partial_instance, 1).int() + gt = torch.Tensor([[[152, 158], [203, 131], [torch.nan, torch.nan]]]).int() + assert torch.equal(centroids, gt) def test_instance_centroids(minimal_instance): - """Test InstanceCentroidFinder and find_centroids functions.""" + """Test InstanceCentroidFinder and generate_centroids functions.""" # Undefined anchor_ind. datapipe = LabelsReader.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) @@ -20,7 +48,7 @@ def test_instance_centroids(minimal_instance): assert torch.equal(centroids, gt) # Defined anchor_ind. - centroids = find_centroids(instances, 1).int() + centroids = generate_centroids(instances, 1).int() gt = torch.Tensor([[[152, 158], [278, 203]]]).int() assert torch.equal(centroids, gt) @@ -38,6 +66,6 @@ def test_instance_centroids(minimal_instance): ] ] ) - centroids = find_centroids(partial_instance, 1).int() + centroids = generate_centroids(partial_instance, 1).int() gt = torch.Tensor([[[152, 158], [203, 131], [torch.nan, torch.nan]]]).int() assert torch.equal(centroids, gt) diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index cd984f0..8ddbd44 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -1,15 +1,16 @@ import sleap_io as sio import torch -from sleap_nn.data.instance_centroids import InstanceCentroidFinder +from sleap_nn.data.instance_centroids import InstanceCentroidFinder, generate_centroids from sleap_nn.data.instance_cropping import ( InstanceCropper, find_instance_crop_size, + generate_crops, make_centered_bboxes, ) from sleap_nn.data.normalization import Normalizer from sleap_nn.data.resizing import SizeMatcher, Resizer, PadToStride -from sleap_nn.data.providers import LabelsReader +from sleap_nn.data.providers import LabelsReader, process_lf def test_find_instance_crop_size(minimal_instance): @@ -83,3 +84,20 @@ def test_instance_cropper(minimal_instance): ) centered_instance = sample["instance"] assert torch.equal(centered_instance, gt.unsqueeze(0)) + + +def test_generate_crops(minimal_instance): + """Test `generate_crops` function.""" + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 2) + + centroids = generate_centroids(ex["instances"], 0) + cropped_ex = generate_crops( + ex["image"], ex["instances"][0, 0], centroids[0, 0], crop_size=(100, 100) + ) + + assert cropped_ex["instance"].shape == (1, 2, 2) + assert cropped_ex["centroid"].shape == (1, 2) + assert cropped_ex["instance_image"].shape == (1, 1, 100, 100) + assert cropped_ex["instance_bbox"].shape == (1, 4, 2) diff --git a/tests/data/test_normalization.py b/tests/data/test_normalization.py index 8b41b8c..bcf7183 100644 --- a/tests/data/test_normalization.py +++ b/tests/data/test_normalization.py @@ -1,7 +1,11 @@ import torch import numpy as np -from sleap_nn.data.normalization import Normalizer, convert_to_grayscale +from sleap_nn.data.normalization import ( + Normalizer, + convert_to_grayscale, + apply_normalization, +) from sleap_nn.data.providers import LabelsReader @@ -28,3 +32,11 @@ def test_convert_to_grayscale(): img = torch.randint(0, 255, (3, 200, 200)) res = convert_to_grayscale(img) assert res.shape[0] == 1 + + +def test_apply_normalization(): + """Test `apply_normalization` function.""" + img = torch.randint(0, 255, (3, 200, 200)) + res = apply_normalization(img) + assert torch.max(res) <= 1.0 and torch.min(res) == 0.0 + assert res.dtype == torch.float32 diff --git a/tests/data/test_providers.py b/tests/data/test_providers.py index 318aadd..6271c6c 100644 --- a/tests/data/test_providers.py +++ b/tests/data/test_providers.py @@ -1,6 +1,6 @@ import torch -from sleap_nn.data.providers import LabelsReader, VideoReader +from sleap_nn.data.providers import LabelsReader, VideoReader, process_lf from queue import Queue import sleap_io as sio import numpy as np @@ -91,3 +91,13 @@ def test_videoreader_provider(centered_instance_video): finally: reader.join() assert reader.total_len() == 6 + + +def test_process_lf(minimal_instance): + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 4) + + assert ex["image"].shape == torch.Size([1, 1, 384, 384]) + assert ex["instances"].shape == torch.Size([1, 4, 2, 2]) + assert torch.isnan(ex["instances"][:, 2:, :, :]).all() diff --git a/tests/data/test_resizing.py b/tests/data/test_resizing.py index f1c831d..0d926be 100644 --- a/tests/data/test_resizing.py +++ b/tests/data/test_resizing.py @@ -1,7 +1,14 @@ import torch -from sleap_nn.data.providers import LabelsReader -from sleap_nn.data.resizing import SizeMatcher, Resizer, PadToStride +from sleap_nn.data.providers import LabelsReader, process_lf +from sleap_nn.data.resizing import ( + SizeMatcher, + Resizer, + PadToStride, + apply_resizer, + apply_pad_to_stride, + apply_sizematcher, +) import numpy as np import sleap_io as sio import pytest @@ -69,12 +76,56 @@ def test_padtostride(minimal_instance): image = sample["image"] assert image.shape == torch.Size([1, 1, 384, 384]) - pipe = PadToStride(l, max_stride=2) + pipe = PadToStride(l, max_stride=500) sample = next(iter(pipe)) image = sample["image"] + assert image.shape == torch.Size([1, 1, 500, 500]) + + +def test_apply_resizer(minimal_instance): + """Test `apply_resizer` function.""" + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 2) + + image, instances = apply_resizer(ex["image"], ex["instances"], scale=2.0) + assert image.shape == torch.Size([1, 1, 768, 768]) + assert torch.all(instances == ex["instances"] * 2.0) + + +def test_apply_pad_to_stride(minimal_instance): + """Test `apply_pad_to_stride` function.""" + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 2) + + image = apply_pad_to_stride(ex["image"], max_stride=2) assert image.shape == torch.Size([1, 1, 384, 384]) - pipe = PadToStride(l, max_stride=500) - sample = next(iter(pipe)) - image = sample["image"] + image = apply_pad_to_stride(ex["image"], max_stride=200) + assert image.shape == torch.Size([1, 1, 400, 400]) + + +def test_apply_sizematcher(minimal_instance): + """Test `apply_sizematcher` function.""" + labels = sio.load_slp(minimal_instance) + lf = labels[0] + ex = process_lf(lf, 0, 2) + + image = apply_sizematcher(ex["image"], 500, 500) assert image.shape == torch.Size([1, 1, 500, 500]) + + image = apply_sizematcher(ex["image"]) + assert image.shape == torch.Size([1, 1, 384, 384]) + + with pytest.raises( + Exception, + match=f"Max height {100} should be greater than the current image height: {384}", + ): + image = apply_sizematcher(ex["image"], max_height=100, max_width=500) + + with pytest.raises( + Exception, + match=f"Max width {100} should be greater than the current image width: {384}", + ): + image = apply_sizematcher(ex["image"], max_height=500, max_width=100)