From 2dfcd90bd6764594633b225325aa8dfff2b3e0d1 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 3 Aug 2023 07:55:25 -0700 Subject: [PATCH 01/55] added make_centered_bboxes & normalize_bboxes --- sleap_nn/data/instance_cropping.py | 61 ++++++++++++++++++++++++++++++ sleap_nn/inference/peak_finding.py | 5 +++ 2 files changed, 66 insertions(+) create mode 100644 sleap_nn/data/instance_cropping.py create mode 100644 sleap_nn/inference/peak_finding.py diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py new file mode 100644 index 00000000..1047db58 --- /dev/null +++ b/sleap_nn/data/instance_cropping.py @@ -0,0 +1,61 @@ +import torch + +def make_centered_bboxes( + centroids: torch.Tensor, box_height: int, box_width: int +) -> torch.Tensor: + """Create centered bounding boxes around centroid. + + To be used with `kornia.geometry.transform.crop_and_resize`in the following (clockwise) + order: top-left, top-right, bottom-right and bottom-left. + """ + half_h = box_height / 2 + half_w = box_width / 2 + + # Get x and y values from the centroids tensor + x = centroids[..., 0] + y = centroids[..., 1] + + # Calculate the corner points + top_left = torch.stack([x - half_w, y - half_h], dim=-1) + top_right = torch.stack([x + half_w, y - half_h], dim=-1) + bottom_left = torch.stack([x - half_w, y + half_h], dim=-1) + bottom_right = torch.stack([x + half_w, y + half_h], dim=-1) + + # Get bounding box + corners = torch.stack([top_left, top_right, bottom_right, bottom_left], dim=-2) + + return corners + +def normalize_bboxes( + bboxes: torch.Tensor, image_height: int, image_width: int +) -> torch.Tensor: + """Normalize bounding box coordinates to the range [0, 1]. + + This is useful for transforming points for PyTorch operations that require + normalized image coordinates. + + Args: + bboxes: Tensor of shape (n_bboxes, 4) and dtype torch.float32, where the last axis + corresponds to (y1, x1, y2, x2) coordinates of the bounding boxes. + image_height: Scalar integer indicating the height of the image. + image_width: Scalar integer indicating the width of the image. + + Returns: + Tensor of the normalized points of the same shape as `bboxes`. + + The normalization applied to each point is `x / (image_width - 1)` and + `y / (image_width - 1)`. + + See also: unnormalize_bboxes + """ + # Compute normalizing factor of shape (1, 4). + factor = ( + torch.tensor( + [[image_height, image_width, image_height, image_width]], dtype=torch.float32 + ) + - 1 + ) + + # Normalize and return. + normalized_bboxes = bboxes / factor + return normalized_bboxes diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py new file mode 100644 index 00000000..3753d5f5 --- /dev/null +++ b/sleap_nn/inference/peak_finding.py @@ -0,0 +1,5 @@ +import torch +import numpy as np +from typing import Tuple, Optional +from sleap_nn.data.instance_cropping import make_centered_bboxes, normalize_bboxes + From 1088e7fea5a6668f25fafc64b1ae334e7b070f43 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 3 Aug 2023 07:55:25 -0700 Subject: [PATCH 02/55] added make_centered_bboxes & normalize_bboxes --- sleap_nn/data/instance_cropping.py | 33 ++++++++++++++++++++++++++++++ sleap_nn/inference/peak_finding.py | 5 +++++ 2 files changed, 38 insertions(+) create mode 100644 sleap_nn/inference/peak_finding.py diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 8a85c371..e0b90715 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -6,6 +6,39 @@ import numpy as np import torch +def normalize_bboxes( + bboxes: torch.Tensor, image_height: int, image_width: int +) -> torch.Tensor: + """Normalize bounding box coordinates to the range [0, 1]. + + This is useful for transforming points for PyTorch operations that require + normalized image coordinates. + + Args: + bboxes: Tensor of shape (n_bboxes, 4) and dtype torch.float32, where the last axis + corresponds to (y1, x1, y2, x2) coordinates of the bounding boxes. + image_height: Scalar integer indicating the height of the image. + image_width: Scalar integer indicating the width of the image. + + Returns: + Tensor of the normalized points of the same shape as `bboxes`. + + The normalization applied to each point is `x / (image_width - 1)` and + `y / (image_width - 1)`. + + See also: unnormalize_bboxes + """ + # Compute normalizing factor of shape (1, 4). + factor = ( + torch.tensor( + [[image_height, image_width, image_height, image_width]], dtype=torch.float32 + ) + - 1 + ) + + # Normalize and return. + normalized_bboxes = bboxes / factor + return normalized_bboxes def make_centered_bboxes( centroids: torch.Tensor, box_height: int, box_width: int diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py new file mode 100644 index 00000000..3753d5f5 --- /dev/null +++ b/sleap_nn/inference/peak_finding.py @@ -0,0 +1,5 @@ +import torch +import numpy as np +from typing import Tuple, Optional +from sleap_nn.data.instance_cropping import make_centered_bboxes, normalize_bboxes + From 2d0a0098dd046e972a7ac593773223a2ce5c76b6 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 3 Aug 2023 13:26:55 -0700 Subject: [PATCH 03/55] created test_instance_cropping.py --- sleap_nn/data/instance_cropping.py | 55 +++++++++++++++--------------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index e0b90715..cd3086d7 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -6,6 +6,32 @@ import numpy as np import torch +def make_centered_bboxes( + centroids: torch.Tensor, box_height: int, box_width: int +) -> torch.Tensor: + """Create centered bounding boxes around centroid. + + To be used with `kornia.geometry.transform.crop_and_resize`in the following (clockwise) + order: top-left, top-right, bottom-right and bottom-left. + """ + half_h = box_height / 2 + half_w = box_width / 2 + + # Get x and y values from the centroids tensor + x = centroids[..., 0] + y = centroids[..., 1] + + # Calculate the corner points + top_left = torch.stack([x - half_w, y - half_h], dim=-1) + top_right = torch.stack([x + half_w, y - half_h], dim=-1) + bottom_left = torch.stack([x - half_w, y + half_h], dim=-1) + bottom_right = torch.stack([x + half_w, y + half_h], dim=-1) + + # Get bounding box + corners = torch.stack([top_left, top_right, bottom_right, bottom_left], dim=-2) + + return corners + def normalize_bboxes( bboxes: torch.Tensor, image_height: int, image_width: int ) -> torch.Tensor: @@ -40,33 +66,6 @@ def normalize_bboxes( normalized_bboxes = bboxes / factor return normalized_bboxes -def make_centered_bboxes( - centroids: torch.Tensor, box_height: int, box_width: int -) -> torch.Tensor: - """Create centered bounding boxes around centroid. - - To be used with `kornia.geometry.transform.crop_and_resize`in the following (clockwise) - order: top-left, top-right, bottom-right and bottom-left. - """ - half_h = box_height / 2 - half_w = box_width / 2 - - # Get x and y values from the centroids tensor - x = centroids[..., 0] - y = centroids[..., 1] - - # Calculate the corner points - top_left = torch.stack([x - half_w, y - half_h], dim=-1) - top_right = torch.stack([x + half_w, y - half_h], dim=-1) - bottom_left = torch.stack([x - half_w, y + half_h], dim=-1) - bottom_right = torch.stack([x + half_w, y + half_h], dim=-1) - - # Get bounding box - corners = torch.stack([top_left, top_right, bottom_right, bottom_left], dim=-2) - - return corners - - class InstanceCropper(IterDataPipe): """Datapipe for cropping instances. @@ -120,4 +119,4 @@ def __iter__(self): "bbox": bbox, # (frames, 4, 2) "instance": center_instance, # (n_instances, 2) } - yield instance_example + yield instance_example \ No newline at end of file From 02ea629e4aafeb06d1390995ca44f5bc40501d31 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Sat, 5 Aug 2023 23:49:53 -0700 Subject: [PATCH 04/55] added test normalize bboxes; added find_global_peaks_rough --- sleap_nn/inference/peak_finding.py | 38 ++++++++++++++++++++++++++++ tests/data/test_instance_cropping.py | 32 +++++++++++++++++++++-- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index 3753d5f5..47b5aaac 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -1,5 +1,43 @@ +"""Peak finding for inference.""" import torch import numpy as np from typing import Tuple, Optional from sleap_nn.data.instance_cropping import make_centered_bboxes, normalize_bboxes +import torch + +def find_global_peaks_rough( + cms: torch.Tensor, threshold: float = 0.1 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Find the global maximum for each sample and channel. + + Args: + cms: Tensor of shape (samples, channels, height, width). + threshold: Scalar float specifying the minimum confidence value for peaks. Peaks + with values below this threshold will be replaced with NaNs. + + Returns: + A tuple of (peak_points, peak_vals). + + peak_points: float32 tensor of shape (samples, channels, 2), where the last axis + indicates peak locations in xy order. + + peak_vals: float32 tensor of shape (samples, channels) containing the values at + the peak points. + """ + # Find the maximum values and their indices along the height and width axes. + max_values, max_indices_y = torch.max(cms, dim=2, keepdim=True) + max_values, max_indices_x = torch.max(max_values, dim=3, keepdim=True) + + max_indices_x = max_indices_x.squeeze(dim=(2, 3)) # (samples, channels) + max_indices_y = max_indices_y.max(dim=3).values # (samples, channels, 1) + max_values = max_values.squeeze(-1).squeeze(-1) # (samples, channels) + peak_points = torch.cat([max_indices_x.unsqueeze(-1), max_indices_y], dim=-1) + + # Create masks for values below the threshold. + below_threshold_mask = max_values < threshold + + # Replace values below the threshold with NaN. + max_values[below_threshold_mask] = float('nan') + return peak_points, max_values diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 298cc97f..cc538ad1 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -1,9 +1,37 @@ from sleap_nn.data.providers import LabelsReader import torch -from sleap_nn.data.instance_cropping import make_centered_bboxes, InstanceCropper +from sleap_nn.data.instance_cropping import make_centered_bboxes, normalize_bboxes, InstanceCropper from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.normalization import Normalizer +def test_normalize_bboxes(minimal_instance): + bboxes = torch.Tensor( + [ + [72.4970, 130.5748, 172.4970, 230.5748], + [3.0000, 5.5748, 100.0000, 220.1235], + ] + ) + + norm_bboxes = normalize_bboxes(bboxes, image_height=200, image_width=300) + + gt = torch.Tensor( + [ + [ + 0.3643065392971039, + 0.4367050230503082, + 0.8668190836906433, + 0.7711531519889832, + ], + [ + 0.015075377188622952, + 0.01864481531083584, + 0.5025125741958618, + 0.7361990213394165, + ], + ] + ) + + assert torch.equal(norm_bboxes, gt) def test_instance_cropper(minimal_instance): datapipe = LabelsReader.from_filename(minimal_instance) @@ -38,4 +66,4 @@ def test_instance_cropper(minimal_instance): ] ) centered_instance = sample["instance"] - assert torch.equal(centered_instance, gt) + assert torch.equal(centered_instance, gt) \ No newline at end of file From 711b3aa643de4003b3d4bafe79ee776701239d51 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Sun, 6 Aug 2023 00:15:52 -0700 Subject: [PATCH 05/55] black formatted --- sleap_nn/data/instance_cropping.py | 10 +++++++--- sleap_nn/inference/peak_finding.py | 7 ++++--- tests/data/test_instance_cropping.py | 10 ++++++++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index cd3086d7..ab6f3689 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -6,6 +6,7 @@ import numpy as np import torch + def make_centered_bboxes( centroids: torch.Tensor, box_height: int, box_width: int ) -> torch.Tensor: @@ -32,6 +33,7 @@ def make_centered_bboxes( return corners + def normalize_bboxes( bboxes: torch.Tensor, image_height: int, image_width: int ) -> torch.Tensor: @@ -57,8 +59,9 @@ def normalize_bboxes( # Compute normalizing factor of shape (1, 4). factor = ( torch.tensor( - [[image_height, image_width, image_height, image_width]], dtype=torch.float32 - ) + [[image_height, image_width, image_height, image_width]], + dtype=torch.float32, + ) - 1 ) @@ -66,6 +69,7 @@ def normalize_bboxes( normalized_bboxes = bboxes / factor return normalized_bboxes + class InstanceCropper(IterDataPipe): """Datapipe for cropping instances. @@ -119,4 +123,4 @@ def __iter__(self): "bbox": bbox, # (frames, 4, 2) "instance": center_instance, # (n_instances, 2) } - yield instance_example \ No newline at end of file + yield instance_example diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index 47b5aaac..cefa4f5c 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -5,6 +5,7 @@ from sleap_nn.data.instance_cropping import make_centered_bboxes, normalize_bboxes import torch + def find_global_peaks_rough( cms: torch.Tensor, threshold: float = 0.1 ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -30,14 +31,14 @@ def find_global_peaks_rough( max_values, max_indices_x = torch.max(max_values, dim=3, keepdim=True) max_indices_x = max_indices_x.squeeze(dim=(2, 3)) # (samples, channels) - max_indices_y = max_indices_y.max(dim=3).values # (samples, channels, 1) - max_values = max_values.squeeze(-1).squeeze(-1) # (samples, channels) + max_indices_y = max_indices_y.max(dim=3).values # (samples, channels, 1) + max_values = max_values.squeeze(-1).squeeze(-1) # (samples, channels) peak_points = torch.cat([max_indices_x.unsqueeze(-1), max_indices_y], dim=-1) # Create masks for values below the threshold. below_threshold_mask = max_values < threshold # Replace values below the threshold with NaN. - max_values[below_threshold_mask] = float('nan') + max_values[below_threshold_mask] = float("nan") return peak_points, max_values diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index cc538ad1..279d7946 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -1,9 +1,14 @@ from sleap_nn.data.providers import LabelsReader import torch -from sleap_nn.data.instance_cropping import make_centered_bboxes, normalize_bboxes, InstanceCropper +from sleap_nn.data.instance_cropping import ( + make_centered_bboxes, + normalize_bboxes, + InstanceCropper, +) from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.normalization import Normalizer + def test_normalize_bboxes(minimal_instance): bboxes = torch.Tensor( [ @@ -33,6 +38,7 @@ def test_normalize_bboxes(minimal_instance): assert torch.equal(norm_bboxes, gt) + def test_instance_cropper(minimal_instance): datapipe = LabelsReader.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) @@ -66,4 +72,4 @@ def test_instance_cropper(minimal_instance): ] ) centered_instance = sample["instance"] - assert torch.equal(centered_instance, gt) \ No newline at end of file + assert torch.equal(centered_instance, gt) From 9a728aa0af8694d954ef1b163b71269c0e136311 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Sun, 6 Aug 2023 00:19:39 -0700 Subject: [PATCH 06/55] black formatted peak_finding --- sleap_nn/inference/peak_finding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index 6d22aaba..71879e3a 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -4,6 +4,7 @@ from typing import Tuple, Optional from sleap_nn.data.instance_cropping import make_centered_bboxes, normalize_bboxes + def find_global_peaks_rough( cms: torch.Tensor, threshold: float = 0.1 ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -40,4 +41,3 @@ def find_global_peaks_rough( max_values[below_threshold_mask] = float("nan") return peak_points, max_values - From e84535fd05469ad997c53f8498386037aebc8952 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 9 Aug 2023 22:50:36 -0700 Subject: [PATCH 07/55] added make_grid_vectors, normalize_bboxes, integral_regression, added docstring to make_centered_bboxes, fixed find_global_peaks_rough; added crop_bboxes --- sleap_nn/data/instance_cropping.py | 58 ++++++++++++++-------- sleap_nn/data/utils.py | 34 +++++++++++++ sleap_nn/inference/peak_finding.py | 77 +++++++++++++++++++++++++++++- 3 files changed, 147 insertions(+), 22 deletions(-) create mode 100644 sleap_nn/data/utils.py diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index ab6f3689..f872b29a 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -14,6 +14,19 @@ def make_centered_bboxes( To be used with `kornia.geometry.transform.crop_and_resize`in the following (clockwise) order: top-left, top-right, bottom-right and bottom-left. + + Args: + centroids: A tensor of centroids with shape (channels, 2), where channels is the number of centroids, + and the last dimension represents x and y coordinates. + box_height: The desired height of the bounding boxes. + box_width: The desired width of the bounding boxes. + + Returns: + torch.Tensor: A tensor containing bounding box coordinates for each centroid. The output tensor + has shape (channels, 4, 2), where channels is the number of centroids, and the second dimension + represents the four corner points of the bounding boxes, each with x and y coordinates. + The order of the corners follows a clockwise arrangement: top-left, top-right, + bottom-right, and bottom-left. """ half_h = box_height / 2 half_w = box_width / 2 @@ -31,38 +44,43 @@ def make_centered_bboxes( # Get bounding box corners = torch.stack([top_left, top_right, bottom_right, bottom_left], dim=-2) - return corners + offset = torch.tensor([ + [+.5, +.5], + [-.5, +.5], + [-.5, -.5], + [+.5, -.5] + ]) + + return corners + offset def normalize_bboxes( bboxes: torch.Tensor, image_height: int, image_width: int ) -> torch.Tensor: - """Normalize bounding box coordinates to the range [0, 1]. + """Normalizes bounding boxes by image width and height. - This is useful for transforming points for PyTorch operations that require - normalized image coordinates. + This function takes a tensor of bounding boxes and normalizes them based on the + provided image width and height. Args: - bboxes: Tensor of shape (n_bboxes, 4) and dtype torch.float32, where the last axis - corresponds to (y1, x1, y2, x2) coordinates of the bounding boxes. - image_height: Scalar integer indicating the height of the image. - image_width: Scalar integer indicating the width of the image. + bboxes: Bounding boxes with shape (samples, 4, 2), where each box + is defined in the order: top-left, top-right, bottom-right, and bottom-left. + The coordinates must be in the (x, y) order. The coordinates compose a + rectangle with a shape of (N1, N2). + image_height: Height of the image. + image_width: Width of the image. Returns: - Tensor of the normalized points of the same shape as `bboxes`. - - The normalization applied to each point is `x / (image_width - 1)` and - `y / (image_width - 1)`. - - See also: unnormalize_bboxes + torch.Tensor: Normalized bounding boxes with shape (samples, 4, 2), where each box + is defined in the order: top-left, top-right, bottom-right, and bottom-left, + and coordinates are normalized to the range [0, 1]. """ - # Compute normalizing factor of shape (1, 4). factor = ( - torch.tensor( - [[image_height, image_width, image_height, image_width]], - dtype=torch.float32, - ) - - 1 + torch.tensor( + [[image_width, image_height]], + dtype=torch.float32, + ) + - 1 ) # Normalize and return. diff --git a/sleap_nn/data/utils.py b/sleap_nn/data/utils.py new file mode 100644 index 00000000..8088a266 --- /dev/null +++ b/sleap_nn/data/utils.py @@ -0,0 +1,34 @@ +from typing import Tuple +import torch + +def make_grid_vectors( + image_height: int, image_width: int, output_stride: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Make sampling grid vectors from image dimensions. + + This is a useful function for creating the x- and y-vectors that define a sampling + grid over an image space. These vectors can be used to generate a full meshgrid or + for equivalent broadcasting operations. + + Args: + image_height: Height of the image grid that will be sampled, specified as a + scalar integer. + image_width: width of the image grid that will be sampled, specified as a + scalar integer. + output_stride: Sampling step size, specified as a scalar integer. This can be + used to specify a sampling grid that has a smaller shape than the image + grid but with values span the same range. This can be thought of as the + reciprocal of the output scale, i.e., it will induce subsampling when set to + values greater than 1. + + Returns: + Tuple of grid vectors (xv, yv). These are tensors of dtype tf.float32 with + shapes (grid_width,) and (grid_height,) respectively. + + The grid dimensions are calculated as: + grid_width = image_width // output_stride + grid_height = image_height // output_stride + """ + xv = torch.arange(0, image_width, step=output_stride, dtype=torch.float32) + yv = torch.arange(0, image_height, step=output_stride, dtype=torch.float32) + return xv, yv \ No newline at end of file diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index 71879e3a..d0bf96f7 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -3,8 +3,81 @@ import numpy as np from typing import Tuple, Optional from sleap_nn.data.instance_cropping import make_centered_bboxes, normalize_bboxes +from kornia.geometry.transform import crop_and_resize +def crop_bboxes( + images: torch.Tensor, bboxes: torch.Tensor, sample_inds: torch.Tensor +) -> torch.Tensor: + """Crop bounding boxes from a batch of images. + + Args: + images: Tensor of shape (samples, channels, height, width) of a batch of images. + bboxes: Tensor of shape (n_bboxes, 4, 2) and dtype torch.float32, where n_bboxes is the number of centroids, and the second dimension + represents the four corner points of the bounding boxes, each with x and y coordinates. + The order of the corners follows a clockwise arrangement: top-left, top-right, + bottom-right, and bottom-left. This can be generated from centroids using `make_centered_bboxes`. + sample_inds: Tensor of shape (n_bboxes,) specifying which samples each bounding + box should be cropped from. + + Returns: + A tensor of shape (n_bboxes, crop_height, crop_width, channels) of the same + dtype as the input image. The crop size is inferred from the bounding box + coordinates. + + Notes: + This function expects bounding boxes with coordinates at the centers of the + pixels in the box limits. Technically, the box will span (x1 - 0.5, x2 + 0.5) + and (y1 - 0.5, y2 + 0.5). + + For example, a 3x3 patch centered at (1, 1) would be specified by + (y1, x1, y2, x2) = (0, 0, 2, 2). This would be exactly equivalent to indexing + the image with `image[:, :, 0:3, 0:3]`. + + See also: `make_centered_bboxes` + """ + # Compute bounding box size to use for crops. + box_size = ( + bboxes[0, 3, 1] - bboxes[0, 0, 1], # height + bboxes[0, 1, 0] - bboxes[0, 0, 0] # width + ) + + # Crop. + crops = crop_and_resize( + images[sample_inds], # (n_boxes, channels, height, width) + boxes=bboxes, + size=box_size + ) + + # Cast back to original dtype and return. + crops = crops.to(images.dtype) + return crops + +def integral_regression( + cms: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute regression by integrating over the confidence maps on a grid. + + Args: + cms: Confidence maps with shape (samples, channels, height, width). + xv: X grid vector torch.float32 of grid coordinates to sample. + yv: Y grid vector torch.float32 of grid coordinates to sample. + + Returns: + A tuple of (x_hat, y_hat) with the regressed x- and y-coordinates for each + channel of the confidence maps. + + x_hat and y_hat are of shape (samples, channels) + """ + # Compute normalizing factor. + z = torch.sum(cms, dim=[2, 3]) + + # Regress to expectation. + x_hat = torch.sum(xv.view(1, 1, 1, -1) * cms, dim=[2, 3]) / z + y_hat = torch.sum(yv.view(1, 1, -1, 1) * cms, dim=[2, 3]) / z + + return x_hat, y_hat + def find_global_peaks_rough( cms: torch.Tensor, threshold: float = 0.1 ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -32,12 +105,12 @@ def find_global_peaks_rough( max_indices_x = max_indices_x.squeeze(dim=(2, 3)) # (samples, channels) max_indices_y = max_indices_y.max(dim=3).values # (samples, channels, 1) max_values = max_values.squeeze(-1).squeeze(-1) # (samples, channels) - peak_points = torch.cat([max_indices_x.unsqueeze(-1), max_indices_y], dim=-1) + peak_points = torch.cat([max_indices_x.unsqueeze(-1), max_indices_y], dim=-1).to(torch.float32) # Create masks for values below the threshold. below_threshold_mask = max_values < threshold # Replace values below the threshold with NaN. - max_values[below_threshold_mask] = float("nan") + peak_points[below_threshold_mask] = float("nan") return peak_points, max_values From 36f65735edf11d20a7afe35904f7f9d83ab8c113 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 10 Aug 2023 09:31:04 -0700 Subject: [PATCH 08/55] finished find_global_peaks with integral regression over centroid crops! --- sleap_nn/inference/peak_finding.py | 83 ++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 3 deletions(-) diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index d0bf96f7..4a885975 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -37,9 +37,10 @@ def crop_bboxes( See also: `make_centered_bboxes` """ # Compute bounding box size to use for crops. - box_size = ( - bboxes[0, 3, 1] - bboxes[0, 0, 1], # height - bboxes[0, 1, 0] - bboxes[0, 0, 0] # width + height = bboxes[0, 3, 1] - bboxes[0, 0, 1] + width = bboxes[0, 1, 0] - bboxes[0, 0, 0] + box_size = tuple( + np.round((height + 1, width + 1)).astype(np.int32) ) # Crop. @@ -114,3 +115,79 @@ def find_global_peaks_rough( peak_points[below_threshold_mask] = float("nan") return peak_points, max_values + +def find_global_peaks( + cms: torch.Tensor, + threshold: float = 0.2, + refinement: Optional[str] = None, + integral_patch_size: int = 5, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Find global peaks with optional refinement. + + Args: + cms: Confidence maps. Tensor of shape (samples, channels, height, width). + threshold: Minimum confidence threshold. Peaks with values below this will + ignored. + refinement: If `None`, returns the grid-aligned peaks with no refinement. If + `"integral"`, peaks will be refined with integral regression. + integral_patch_size: Size of patches to crop around each rough peak as an + integer scalar. + + Returns: + A tuple of (peak_points, peak_vals). + + peak_points: float32 tensor of shape (samples, channels, 2), where the last axis + indicates peak locations in xy order. + + peak_vals: float32 tensor of shape (samples, channels) containing the values at + the peak points. + """ + # Find grid aligned peaks. + rough_peaks, peak_vals = find_global_peaks_rough( + cms, threshold=threshold + ) # (samples, channels, 2) + + # Return early if not refining or no rough peaks found. + if refinement is None or torch.isnan(rough_peaks).all(): + return rough_peaks, peak_vals + + if refinement == "integral": + crop_size = integral_patch_size + else: + return rough_peaks, peak_vals + + # Flatten samples and channels to (n_peaks, 2). + samples = cms.size(0) + channels = cms.size(1) + rough_peaks = rough_peaks.view(samples * channels, 2) + + # Keep only peaks that are not NaNs. + valid_idx = torch.where(~torch.isnan(rough_peaks[:, 0]))[0].squeeze(0) + valid_peaks = rough_peaks[valid_idx] + + # Make bounding boxes for cropping around peaks. + bboxes = make_centered_bboxes( + valid_peaks, box_height=crop_size, box_width=crop_size + ) + + # Crop patch around each grid-aligned peak. + cms = torch.reshape( + cms, + [samples * channels, 1, cms.size(2), cms.size(3)], + ) + cm_crops = crop_bboxes(cms, bboxes, valid_idx) + + # Compute offsets via integral regression on a local patch. + if refinement == "integral": + gv = torch.arange(crop_size, dtype=torch.float32) - ((crop_size - 1) / 2) + dx_hat, dy_hat = integral_regression(cm_crops, xv=gv, yv=gv) + offsets = torch.cat([dx_hat, dy_hat], dim=1) + + # Apply offsets. + refined_peaks = rough_peaks.clone() + refined_peaks[valid_idx] += offsets + + # Reshape to (samples, channels, 2). + refined_peaks = refined_peaks.reshape(samples, channels, 2) + + return refined_peaks, peak_vals \ No newline at end of file From b17af286bea15a8460dfaa94a952d40010ef34b8 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 10 Aug 2023 09:38:08 -0700 Subject: [PATCH 09/55] reformatted with pydocstyle & black --- sleap_nn/data/instance_cropping.py | 21 ++++++++------------- sleap_nn/data/utils.py | 4 +++- sleap_nn/inference/peak_finding.py | 15 +++++++++------ 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index f872b29a..926f5e01 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -44,12 +44,7 @@ def make_centered_bboxes( # Get bounding box corners = torch.stack([top_left, top_right, bottom_right, bottom_left], dim=-2) - offset = torch.tensor([ - [+.5, +.5], - [-.5, +.5], - [-.5, -.5], - [+.5, -.5] - ]) + offset = torch.tensor([[+0.5, +0.5], [-0.5, +0.5], [-0.5, -0.5], [+0.5, -0.5]]) return corners + offset @@ -57,10 +52,10 @@ def make_centered_bboxes( def normalize_bboxes( bboxes: torch.Tensor, image_height: int, image_width: int ) -> torch.Tensor: - """Normalizes bounding boxes by image width and height. + """Normalize bounding boxes by image width and height. This function takes a tensor of bounding boxes and normalizes them based on the - provided image width and height. + provided image width and height. Args: bboxes: Bounding boxes with shape (samples, 4, 2), where each box @@ -76,11 +71,11 @@ def normalize_bboxes( and coordinates are normalized to the range [0, 1]. """ factor = ( - torch.tensor( - [[image_width, image_height]], - dtype=torch.float32, - ) - - 1 + torch.tensor( + [[image_width, image_height]], + dtype=torch.float32, + ) + - 1 ) # Normalize and return. diff --git a/sleap_nn/data/utils.py b/sleap_nn/data/utils.py index 8088a266..8b1b8434 100644 --- a/sleap_nn/data/utils.py +++ b/sleap_nn/data/utils.py @@ -1,6 +1,8 @@ +"""Miscellaneous utility functions for data processing.""" from typing import Tuple import torch + def make_grid_vectors( image_height: int, image_width: int, output_stride: int = 1 ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -31,4 +33,4 @@ def make_grid_vectors( """ xv = torch.arange(0, image_width, step=output_stride, dtype=torch.float32) yv = torch.arange(0, image_height, step=output_stride, dtype=torch.float32) - return xv, yv \ No newline at end of file + return xv, yv diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index 4a885975..51dc5b82 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -39,21 +39,20 @@ def crop_bboxes( # Compute bounding box size to use for crops. height = bboxes[0, 3, 1] - bboxes[0, 0, 1] width = bboxes[0, 1, 0] - bboxes[0, 0, 0] - box_size = tuple( - np.round((height + 1, width + 1)).astype(np.int32) - ) + box_size = tuple(np.round((height + 1, width + 1)).astype(np.int32)) # Crop. crops = crop_and_resize( images[sample_inds], # (n_boxes, channels, height, width) boxes=bboxes, - size=box_size + size=box_size, ) # Cast back to original dtype and return. crops = crops.to(images.dtype) return crops + def integral_regression( cms: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -79,6 +78,7 @@ def integral_regression( return x_hat, y_hat + def find_global_peaks_rough( cms: torch.Tensor, threshold: float = 0.1 ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -106,7 +106,9 @@ def find_global_peaks_rough( max_indices_x = max_indices_x.squeeze(dim=(2, 3)) # (samples, channels) max_indices_y = max_indices_y.max(dim=3).values # (samples, channels, 1) max_values = max_values.squeeze(-1).squeeze(-1) # (samples, channels) - peak_points = torch.cat([max_indices_x.unsqueeze(-1), max_indices_y], dim=-1).to(torch.float32) + peak_points = torch.cat([max_indices_x.unsqueeze(-1), max_indices_y], dim=-1).to( + torch.float32 + ) # Create masks for values below the threshold. below_threshold_mask = max_values < threshold @@ -116,6 +118,7 @@ def find_global_peaks_rough( return peak_points, max_values + def find_global_peaks( cms: torch.Tensor, threshold: float = 0.2, @@ -190,4 +193,4 @@ def find_global_peaks( # Reshape to (samples, channels, 2). refined_peaks = refined_peaks.reshape(samples, channels, 2) - return refined_peaks, peak_vals \ No newline at end of file + return refined_peaks, peak_vals From a5065790929ad9822280d1052c8ba4be1f678639 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 10 Aug 2023 12:50:47 -0700 Subject: [PATCH 10/55] moved make_grid_vectors to data/utils --- sleap_nn/data/confidence_maps.py | 28 +--------------------------- tests/data/test_confmaps.py | 2 +- 2 files changed, 2 insertions(+), 28 deletions(-) diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index e2b924f9..a2aed607 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -1,6 +1,7 @@ """Generate confidence maps.""" from torch.utils.data.datapipes.datapipe import IterDataPipe from typing import Optional +from sleap_nn.data.utils import make_grid_vectors import sleap_io as sio import torch @@ -42,33 +43,6 @@ def make_confmaps( return cm -def make_grid_vectors(image_height: int, image_width: int, output_stride: int): - """Make sampling grid vectors from image dimensions. - - Args: - image_height: Height of the image grid that will be sampled, specified as a - scalar integer. - image_width: width of the image grid that will be sampled, specified as a - scalar integer. - output_stride: Sampling step size, specified as a scalar integer. - - Returns: - Tuple of grid vectors (xv, yv). These are tensors of dtype torch.float32 with - shapes (grid_width,) and (grid_height,) respectively. - - The grid dimensions are calculated as: - grid_width = image_width // output_stride - grid_height = image_height // output_stride - """ - xv = torch.arange(0, image_width, step=output_stride).to( - torch.float32 - ) # (image_width,) - yv = torch.arange(0, image_height, step=output_stride).to( - torch.float32 - ) # (image_height,) - return xv, yv - - class ConfidenceMapGenerator(IterDataPipe): """DataPipe for generating confidence maps. diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index c93ff269..453b407c 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -6,8 +6,8 @@ from sleap_nn.data.confidence_maps import ( ConfidenceMapGenerator, make_confmaps, - make_grid_vectors, ) +from sleap_nn.data.utils import make_grid_vectors def test_confmaps(minimal_instance): From 02babb1b87c51f0ee650d69f2be1c4630760ecbe Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 10 Aug 2023 13:00:08 -0700 Subject: [PATCH 11/55] removed normalize_bboxes --- sleap_nn/data/instance_cropping.py | 34 ---------------------------- sleap_nn/inference/peak_finding.py | 2 +- tests/data/test_augmentation.py | 3 +-- tests/data/test_instance_cropping.py | 31 ------------------------- 4 files changed, 2 insertions(+), 68 deletions(-) diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 926f5e01..a2d6ad53 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -49,40 +49,6 @@ def make_centered_bboxes( return corners + offset -def normalize_bboxes( - bboxes: torch.Tensor, image_height: int, image_width: int -) -> torch.Tensor: - """Normalize bounding boxes by image width and height. - - This function takes a tensor of bounding boxes and normalizes them based on the - provided image width and height. - - Args: - bboxes: Bounding boxes with shape (samples, 4, 2), where each box - is defined in the order: top-left, top-right, bottom-right, and bottom-left. - The coordinates must be in the (x, y) order. The coordinates compose a - rectangle with a shape of (N1, N2). - image_height: Height of the image. - image_width: Width of the image. - - Returns: - torch.Tensor: Normalized bounding boxes with shape (samples, 4, 2), where each box - is defined in the order: top-left, top-right, bottom-right, and bottom-left, - and coordinates are normalized to the range [0, 1]. - """ - factor = ( - torch.tensor( - [[image_width, image_height]], - dtype=torch.float32, - ) - - 1 - ) - - # Normalize and return. - normalized_bboxes = bboxes / factor - return normalized_bboxes - - class InstanceCropper(IterDataPipe): """Datapipe for cropping instances. diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index 51dc5b82..50037a84 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -2,7 +2,7 @@ import torch import numpy as np from typing import Tuple, Optional -from sleap_nn.data.instance_cropping import make_centered_bboxes, normalize_bboxes +from sleap_nn.data.instance_cropping import make_centered_bboxes from kornia.geometry.transform import crop_and_resize diff --git a/tests/data/test_augmentation.py b/tests/data/test_augmentation.py index 9b57f63b..c262443b 100644 --- a/tests/data/test_augmentation.py +++ b/tests/data/test_augmentation.py @@ -1,5 +1,4 @@ -"""Module for testing augmentations with Kornia""" - +"""Module for testing augmentations with Kornia.""" from sleap_nn.data.augmentation import KorniaAugmenter, RandomUniformNoise from sleap_nn.data.providers import LabelsReader from sleap_nn.data.normalization import Normalizer diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 279d7946..e49ded26 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -2,43 +2,12 @@ import torch from sleap_nn.data.instance_cropping import ( make_centered_bboxes, - normalize_bboxes, InstanceCropper, ) from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.normalization import Normalizer -def test_normalize_bboxes(minimal_instance): - bboxes = torch.Tensor( - [ - [72.4970, 130.5748, 172.4970, 230.5748], - [3.0000, 5.5748, 100.0000, 220.1235], - ] - ) - - norm_bboxes = normalize_bboxes(bboxes, image_height=200, image_width=300) - - gt = torch.Tensor( - [ - [ - 0.3643065392971039, - 0.4367050230503082, - 0.8668190836906433, - 0.7711531519889832, - ], - [ - 0.015075377188622952, - 0.01864481531083584, - 0.5025125741958618, - 0.7361990213394165, - ], - ] - ) - - assert torch.equal(norm_bboxes, gt) - - def test_instance_cropper(minimal_instance): datapipe = LabelsReader.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) From 373f4b16e3fc2ff42edf7a65cdb590b9e45f624f Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 10 Aug 2023 13:12:45 -0700 Subject: [PATCH 12/55] added tests docstrings --- tests/__init__.py | 1 + tests/fixtures/datasets.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..b26e62c1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the sleap_nn package.""" \ No newline at end of file diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index c818b430..c5087831 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -1,3 +1,4 @@ +"""Dataset fixtures for unit testing.""" import pytest from pathlib import Path From 63513142ea79bdd2fd130351f274a54c4c4862f5 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 10 Aug 2023 13:25:37 -0700 Subject: [PATCH 13/55] sorted imports with isort --- sleap_nn/data/augmentation.py | 13 +++++++------ sleap_nn/data/confidence_maps.py | 6 ++++-- sleap_nn/data/instance_centroids.py | 3 ++- sleap_nn/data/instance_cropping.py | 7 ++++--- sleap_nn/data/providers.py | 6 +++--- sleap_nn/data/utils.py | 1 + sleap_nn/inference/peak_finding.py | 8 +++++--- tests/data/test_augmentation.py | 9 +++++---- tests/data/test_confmaps.py | 10 ++++------ tests/data/test_instance_centroids.py | 5 +++-- tests/data/test_instance_cropping.py | 8 +++----- tests/data/test_normalization.py | 5 +++-- tests/data/test_providers.py | 3 ++- tests/fixtures/datasets.py | 3 ++- 14 files changed, 48 insertions(+), 39 deletions(-) diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index f781143e..fb095405 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -1,15 +1,16 @@ """This module implements data pipeline blocks for augmentation operations.""" -from typing import Tuple, Dict, Any, Optional, Union, Text -import torch -from torch.utils.data.datapipes.datapipe import IterDataPipe +from typing import Any, Dict, Optional, Text, Tuple, Union + import kornia as K -from kornia.core import Tensor -from kornia.augmentation.container import AugmentationSequential +import torch from kornia.augmentation._2d.geometric.base import GeometricAugmentationBase2D from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D +from kornia.augmentation.container import AugmentationSequential +from kornia.augmentation.utils.param_validation import _range_bound from kornia.constants import Resample, SamplePadding +from kornia.core import Tensor from kornia.geometry.transform import warp_affine -from kornia.augmentation.utils.param_validation import _range_bound +from torch.utils.data.datapipes.datapipe import IterDataPipe class RandomUniformNoise(IntensityAugmentationBase2D): diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index a2aed607..42b58310 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -1,9 +1,11 @@ """Generate confidence maps.""" -from torch.utils.data.datapipes.datapipe import IterDataPipe from typing import Optional -from sleap_nn.data.utils import make_grid_vectors + import sleap_io as sio import torch +from torch.utils.data.datapipes.datapipe import IterDataPipe + +from sleap_nn.data.utils import make_grid_vectors def make_confmaps( diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index 81b8e59a..b15deb91 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -1,7 +1,8 @@ """Handle calculation of instance centroids.""" -from torch.utils.data.datapipes.datapipe import IterDataPipe from typing import Optional + import torch +from torch.utils.data.datapipes.datapipe import IterDataPipe def find_points_bbox_midpoint(points: torch.Tensor) -> torch.Tensor: diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index a2d6ad53..017c0559 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -1,10 +1,11 @@ """Handle cropping of instances.""" -from torch.utils.data.datapipes.datapipe import IterDataPipe from typing import Optional -import sleap_io as sio -from kornia.geometry.transform import crop_and_resize + import numpy as np +import sleap_io as sio import torch +from kornia.geometry.transform import crop_and_resize +from torch.utils.data.datapipes.datapipe import IterDataPipe def make_centered_bboxes( diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index c5503add..d090775c 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -1,8 +1,8 @@ """This module implements pipeline blocks for reading input data such as labels.""" -from torch.utils.data.datapipes.datapipe import IterDataPipe -import torch -import sleap_io as sio import numpy as np +import sleap_io as sio +import torch +from torch.utils.data.datapipes.datapipe import IterDataPipe class LabelsReader(IterDataPipe): diff --git a/sleap_nn/data/utils.py b/sleap_nn/data/utils.py index 8b1b8434..49de335e 100644 --- a/sleap_nn/data/utils.py +++ b/sleap_nn/data/utils.py @@ -1,5 +1,6 @@ """Miscellaneous utility functions for data processing.""" from typing import Tuple + import torch diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index 50037a84..48c96e2c 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -1,10 +1,12 @@ """Peak finding for inference.""" -import torch +from typing import Optional, Tuple + import numpy as np -from typing import Tuple, Optional -from sleap_nn.data.instance_cropping import make_centered_bboxes +import torch from kornia.geometry.transform import crop_and_resize +from sleap_nn.data.instance_cropping import make_centered_bboxes + def crop_bboxes( images: torch.Tensor, bboxes: torch.Tensor, sample_inds: torch.Tensor diff --git a/tests/data/test_augmentation.py b/tests/data/test_augmentation.py index c262443b..88affab9 100644 --- a/tests/data/test_augmentation.py +++ b/tests/data/test_augmentation.py @@ -1,10 +1,11 @@ """Module for testing augmentations with Kornia.""" +import pytest +import torch +from torch.utils.data import DataLoader + from sleap_nn.data.augmentation import KorniaAugmenter, RandomUniformNoise -from sleap_nn.data.providers import LabelsReader from sleap_nn.data.normalization import Normalizer -from torch.utils.data import DataLoader -import torch -import pytest +from sleap_nn.data.providers import LabelsReader def test_uniform_noise(minimal_instance): diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index 453b407c..9dbab900 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -1,12 +1,10 @@ -from sleap_nn.data.providers import LabelsReader import torch -from sleap_nn.data.instance_cropping import make_centered_bboxes, InstanceCropper + +from sleap_nn.data.confidence_maps import ConfidenceMapGenerator, make_confmaps from sleap_nn.data.instance_centroids import InstanceCentroidFinder +from sleap_nn.data.instance_cropping import InstanceCropper, make_centered_bboxes from sleap_nn.data.normalization import Normalizer -from sleap_nn.data.confidence_maps import ( - ConfidenceMapGenerator, - make_confmaps, -) +from sleap_nn.data.providers import LabelsReader from sleap_nn.data.utils import make_grid_vectors diff --git a/tests/data/test_instance_centroids.py b/tests/data/test_instance_centroids.py index 2984aacf..5cf1eb42 100644 --- a/tests/data/test_instance_centroids.py +++ b/tests/data/test_instance_centroids.py @@ -1,10 +1,11 @@ -from sleap_nn.data.providers import LabelsReader import torch + from sleap_nn.data.instance_centroids import ( InstanceCentroidFinder, - find_points_bbox_midpoint, find_centroids, + find_points_bbox_midpoint, ) +from sleap_nn.data.providers import LabelsReader def test_instance_centroids(minimal_instance): diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index e49ded26..3ded7ee4 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -1,11 +1,9 @@ -from sleap_nn.data.providers import LabelsReader import torch -from sleap_nn.data.instance_cropping import ( - make_centered_bboxes, - InstanceCropper, -) + from sleap_nn.data.instance_centroids import InstanceCentroidFinder +from sleap_nn.data.instance_cropping import InstanceCropper, make_centered_bboxes from sleap_nn.data.normalization import Normalizer +from sleap_nn.data.providers import LabelsReader def test_instance_cropper(minimal_instance): diff --git a/tests/data/test_normalization.py b/tests/data/test_normalization.py index 77373320..bcea43af 100644 --- a/tests/data/test_normalization.py +++ b/tests/data/test_normalization.py @@ -1,7 +1,8 @@ -from sleap_nn.data.providers import LabelsReader -from sleap_nn.data.normalization import Normalizer import torch +from sleap_nn.data.normalization import Normalizer +from sleap_nn.data.providers import LabelsReader + def test_normalizer(minimal_instance): p = LabelsReader.from_filename(minimal_instance) diff --git a/tests/data/test_providers.py b/tests/data/test_providers.py index d1be712d..9a53e7a9 100644 --- a/tests/data/test_providers.py +++ b/tests/data/test_providers.py @@ -1,6 +1,7 @@ -from sleap_nn.data.providers import LabelsReader import torch +from sleap_nn.data.providers import LabelsReader + def test_providers(minimal_instance): l = LabelsReader.from_filename(minimal_instance) diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index c5087831..2bebfbdc 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -1,7 +1,8 @@ """Dataset fixtures for unit testing.""" -import pytest from pathlib import Path +import pytest + @pytest.fixture def sleap_data_dir(pytestconfig): From 008a9940984913b3f87f01a4b894bbdfaaa9f5ca Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 10 Aug 2023 13:32:59 -0700 Subject: [PATCH 14/55] remove unused imports --- tests/data/test_confmaps.py | 2 +- tests/data/test_instance_centroids.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index 9dbab900..1653da0a 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -2,7 +2,7 @@ from sleap_nn.data.confidence_maps import ConfidenceMapGenerator, make_confmaps from sleap_nn.data.instance_centroids import InstanceCentroidFinder -from sleap_nn.data.instance_cropping import InstanceCropper, make_centered_bboxes +from sleap_nn.data.instance_cropping import InstanceCropper from sleap_nn.data.normalization import Normalizer from sleap_nn.data.providers import LabelsReader from sleap_nn.data.utils import make_grid_vectors diff --git a/tests/data/test_instance_centroids.py b/tests/data/test_instance_centroids.py index 5cf1eb42..4355620d 100644 --- a/tests/data/test_instance_centroids.py +++ b/tests/data/test_instance_centroids.py @@ -3,7 +3,6 @@ from sleap_nn.data.instance_centroids import ( InstanceCentroidFinder, find_centroids, - find_points_bbox_midpoint, ) from sleap_nn.data.providers import LabelsReader From b45619c55671919a3db8e2b0c154b5f00db98083 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 10 Aug 2023 16:25:23 -0700 Subject: [PATCH 15/55] updated test cases for instance cropping --- sleap_nn/data/instance_cropping.py | 6 +++--- tests/data/test_confmaps.py | 4 ++-- tests/data/test_instance_cropping.py | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 017c0559..80b5305f 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -32,17 +32,17 @@ def make_centered_bboxes( half_h = box_height / 2 half_w = box_width / 2 - # Get x and y values from the centroids tensor + # Get x and y values from the centroids tensor. x = centroids[..., 0] y = centroids[..., 1] - # Calculate the corner points + # Calculate the corner points. top_left = torch.stack([x - half_w, y - half_h], dim=-1) top_right = torch.stack([x + half_w, y - half_h], dim=-1) bottom_left = torch.stack([x - half_w, y + half_h], dim=-1) bottom_right = torch.stack([x + half_w, y + half_h], dim=-1) - # Get bounding box + # Get bounding box. corners = torch.stack([top_left, top_right, bottom_right, bottom_left], dim=-2) offset = torch.tensor([[+0.5, +0.5], [-0.5, +0.5], [-0.5, -0.5], [+0.5, -0.5]]) diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index 1653da0a..77562404 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -18,7 +18,7 @@ def test_confmaps(minimal_instance): assert sample["confidence_maps"].shape == (2, 100, 100) assert torch.max(sample["confidence_maps"]) == torch.Tensor( - [0.989626109600067138671875] + [0.9479378461837769] ) datapipe2 = ConfidenceMapGenerator(datapipe, sigma=3.0, output_stride=2) @@ -26,7 +26,7 @@ def test_confmaps(minimal_instance): assert sample["confidence_maps"].shape == (2, 50, 50) assert torch.max(sample["confidence_maps"]) == torch.Tensor( - [0.99739634990692138671875] + [0.9867223501205444] ) xv, yv = make_grid_vectors(2, 2, 1) diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 3ded7ee4..e8cbd5e3 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -20,10 +20,10 @@ def test_instance_cropper(minimal_instance): # test bounding box calculation gt = torch.Tensor( [ - [72.49704742431640625, 130.57481384277343750], - [172.49703979492187500, 130.57481384277343750], - [172.49703979492187500, 230.57481384277343750], - [72.49704742431640625, 230.57481384277343750], + [72.9970474243164, 131.07481384277344], + [171.99703979492188, 131.07481384277344], + [171.99703979492188, 230.07481384277344], + [72.9970474243164, 230.07481384277344] ] ) @@ -34,8 +34,8 @@ def test_instance_cropper(minimal_instance): # test samples gt = torch.Tensor( [ - [20.15515899658203125, 72.15116882324218750], - [79.84484100341796875, 27.84883117675781250], + [19.65515899658203, 71.65116882324219], + [79.34484100341797, 27.348831176757812] ] ) centered_instance = sample["instance"] From 381a49f4ecc0f1effeebff4e7933fe9a38b7b628 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 10 Aug 2023 19:28:44 -0700 Subject: [PATCH 16/55] added minimal_cms.pt fixture + unit tests --- tests/__init__.py | 2 +- tests/assets/minimal_cms.pt | Bin 0 -> 333465 bytes tests/data/test_augmentation.py | 3 --- tests/data/test_confmaps.py | 8 ++----- tests/data/test_instance_centroids.py | 6 ++--- tests/data/test_instance_cropping.py | 33 ++++++++++++++------------ tests/data/test_utils.py | 13 ++++++++++ tests/inference/test_peak_finding.py | 0 8 files changed, 37 insertions(+), 28 deletions(-) create mode 100644 tests/assets/minimal_cms.pt create mode 100644 tests/data/test_utils.py create mode 100644 tests/inference/test_peak_finding.py diff --git a/tests/__init__.py b/tests/__init__.py index b26e62c1..746f4c0b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -"""Unit tests for the sleap_nn package.""" \ No newline at end of file +"""Unit tests for the sleap_nn package.""" diff --git a/tests/assets/minimal_cms.pt b/tests/assets/minimal_cms.pt new file mode 100644 index 0000000000000000000000000000000000000000..434c57caf25d449d52787ed8aaf63f407049cb46 GIT binary patch literal 333465 zcmeEv1zZ(R*ZqV51n=m?)xPV7I8)-GPDK-Hm}A*n#oCyST7> zclO>nywCgopWpM${d{02&YAe`?CyNeoSkhPtE5a(>gtr>r%9PpLw%AvrFE*>w^vF` zEn}(;;x}3`DJ`{sQunGI2B!5$8KBvrf0xb!d!!_HNbAyPKx+REgZwr7QEqOn)FkB( zkVtmw)}^09nTlpiN@}OHre%qm{S;d%NMwBbDYa6P9FzJE?5Es-s-r-q^;1cRrDAnr z8^<=D&`&j1p`pHVtj>gfY7>)^Qd3g8B&GF8?K40#zJGF;{#}xzdnBc4_EUFn(oZ9; zpJtPOT2&jzs;NKQSLC!zQY!t|ShUCLl-u7P`}oGxazgf38Ujgyn7`3cn1bbE@gpu* zFOOgb{1snDV2YNDsXyElS{}g+JTLGM!W1nR^PMx$r#ylgaC1y{!xSwSTUkfXt~`Pn z$ZtHr6jQWZY{-Uet?~$FV55a4rX&#~2oMAa0t5kq06~BtKoB4Z5CjMU1cCoX1S(?- zxMJW9wjZ?UUIT1DSievoY(H4_d3qr9JchhIRTZi;VrGFWRA=A(Ubax3aJ`czP@U{8 z#hOr^Tqm%}0a^u-CkPM(2m%BFf&f8)AV3fx2oMAa0tA77FankF3#g1OK>2y|FepsH z$LvC&I<-43_lN2PjV$tl>QIv|Izx57ei>^G)p_0Hz7bTXXWKsrza8171Ra;1PjU$Y z1cCn<0{WFue_IKiu8MFaE!xQr_%jB|6LEq7L4Y7Y5FiK;1PB5Ifqy;%m3c>?GJXM- zu?5^;`#u1QaAn>_h}gT@9a_)yeG46+I+pjMEucDH9;?8#WAE#OD$shGRf2Os>mc$3 zf&Xm;zyn{>peNz*Rt>qB4pzE%a})gv9O#!bVimYcH@CG^1yv#u?sBdGT+e8ed*J?T z|F{X-G-_RYunMICS0Ne%0fGQQfFM8+AP5iy2m%-gRA#?`%Df{`8NYza*aEi9`0jzx zZJtPYU0WLrF`h`J-bE;K*X9+15oxAb8@!qbR={Wn5g`ctPY^H%d(lvo z^70wrp*%1riQWbW^i#EX1AgJhUe#3(REb2`)zsDis`IVrELf+oqd!)XKR=fzb;utn4qY^jBw86r;*AP5iy2m%BFf&f9_pN~LgIC^FF3#iOH0+sO# zsEjS3GW!yAy4%JadiLw|*8tB@SN7@wUP~Np4{aHdCkXuSBCxVEynLnRDwls;CgD@$OQLCVQC96 zZVNv@1)rZzDsEsqMOch8*azd{9VpU7o*+OFAP5iy2m%BFg1|o;fy#VGpfY<0R%X9| z%Df{`8NYza*aGq!#e(OwmD!hod>7%LeSQ5Ww!S6!%^iw*5$g%&W@bI7=|J)8zc?TK zTE51f5#aZVKKUqv_d<=12m>RR+u6?viovmdPry3ei(|k#7o2*6`;*Zs42)aX%#+~r zlm5^EyzlJKk$SpNtnb}C03O%pnnP6!TYTIh`8tpt>=enrWaJ_(1;h{dYYtafZr9fJA50A zn|rrj;PYcya~*gbn5V40t5kq06~BtKoB4Z{4XJp z-*^DnpR_U@y)xessLbAh*_GHYpfc|WRK_oW*aH5S1~DnN(g^sihi@`q{dp94dpd== zw&3rG1m|Reed@hC%mZ`H2}vn#Pyj<4e*ts!ni<}{(0Xb`JA=O`GxPdYu+EQ0pM$m4HY8vZ;k+%Y)ksn z``JwJTBuj^iC{l(uUej90<^!=G&?BduV?1FLUk6d9RcRFqurFjeEXtbSP-;r!&~`* zzhV0%`3YFZ`F0FgXR=vOaDNmUg@JMVGwdYz{P^B70NWfk57Gng9FuhZ0C-$CoaPDM zcXIBG#^Aj`T0jRIktYZc1PB5I0fGQQfFOW@KxNj^bDh8@2P)_|Mq{KzD)UW2@aF4M zb&@DS;Gc@XS@23IMXl6s1>W1fPp!I!P>eng{{}X1>`qSrzmxyzOdqfxex7PIKWKlZ z7ajz2=le!#q0oBzucpGGI*wabfOVGE(g*7_x#AA)PlB6TFtk6v`sRYq&x}P^z~f+p z>qYQ*3rm><_F1_0`WyKAS>4zS>?q_@9g1{?d}W@aJ-`+Yib}k#SWziv=#0|a16(n*50fGQQfFM8+AP5iy2m%BF5d=E+hR?-w zUta;gi_k0W2zZY|v;70W-#guveh&Pu{Kklv;I(SwjO_s;kcuYPGK;7tnkkYod(=2E zM>p@59w}0sXtMtV2QYI?8#5VPkMFFN;Qp+fF%FEI&05_+(awrwc+7~eDv~9d>{B%k z?7K4egc^81NNuk)6VXP9WGM8W3HJGWw4}^5xDCAX2t|cl*8_9scdIm1MVc>~9Amc? zeE&Ir#S+}39SPNRL>nQJu_f6Etg|g~33zd3Y3XyYZD?ZjQ?Rw>+aU1mlrja60ueAt zGC_bKKoB4Z5CjMU1ObA;za0YUMc`>QRcoR)c)mMykpg(nlhTWDPe%+I3SN8nDA*JMZJSf4RbX3y$JCZl(0Y1rC;-=U@7-N+e@uCT>!(}aOcS*qh>@F;#80B@!C4}+XOFsCqP+2k5qg;~bM19e`$6aPJCLX5?1m1a zJXrt9dL%XMl)FfjZ?9?C9;ws9c(y37pAwJWzA(FPO;P)S{bOWzsO-9{MeEs)R-ryr zIWF2C5f!TDnnj|xiRm<2?T)k`$a?W6`PLxhIFR+?fc3@xNc#bL)ReWh^ALMGIwXDa zq~o*<`&;jm%W%vq!T3w>W5;f0><9hhlVx?At%5o3i!OT|BV7F*c)t=0U--9APQA1T ze`mA)_)FHjMq@;_HwdxA5aWrIoNHyzOTdUcZ0`-$F-j{Cz0ZkNS*W4`#=xYvS{1SK z#L_>0-ww7@OdahA_QCMmTUYFsF4i<^NE6X4l0gt42oMAa0t5kq06~Bt@V|t>+k>)L z3%b(#N9T%?7}3(6=7_EtX4OR3(`>ts4LXOSPP>^PU%YUzpDF5ayKq7=(pP%N?bV|1 zFQ}gGgS3@Yvl}b={=%~NtB_6|=Vz;l`r)tKu>?8ij!=sd_4mKlrXfMIpq0?=-y0pZFJ}|6Xg2~4g2XM`$AV;>SKhoSM|FvPPCrdhmSDhI$#06 z@oLs9L@%FoWnaercIwa;-Wcv8<69_`e>G>uu1fwyp5^SVXm2P=Z&VqLdupK%UPssd z6AR|F^TKL7@KKN-i_ZJ$EMJN*p-Uwq?L2%beyq;+vF`j7`LWj;31FS8%8!s^NLFpV zf}Oz_WYl-GmsKVz?^MujOH{h7G;OCkU|VsgOMQ{|GFkPsDT7MsOD=tIEU#25@TKKcWELP)Ao4XLnooOA#C<>{gYCkKC&X@kRq!YfhN0M|O-5N+ZX=~T!SnKUWKeD{0g;|4LMcp6Wp?X?iG%ak%q-o*GFWF>dxo%7@u7e&{4HJRQ> z+V=J(&lWJKMq^Iz(7~tmT>fg8b{=cwAT zVd-OM;}1w*Y1zmZO$M9Y0{Ns)GAnS=& zoTs<=>L#RZ=@O?4&d+q}B2iy)#|wT8HY56I@$4nT zA081M@AHb5Fz*wr-D=CjsNu4FQQmc{{bA-h9ohIR&kpBr9*(@;=;pHo(|+(6M^1OB zy!}%}?G&ustU9}Rwt&=xS)%s|Rwts%CLVVVL0+M(blI`)m(KqD6#23FImeOXPL`(T z?UzWql`Nfh=L*0&+jV2XZQIs+Ci0q))wYtY(~;MNES)Ku(d>|rCC1A;ZM6k?FXIVG zMLxB@DSEv~iIzuCuG=a)j?0Cl0?)?OMcS&R#CXx3ysMO#MIrMtx@8%VBm-X z#)UlY#n$EjXAE-Pl=j2dt~(g1Q+@ehB)9l&wO=dm<8@@aw+*yKa!FBhTKRgg{uU?4c z&uhMEA%PcZ_h*89A2)9W!75xmH>)M8x`-kB`yv*==(OU$F%uZgU4dBzsc%Dt|r0*L$ju7_6W}WSuEMBL z-RoB4mFqOQvs|ur<82{toH*)kB_^{`>}sp>piz?1XakvOyR^l|ZK#Q<*H`SS>* zZHFuK$Z8GGZe=Ly=~1EO_TDJ8B8C@w_^T??zebjJz&kCljzhyaNS_>8^`!XS1KTO4 zEV_-nUS-vD&u}{O+LNUd_)Q6U&y=O(6J(FPpUcu2Wv^l)D<75jbQS!YB+6=x=?+yK zNh1gl1PB5I0fGQQfFSUXK%iKf;9#^mF72n>dAW2k+Xg>JPMkyux}<;!yQcB@4O`~i zTD3%5U9jXMa_+c0>Vv3%s@pP6ufkB7OHY+V8cPuW$sJ_ied0GM(o2`##R+@lUp%PSc!`{fT(y z;v788h2ih4EPK>GismEJn+6dbfpUr!WOwnHfxMRq1UU)&KPe-x zagK-}IJbv`+%YMw2W&7L# zxi%=-@Tj%j1P?dBc{~Kl{@Q) zCc{VGCamW9MW+;b`3>4{gmn5jcYQZggVHQro`22Y*DZXyi4c$Df3L(DjU0LJlI&R? zD0h80?it^Iq-VWzNSwL&l_ZVhm8-kU8fn|%>TKyzz`Is5t~tuJIh<;{EreUf5$mTs zOt^V$SjVh8k8h73+eU-U6FX{YoZ<2Mjz?wOkp32utWNIm4WFD=K7*RkG#yf<@Yq{DfU_vDT*X`O*Pb|SBF zX*q8EjZz%)I_AnqhsP++SCUSX7Y9FVeCBpb|*b9i9cA z&$KMwgY+ekReiSiIwa4Mr4we?9LWz+IxJ0FpS%sZ?iz?-^|pUm+A)u_$aVDMj4=Xi z_6#f|Wgkoz@qDG(HZ5K@LonUT(AT4A^7l5yk;iXXo!Fct-r6ypt`cKGxlHS*nDej; zZ+(>2xn1Rrz#g;mfEq>bn|F3W8zdjNsp2X)2PZm9_P0~jH>;Vgjnvs|V#`~jrt7~Z z`O66#Z9BVVS~a#k+~bm>vjQDDzfR(OmoIasBJC85JNhGa3~CMcdhMT#!EqQh1X}KKZCces|hV5{VUnjFwoQ3AG!YC>-sjq`JwfD7dHxC@VF|{#&zkcq82Sv zpho|%_`TSo9dal9Ibt+Yl=c>i8OBQ#EqPiAiDQC;hBrp?+S7m8)$T?aJg)3=H80Tm zhll~;Z*z*1XY=d_BfTadajs~+WRHet-FX-_TT}&U+u`cezWI@7KiIF4BG=}iJJ*YM zeYN!;XvZxR754FV=8eY@{$9xO!PPb1e~o89sF|@`?!C>gG>pgVhrYWcxKD8Fpu-WK zi99z2Dxo z1fOw1UdPgM+<4@{SVwLgC#+f94SD~OrK|C%0Bk>acvXjeRLLrum&d!Xc{wCP5FiK; z1PB5I0fGQQ;D3gICSqnHn?X?2(txqNwNAPP$c0W$(|)wvA~01Ts%A9UiFLb%%jK>g zvQ@NMR~Gt~5bMOI$vgJB>d*Kc=R2;$o^KYnPG|bAe*7?UUK@=7FL{4ohR1EaI}5q~ z7M0RY(9~D8Ew$+H-03~H3=?EZ5?=P#6`X@BT_$MF3QAHfY1o$ME6wT`^%lN+J7mQY z&6I|T2H$fUBRRw%n@a+J+iv4Vs$bh}VEN2Q6{(Y0^@L!Ke8o*eC2#*2$5mB&AbH2G znxh4L`$S0+Mf>6J_?gAycx-=`H-8|g$5nz#*J>ku?YFfKWaYTW2Wd)3u9x;eD=7KN z3nWGR7I{fv-=WJqY1O#e_DAkW+moqBC&Brlb&8uOgzdcHVn>IifA;5>GfCZ|uu+pS!aSKs-^W5WhMjuG{Rp>;|}FXvs8Q}!&Y=Cfl6ohACy zsZ+qaCI=tQLfUq?1zcCiv2n*GjK+=Upo&x?knld${GZ;on;)@@x*~ z$kEA4mmM#Mw&UsWbJ4Lj)^Cyb3$%1u>RTQ7-V8fcy^!}mS@qObNpWm z87dc@Hgo$32Xu~j>cMddVE=;|A73EvKjO7e)PV)Q>?2W@7%z`^GvMWr2tj}#KoB4Z z5CjMU|5OCPlNCuvOH<*!os_|wwh9W^=q^a2bPoqhF&Mr~#s^n{^KIm$5ab+Z&0p2% zL&<{mQ}3|v$a-Ab9%8g+a&>n_(Xf*yXuc02%5_jt996t z;l{HI+X>b~`-{x(2cK!naIPN_!+C3+bb-i*NzzwuOkmdiXEv&dVkH*Nz;yI#$IDok?WQR6#N9| zD0O3^5xs9-ix2CO^(4%7lUP3ZzX5OF>fIgT=Q5;)DRMw6+T9Z9gqv5@dbPH>x7Ida zqzz@DW~4x8wJwsg}*a+ShK?FIdLpQ}X0wK*6Pb!#?~tITy2Qg^x#n%9qU zxP!QYDC&h4zyECPn6F;srl1@4Y@B;!$AKL;Q4;nzcsSDR@f4Zs;yKvIwunwwo)*4L5L1bKu4umR=QHJKTIa zTrmF+4_DoGj&j-a=a+%J^EY#=yQ_4$((%bo-aPwZ1+b$W9 zEJLYF)J_p5Q>L`Fo=~|{B6`E;*MCKSrA-&KW7T^#9v5LtE$CO5;UhPOpAgs&Zp__{ z<1kOxeM0A>AD^q_F;Iv#`a_vwr=+Lxn&dNTzz7nT?{b{8B-{* zANY**MK0m!q$n`!uOChC2sqIAIw#pGdB?WzzA*i$IeoQY&V7!@`YO~;)mmcPEJBq! z)7}T053GNvX-9wO?)t_-NM2u2dPPX{kfUuWm1E*NyC>3qfFhH9vlOilbgX9k^P(wI z2Ss%2ZqBJv(=Fe1Zxysb3)SEp)pN#E79Fd4&~!eM-!LblqAN`3`|ne_v>$$cQ5w>I zfa;Isq)au-AHJbSvz6qoi?h|$bJn8@S7$`nm`7o+b<%f<+kM7W|GF-m$EVpc#g5yo z1EW3$en-JtuG0I%?(#4l=(BJ|b*tSSkk=P0xMHmpkIU#kITvy3_yiP0tP0hzZvfAj z;-|eKRHlMNv*UYB+wuY=fdi|ed7Pb_$A*s$&gEUl)lKL4uz6y~7q8Fpd=?&jSm-7u z%cl1>I>tNaiv8Q!v3X)gE9FtV_NO+z#KQ!o*VgO-&(C3NM^zs7D4jDUO1$@w3KEc) z0Z-Lj`HJ8k!xKP6yxfv`Z9|9%k|rNt3AP33U(4s&cLaqZ5_DtQ$@M#tR*P3S;Ut|V zFW%Ywqm?`v>HH;!P70oB(lXL#PmU| zob8=90m)TPXh$%HlJ%R`3;J8lISc8_nSFXO!`;^gy%Mzj?)_>=otW4ijI7J-n}Rvr zh@ESYYd)>tWiw?JUd`maYfa-Zan)sJ{a+fJVogJa*L?o*h{ro2Kp@MJ6I~V45wHSi_s*;PRG_ZyX(zRnF!^;16;p_W{w#dz^s&VzYte{5!aR@s_{0b*`)Hg^|q?#}A$ zNNu1_hrL%#6l{0!pv`UvR*rkDTe6nn8<$5E2rwGjd>B&4#N!BK3)r+`J&$+L`dgASN!5VVk#CJ&BMu%I8jBLYK z=LGX-)iV>2c8a0-O&N}2-*>TqH)%G#h#cDUMydO}UZFt!x#A??&NuXK3h<}<)J?yS zPuABt8cS|>6WAO^&0%}ZKaOvmQaVu`R(R`YvYS@4W-71D$d5a#vE_*!R}F8^ zyDsvc$IIoo3*SEH-RmY5ERpNSzWBZq?|x$Qc7@z?bK#^lPbNaIhim{OolNvP%e!Cv zYIslPxuJ3SvCyA4wo(jqc{rl5YgVTzug{u2<CIbT!_ggd?=pDpVp-o`NKPxO zKErqC1l*XGx09h{UZXQ_e$0wSyy(Tes!?V|U{lMEHMLhG*HoW0HDdgsC$=*ds5Lvf z8abc5b$$qAC$TbaAmGV|Jh~utbTpi6^a}s#sfqk&E3Oqzz?bnmuDT*i(|o;{$X#IZK8J{3Ymg6EzEl%N{!7 zAz0^Iz2g9qdnj%^9%v|`3&kt92y6}>XO0Q5r0XfD=^Zq4r?*=6lkWn)ePGihf8K~$~r$nnU^agXy{ z&5So$*@hO(JcPU^xw35dhHf=Yb8WkDQ|kjfZp`MRj`<~qRFs6*pXr|t;x{{;>V-F4 z`bZ5`CYv@+$_h35@{rKu(QGdI*v_&6Z&VDlPe8sW!qq-I(uy};tFC$D$IVBDgKsY9 z*&N*G1k2^PQ%auk?2Vh&M>?a*=BBOwc8_N-SQMxtcxQ?$PlvM>P34{Qw%^l@q-CVz z*Xxw>+H$kYGP%#D)eDSy_oXj6(xnRHD(l}M?{z}SiWILAqj_zs$Po`cyHUWi)rd=0 z|!Y6eJ=?mvxbv?aJ3%7~etuf{j!$n=i=r6P6NZq8mq#<;_7+GSX?L zw-va3Of1k$8%X)1VPMby*H%4-T z1=l_?Tl{6l=Vk(jRt1)72XnyrfrL`u(g0=SrYdn$j&_ ztNX;&Tfor^4=iJg<{hU;<^|K=xzk*6Ur>*5WR$5o?(hou5kJrF=fs-+yuxHezHS;` zb==h5uIGm)A_OPKm2&k;t>b!eQjZsJBK4(V$=)<#1*ifZBOv)R; z30Sq5%;Vd)IeB714Xw^R9Wgm>`lN23koO?2jC8nYCHmW`(sJCm_HF}!0 zuk@wUdH;Uctqy+ z>hN$c1kVYg?(aqNEJ5L%gy!1^^5%J*EYYAskJ}926%{~anAB9l<~OW*i-y;bJnBHP zcsFxf^%QVCro2=OB%h@7a`ZL|*1YNTrH>1lwQ}C%na+^>0?%(mu0swla$xwCzd^2G ze)c}GH&W-^=TK(dO>fqC9*3aonf>gRfalWbi4*l0PSPpCg=r(5#{SYH)Q}v<_}nZ8 zo8aHsg86gXua5*=pDxcZ=?Y^DNN&1BU~?FDcr=pZv|4kD@!7BAJ&?DaNS8M~>lpH4 z*sV7;lH(*@e-zR2s5+IJ93W_W>58eydANOnN<_W9ERE4~Jp`O? zy^p#o-3IAjcEGjJ(l-G{b=%w#)H6`T%J=$Im*7tRd6r!^R7cu!r0byLt5R3_+21O3 zxNp@F=?{a_dGI;klioM4m9amzE#T{cwoVfJhIgl?2eEAdHcyXPE+0PHchgr*7gVlmDAT9Mp>P{2&B1-Q0T+z86?|xz4ZwAsf$E~aSwaMz-_&E1@g6o=_Plpxu zPtf6}rQ5X>znT6OR9;cGS1YS`D(0-*XOrRTc}A=}ee9cClN&{beF}HTZHu&Gaf@8A zj-RtNuWiz0qT&Vnx*+c-C|NeGiC&W6e$M8J9o=@BVUK&Z3@?L;Y!Lgn3R{v05(EeW z1Ob8oL4Y9eKShAI;DcG%OP4vIbuqfwRlRQMVisR6meS`6_`QNAR5p^&l^2Q#T(S9z z5Op<-<2*d^@S#H-ptK^QgSqu8)w7s zM{nS*)v@XHanZ?R$oJYWjxl8LbVHnp3-vOQI-zx38GaG{!B4*Rbv`pidKDKOmU z#N>N|esxl+g0xc{JCMoPNmitX32Y9NyHZG-L!{FI2Aj;tW&$%x?@hT#ZoE5r2V)Cp z*wE3Aiy0lxAN2sqb65W8!FZMQdAvrzo4n6GL)*AC&sI;8lzIHyTS1vIS<8^TV(HA1 z$Zl#im9EryEbx1tXljAvMQXs5#$8h zD{i$m$_(;2XtM{~7LaiGrV;IjUwm$d2B#>Q!I6Epss}ay=9RKzh)g+dX1rt&r1e6)_9(6qnKB#pjk}9NX5n2)OJH)TVo~N_rb>F*G!135Tv199%&9M0mTZWq2ey}gWK)=TV zzjV4h{tTF{q3r0pU^kKnYdzU1s0YhDGg>HldP*?U;V58_W>cUq*di zi?lrz*3*rm>!0w)Trh^~#&5AhwQaTQ5JqQya7zK#I=iB59|~!`)VS%@%_W`(&sN6? z{GJExeI>w!);|%^BVx??{wmXFC<%1ho0=iVVvWwF)snq$>ZBGX%Uu&@+j!P9BrwND z(PU7cOzbuOQG)`jHa>I3B5rweflY`lSNeFMONk4;Z(jci#d7`7S>5`hTG`}w+ThS) zmkTyuQ@@yDJfiDHr-yqTkr=UMbC0Pitu>bT^>#*H_*_9wc-ni3pN?6XPZV}V5`C7} zUj!(vee{^|*%&*C(*QZaE$L&JG2NyV|1qHV^*D1rYlLB+A}VFmdBzs-W1%?Jl_vgk zU2>U5O=vCNiStqfb&XLOU=r9xe;JKLvqUq74gj2pzoqE9!*yJccOcPbF&VF9XCrgr zEkhPV5(xqX0fGQQfFM8+_=h84fz1XgI(oqE9oQ#vl7bi9$1>Vgz@sW!fexv?yCYb; zMN3wsP9I&{yCQ}cTKiq(y#g8Pif^kUd9W=1gQ%mIkbGN~|G|rMmdx6>bW>#cAIunH z#c)??9oGLqI-QP>9C8&|59@z`(%Gh68)}Cd?LmZ^* zVB!-01NwMMEkkgv7|f!O9Gvz)pmk~m4;A<){5m%PxfVl@doh!aB*ypwDAQk3(ZqXWuUC)&_Wux(QbQ*Sl*k<-xyq?|vpu7xD z;QNI(Nc#c1{{cs|Ja}oUF;a)!|Daru6Bynp3u!-K_dno>mIoc`twGuk*!>Ss1ir1# z6r}xt-T$D+Zz1d5>t&DhDPi|Nc+hKxKuMwYU{v)KyZ=G6_EiL&Hac`Pl1bA0AJCo7 zDrhOVzV2KIKoVVg{{vcQmuo}8I{LsNYmxQ?djA7j=i3t{fll=J+k!sQvh@B35fUqa zSEDWh+o6~X6TSbz#9fApOAlpAFB8*6B7gM$2XijGRl9D~3z?5fr1w9t=#rt`3E5ke z21U{PAF%u837ksMDipo{LE4I=j6JC49Xk;zelorP!K6Lo8NMC2YlL7d@TJRQ^!^7E zUVLWiQJUtA^dUj%Exv26Pp9`Wp3WRmGSX=*ukTrBBli7SU12&iK0i#(!S)eJZdJsL zs|_(#Or_f(7O(yHs=q>Y#;%7gt7H96*!^L37F{;O#*NkK)N&Ho8c?g>TCkm>?*qXt zfo+255nPFANhE_HKoB4Z5CjMU1Ob9T1p*BO45A`WEM#JCBj6BXvTS-`wmUQ`KY{^Dd{1obXnhD__9Wm@jsB} zXtKV)Q2%Gydj-;qcI@{TP;Gd1uMuN&kR~|n_ZLt)(*|fWI?@ZI?DrQ?I)mShM%qKz z?=PTqVl?v@zAK%_et$taosOHvmFrCoR_pwrb62)+F6j(qoo{{8~3)AkAdpZ^Q>SSs+Bpe5+CXmS-xD`NiW?=Rqd=Z;$({rv?kYfGjX zf|o_Of}Ak@{e^y~x-xc+Ju{0rIpRV3`wNGhx-*x^nJ11|cTW2I3oiGpnfqm^zYP{Y zrO7huhhJ~dd2H@HBW*j=pZOlaj30t{PbzYnuA9T)wya5OY+LQ_R%K`WZj2A}?n6as z;vxILS@ur&dt9=mU4N;KT@S01Hbfh{Kdg>hv;9@rFz&JWI4|t;!xpdfD_a?0 z&7E0!`k38P2YVjO>Y3l%gIwEX^HE2`0B2?mxg#o5B<<1U7=~*)nU~F_MZ|c?KbzYJ z(Fy*GCi2qlcrkute~m2+KTJQrlgC6OUi6YSDU>Vbtx{ra4(~kkc`{F$8i}sZUJFrX zSZlR5b$L3s&ZFPErz_hwaT`<5hLy{CAn0gJ5uNY5Uvgqr6_eWVm+j`)ZtGIs0jYr6 z%s9sG!r#G+qQ+l);wdFs9zC`o%b!79U&mEgCeT=-Zs@%@_@-p%iN~=)w zT+NKx4mF=#6U!EGL${hHyKmli&q<7bfS_Qhg!~dis_3t^wSW8ZNE?up3KL!O_bME( zSJeiUk4mfeH}TYrdD!|11ujT_r`-F~@XJHf8?jijybNV3RYVgJRI8(m#y78HjVLhm0m|@^HIlF^X(XJmj8-JgHfX8 zY0HmHVYuu5>#Epxk!Q2lF+L01J|4lIqdiQ?W)_6*Wz@jNqCwRv45!?=ZxoOF$7k`( z3FUpE19x+h7{1+YT^HUu*N)T%qHVRzOlIs9&i%A_I(19dit?mmu9l2WX-P6~o!i9Y zi0C?=R%+RM1%Wmrd8hT$^4WI0PH-YK&TMA*`H{x-Ud2Sj-h<$;dw@4>uJdykJ(c>$P# zpSIrReED%LJf&7ko|gN>|0=o28zW*RY}gWCem%`!+764G6+KHY9~s(X7Gc=}@;GAfWEbE)R znn^xEfFM8+AP5iy2m=4f2xxSL_ip>q=?s?{1yH@YX9~kzch)xH%)(_CV81q;@ma8M z*jtn13Z?%;Wjo_Fj(- zVs!dv-sH_~V`6t8xh}2Sz0XRXWZGV?scAL_K;&f|H-nuel$ zEU?W6#uhO8XE1LKzP(8!B;Tj&tG_v#;c6Nh!^?Rhovu4STGUPv^k6o_1$v~>k`=$H zbAK`8qavjuq%~|I^89yqP88+n7yGp3*=X3(0y>Trt!Jb9MWm4__;$P~Clxamsi_t* z*9coL&e1l)v3zY7*P54b*{*ynWqYIFrI4<>SNQdEo-|u|skU@JDsKL+26I26^S-Kj zAkT-i-u-J8myg2X{$LI`E1{<-?$UL!@nz4`bEE`|2r!lx5WBtb9{!|Om2IH1JVU(ynl5M~AF{H;Lu+3qh zRYv)>Bwe0Zy!&b+q|cI=Je~gKMn~TLQkE>6S5Pp34U!{4fFM8+AP5iy2m%BFBm(RL z1!S_d^kFw<^6k@NS;ZtAS$aicz{pO_T*jcNx+sUOkS+@aM;SA`wC~33ECxX82UZuat7%X(*@UrEBs|jOapl+2g>&L9N}4ch?5ZnP^Xe(+R~N}aY1O(RdW_EQewqvxY;lct zB#E9c_n{dKAF95t3Xgw1IP4>OzP9@=XLxnbyr(=H!^A$jM0wx2k}PJ9zUEI4W<0QM z(@9Z4uKm;br5!Ap^KxrU9*?WHefLFl?%pD?4`bI^bZ;hmSujFilY6gkm85xZqDIh&(B52 za#YWiV}Fr{4qBd0{r>w@xnDr> z#-;Z0C+|ApmMs>3;ra=@zHQrd3FJ1*RNZEVHkfZO-0m78)^4%%+O^a|#qz|`V@J;b z>wHVU#rQdim5EAEJ$o}4l`bo-iN$or9x6*lI`5^O3i$k#Dm#J4LGB`T@OW!|XCC;P zO{@2^;=+*btVrC^=RMEfDk8>9Ud@Y<2tj}#KoB4Z5CjMU|1AhyZ6&%~DuCd@-vvm` zRynek&Ut!Pw@Vl;XJu#6^=zJkB(kUFVrKo(#EZX zUhWL%IQuXTdoAhDKnxu!CQnKy=PRl#28YS5~PY10FT5^No{Mh}^_ z(XVH6c{*#P8^no+eVfbZ7c&JPZX72(;LTewuLw71wohb=SJReixJacV$9tAr8oIWagLSWXpIhPG z5Yj0{^jvYPlW#DHx9$6`{L zeMwMl5KD8JR-M^1MJykcKJ~%{u+G=ODqx+w-R9u_MBO+CwmH0;8V)``Els*H_FWV{ zVrfZ3YcTg(v3ykeog+7S_db-2blN%%XYjqD^;36s>2l=7ZJC{Oj@oic4+#46e@WbK;y|CZs)waCNtRa1_ zyU>YML{o-Gwixo-ftyj>I$dpn8 zASa3yw8@VEukqVxcZ$F&oFlY;Oi*=B6e}3@?ig6-=W7G7PVOKBaDSfY9tHb3yjtN8 zK0mLkw8Ostv!I@C8iH+eW_2bGI?wa*V$)?EZ~Iw`+9SpE46oRO??HdY$$CnN)kM+> z0t5kq06~BtKoIzMLLmJq{2j(C#xp|#>QC&$td}pT6^YI3pIolV_(oW~--GQZ8}s2Rv+nDn@HGgxZ7Ci$j5TP@ zm3!FhY3pv-FzZep8#A$WE^WTS)U&|$9`+nMTOo#_!aw?c#tnxz#42|(ooB~}agVXnZ}Y|#j$YI6m%;Xf$Ll6yU-xVq z9+rFn+Yc_i(2r#GagRF+Ux0c2&F3SdaLZ@qtlk!Y>si>m0Bk=99GMQrEz$5P_Pvko z&(OU(;Bj!@{59Bq(B{K9uw5*=-)Zpk`DV9^VEe&VeNmGv3lljK1PB5I0fGQQfFSUH zh(OIs)YWU_U|DnwS>pNFCE_FrsX#)21MKm+Zv#1PTC&!=MR6tthh=-jIF8hhPt>Q+UDKA*Ew!nRXbhdD4-A)~WLv2703l*?A^ zx2&(r^K;;wLP4LF%AoiyOakg6eG53CHG*Y}uh&_5Il}?$O-D5Ze~Omf3zhK1XL4 z>s(eHz@86Z`MW~xTx}ipnqrr~tG272$7Aa}`fEK3+OQV;ZPqJ)m-1`(-;X_4X9=^W z`xNk8-6m2Wyw-QEs1BGLdS-^$vF8`}_{pi_yMx#|R<*%v^3|`V;uinECU9McP7{-WBc3+%(FaHJGx*u>aXC*h*muR8615tOC4-qCAPeN)KKs}so8{? z*mgsQc)I|Y|+x`o^z|V~JwR6GGrjtMW zfo+=2YO8~7nh_)889y`@dz|Cb`OVdF)7~IM{`;EsJtQ559pxMWu=W7a9jKZ3=u`Jf=D@RvLE>FLQ#Re z=diDOu*Jwd>W@z;b3BnNSs~D;@ViAt3@=pNz~2#1vO>zo+{>dPh8db)1h#+e_KX4B z*~axu0dwcl_0_?C4yx}ifOV?IRRQZ1UN8sur{|Y*;QPP-`f%|1iK*5Ne1Dlay9U^S zXYZ4nRWL_cx!6_*XWlziuyTAuZ$JVXKn0fGQQfFM8+APD?BAP_ec z-il2ppMq!jxxLH2l?bU#!$<+V)@OBe4%mJWQTILAelT&30(jnS(Ptmne$c=!6l_1Z z{AwVW*Dt)?4s1X8WU30*NpfumwjZ=gOat2w)H;NK?FY#^d%^aDEs=k~^ZTcJUW4ri z=6z;@?FR+xOTqSoW>Na!#s7b`{eXE12Xdp3$QFEX0JmFz+;6ad#?FqJ;CCMO-Z==q zj}7tN171V^zPTUvo-8c9XQQrxIql+ZzrpM1);{OJ_sh6L9l-lJ_3#MfTvRIL(%nMA z_p$5aj)2!(b7p9P*LIWlsM=Mu6ffj*brbmB7wTvU=Ct39Lc#V{?>Z*nYwOO5Wni5Q zEkO=iQsFv%$W?S<@ZCzQQr8W5ILiGY7MMp&0Bp zjRkXBmkB4pJ`1}umBBU#RgJe`ou1>Tfpu=weFL@ylx|W2+g-jaI|jDz(RzPY{Bb;J@Y1LLdWRnUGO9#==^nS<|_wRD2PW4`WHL%5yICJ(H$%MAYQ0pnJ$z`Xv^yu}Xf zfTO40IDH%UtFXu5Wo>`+8(QN*E;+FcjSco4-vaUx%vWBeu0A&p>uB--VyM9+3IsMJLo8ELM;N<^Jw%!>xRIm zQ&LG2i#&s;pvR{}P@&0H*VUl0(DwXmJ#%jt&{g>=@|gNOt@a>)zu)?)RNGngK;wrt zHvIhk3G5fOyI}+HyI7y@cLLY*J>3IrbB+G90eqi5ZczfZZN&b00KPX5P5_(wDA2%x zUTj33AV3fx2oMAa0ssUQivYrZMOPT;gVsZ6we;K{;Agpc{0#7Xc!=LiduWda?I?AH z>SPx01lt^@Zw>(aIT&o~173S?-5$=-OI#GeyuQ`u=3pC2tZ#3yPjZ)^eqf)H6SKF2 z-!-0p;47F**&ijTNShJC;CpPF9j%R_9U4B{2gFV3@;?nb2-ex@p<}A83u|5N-P*>^TLVBtzt!6P z`g`lt0^yx?*-Ja^mkx;6Z+hKK^+!Suh;Ms(no^dd3PA7v(H+U{um)^&@lz`Oc({(VHY5is|$e)rv7tc?%1<=xu$OEljk5aS3-)7Pd_oy)3n<*3 z37&5nT;2^{7t`&z4!kVX^m7aFdev0$8HJ7jB9DUrXp#DFfz3POl+2e&A5aRu$Bw<5 z2p(6vT8sm&q0_AlAJQCrFCTFFxmF6li*jR=mw0Lw#!^ah1vvuMjXgCCJsQ#!8vECA z)>D=B=Mq8B8h5KE3Xawepl3}xJBl(6%9H$u{Z<-W2jBmWcRU2PiQn~00sCxK(>Maw zS$OCc_!$tNm=Av5dL48HFY*@->jr*qRQ=czWD}y#{xyGfu&wWFgYICP!;bb~71-b; z(gXp506~BtKoG!0-~@PQehBS33HbW=RH-2=@7X`LluzMy#^t@E>cB-qcPG*1J( zzEdw`0@y$P)Z%U6x#RZurQmmfw-nU|>l{6s2cDl@8NSty#vJm~e>Wbq{zjMG>4_TH z7LeXM47`pM`^*bGK?|7-bg+hiMP&aY5CDB5s&47|iR(hpzzJ zX}{N+1@`mGiVp_+$E}$(8?57CyxIqf+oysNU<+!WVG3Y-{gZY6%ySTORoHHFVdetW z0iAkV<1+Jn>I|3_nZyM!w(1su$4BnIM$BE|4q(ov@%KJ57tkM8v@-TpR+k(J0t5kq z06~BtAdP^topg%)_{6GafTwF&c`sV`1+R^{d5i(C1svHP4)$}j@3|YikJ^{-pFs0k z8ifdjyI{WkeL9?@-*0pWY?( z#Mb$xRc3ba?iB=nRJF|T0Gkkwb%t9@6Y)4W`jWpz3jdV_r<=4-44&xJFK9)fKf0k=p@fXI;`@DE1dy=0vq`b${+1?;P{;d(CEb~q;F z1=xOi%>F0C0U>=lSa&zrFW`h#FxVDwti=E@-=1@+Z8b<=XxW15s$oza*X@a5opTcg zGh2^50a*ISUk?kR+ji6l|CalC;{nFzu-@(k_Vaz`G$qF7pwxFN_VZnJ*k?DoZJ$@Y z#eTkrU!HGI>ud;B#eVMIQ!V?B%Oj_=*gDn2&6zfheU^Z&)3o^;hR~b7$i;rXv(A{k z94CUG@2PdGf_#i<-1-;H0vjDlHqe@|AN(!{E1yUc1PB5I0fNB42LfQyC~0;PAL(zbo}e{*w4e~k^W({@Wi@W z*u4Jd^Buto5(3y^+%Ie!yuWIx-weqpomRR4d))1<`qFcpMEced1#Ex3?OSI!b(cJG zt-cyy+q&|em0I?YU~uoxJ#3wJVFrdY2d-W_8Cz#hk4MY{S8@Gf?EVb=J9{v4m<9GP z+%Y~9Y$H5=Wf|B`G1+B5cp`D69@w-&zRf8aZ{$Z1AP5iy2m=2)1iULz)DEzfouZn6 zpR5#R23{GYsH>Ie5B7wy5@iz=f&f7PjzDGB*d`Cz09O07Zw~m*^k-lI_+13OkyGuV zC``0p;11OZI@%fh?tgfdBCs#P(A_5BHU2|UYG7M{hW`rY75);yTXvhft4Ga}0N3)2 zIkqi%U7UIK>Ah8`M7z1za-V|2j4GV z8n7cNK~1|&7i=4puf`njknLyG63-oQz-bpP^Ejir%e(I?YA#Fg_r~-dyv#bQ-dp9H zLCc+S>j`gGU^KnvRD~rMU9oirs7_aZ_FS@ORihq`xb=LBqbS;0;Y8iGcDOpMT+ilO z*tQK;r*F?xZ2JzYE~q!N=h`Na|mf&f8)AV3fx2oMAa z0t5kqz`q3oU_*eU)n+kfU_YX6O_afVk|c~;2(}+Y)Q0bCc%rUa*-PJmGjiqNF|hrh zdP?nJKoyMWhpA)1_JdiEdPUObm2t+$tZ#?iRtrtPuiv>&C^oO}R#hjk zv2mVZpP!FiaN9Oxhf>vo0!yP_Ya_A8P!-3OE>CX78Je|f?uA=V=ms5|nV!M=-ypYrkZlAF}7g;17(jF5PHsZ>1wa8 zVE===@8^YL6Z*fvEz{4AIG(CQ+5W&^3k|yCUcL59b+u26G_jqwOKRp+Z+hmHPJO^T z`P=3$jBFe|(P&jvEcV{O)4HaF=a}6z>wG>7+ZLc3{XWE_wAkXL(okl90(MNvI(iIt z3Gf)Y!fIG>6>N@f6=&&fUTSCkBhDWRn_7+RFx6jHd$@Tsv|X*wHHxrU zuiJ*%*MXZS77kog!>FX!Bh#5dMcDP!ayC`B)vac@9s5`r zyDxU(M35jr5FiK;1PB5I0fGQQfFSTMMBvyl>{rUmYYpG`4ZO~6??!o4mtj>q^A=uoH)#=%LV91kFWeS_rOBAO{_BCuDIa_T`-DlXQ zhizTdB6hbssJHidDYl)b!GaDEcTeP-b{@G0du@01s-fY@x1+2sZLDU;_EFwZwO>T2 zh0+|maf^pwuW#Q=C=Alto#&Vr-5L8l9o(Yqwn_le~O$ZFYAVF z8|ve-(6RJ~k@LZru_5yM&gS2AQ?Xn$SIM!fV+{6w7>6cSF)#~XY&&xA8SHwlwST9w zaS3IqxL60HIX%N&FB1{k<2oMAa0t5kq06~BtKoIy> zBM=W}+Z0u-nF)TExG>8`;^7SY-556ldtI9?xrydPun%=clX}?x-E2K!Uyu2W$7bhc zS~qGb2j1K|TO|X)FzSDL8*K4KM~^a^qRQ4S@7V2aZnfY`fo47j=VKd(9#QYBrEjye z$uCaA#$0WZYBk4tI~8$pp&gR575jY*S|@pknv!3S3)p)=&^q*HAeoPsfqfQQe)YoM zql2!8{u#Fz0x=IVI_8gv2Lb~G`9}~S2oMAa0t5kq072lt3<2x~Vh%2qS$g_c&Njg> z1xwmKw&f9USHG%t1F~(3Qr){7ytmq>9Vfs(3+7*QouQa!MCp6>_R<8YEj&=er+-$V z*2w9c0U#$BYhBfUe)}MkVM?#DcgCHy_G&=5Qa9`HEDD?dWH0au?D{&!sa=Ep*gR{J zesqA|f@kjS21Q_7q4xeV^6SztxT@2R5!k*2>5_d7?Nw^|IX&%)y(WIS*G0Eh^RxoA z4Ro>B_)YpAcF?*oA;7m|3N}aAt3A*%Cx5GdtM}s~<($g$Y%|wOhHvjK_MPb0fTp%s zzY}in)nk>Tyi1BNVAu2Uog&peW`Uc^IQ?OE_unHmI}h7* zKxP+cH;6nzfFM8+AP5iy2m%Cw{}=>73<5W=md|Q1ueNJ) z8LU(P!WMAb?7LT~00c8sN5sQCtx&uGu@@W8EJ_0VENuV11MEM%;?Pm>K8+dYr`j$C z8giyv>pF*TehcsUl#OR`5z?|ND%s}5P!RAi(GnnvZEGw@tA&^lntpefkC z?%U7Jbe%TSHTc`AG1&a{T#IuKlU?42^vkiu=D|MG4q9c*?^!LfU>r8z)>_`rBsp$= zNNmGd*tUSDy63fT1y2l8HA;#SB7{z&4I4n zv-=F!XOql9dwsUlZ)^AJ_91Lv54xVlW_Qi5<+xz)H9_lK7-6E@XlPB)I}f!QXef<& zZ)Q*H8Wk*@;caX+tWOmxQSbia#SpPPh2q)ODZ+}r=V6lPoO*;eIXP>RgH#f?> z)C$>bRzz?d*m%Eny^h(U?Gq?S61>0swC@>I%Q*j|SnGvaAICm%+WKLq^|#1V*mYf) zobB>X<%M&I-yLlKgS;V6UAC_}>80OoDYpIK@@mS(WJwSIw5@j7_5-c3!;Whn?+Q+8 z)(YEx5E2?=m%igj=+cML*!F`7N3$$DUd#!?g?oAZgY{gNIKHLH2{m&GS?f_dK-DZA zfv6AleT0X}9Fx1C@%!OIQk4eH=b2j95NQY^S>a4WtN#7;?K;oFccOXSw$ysSj}W&h z%AJf5MIw{VPl>llNU`zhxWyU!{e_Hu-K^6F><&EtH6QzZ^sO(eTF&+{2yJJ4AN##` zmvLLoW{$InsAGK^+tE9<K5s4*h^_|MHCzks@S=GCl9gFrgtn+y(k|0f;$eGGbV9D6yzJM(mJ zY`Y7s^XO-Nmtp2+*!CS-XY#w#*6tU_VB3*sosHKujFZpKmRM8pf4jVYsdv4TVvDaq zQGSakYWLVu@ICYVT=JyN`yn=wsrT_{?!Hh`l@9%*~t9)(V)$xPq4>u{O@K4 z|F6As0m~`d|M=$BRSA(Ms<&4+m5T03(d3dyF`Aa6I?fstpHRn9fIsf^s=gC`Z@87z7-o3Z=`>x+w zd*Mq1_j0L$6<<5}@`A&zYoGj!>EP^PTB2j!*T3kgF)8$+XMTqGHTm%mI!jWU!o7mG z>xlN9rE(SI?)O_gkFPTktvRqhKGj+8u{G~0WYZ^H=JxBhILM}xWQ}8FdV%Qvu(n0% zr`nEG{~Qw3j%7}?uss?mdVfn+Pg>43qjPqLtl95kQ?i+E&w~M%pm#pX`l@M8ow28n z*gDvGgLZf1iZ@uA19Clc`I%8 zJ?-oy+V4=b@RIvvSPR%6>j3+W+PUS>n2xn?N^TF)zKWthkHYoL6~Bx9tUvXwrPy;7 z=jDy|6zvl!Y*L=hgKa*x_vLql6@~JA{`CAV%;4Rl14Qquuc&g>VK-Q(sI-&7@f4{$ zmWuX`6>W>C%Hj<{xTwi--8Sq|oc6}2sT03~a4{x`+n&51x+8B-nc1n_C$bkHA8>0& z`KujIy(Rr><2uo1^IBf0XP(gRcYNa((e0a~6QB{>=dX5WFHIDEAJ_TKRld8}>x{=#tPnS=6Ig4vq+gaoP;_ac8&;Z$_%)am1Ph%crBTe#jGi>d(%o zK->Y5g-CY;Ll0`+^+2vrVZt*Mf(s6 zuDPmBa6DyRLl-!p^3#?dVJ%=m+)CI#{h;C$?AMRdJuBMpP;mRsXdSdJm)RZAYZdLQ zDC)Uo)EgS7GrMJ`61bj=_tK!TJ*%gws1epF<{CVK<_1|kQ&jUs`=9K5mimYEh1Zv) z{B4||E_SmUe<;i`dsjh!VAx_a-h1nMTNB&;J4BunzPSOqw?gDCwzLm6o*dM!IKhlT zNgU@Nm}s&7u^X>m!e^M?$Z`2jQMTLfs|yEk{LXOC5c#|D4rn^QADAF|4UV?_!RO{5 z;##z-MEbkp1pic(9|@Z#5zBp`XXc8VTg}!U zeg`&3D?XyDFKY07-Ac3kbKxal zO3R~zM1Tko0U|&IhyW2#5O@H25p7mmN8DpWX()ZD5+K?iP!#-O{W1f^?_xij9Q~nj zAS?HhXFaS1=wGmc#`dh9w2L;P*G;qYS-LT&37WHGHtKsrpt(V2V^)#?+w`{jz=_!+ zX{>Q4(eJ)flxyWQ8uB}s-SMy>IG$otnI5b;M7AcwdR1(5sh|(hG)K|vH$QI$WefTc zlLA6`W0ehG3Wxlxen(%h4R6-``mbjrlEK{_IWBocvFVBByI&U~;x`@h&#b?@Gglss zAwF$waUO82=10hBO7E276hFrhdd5s0zTLIhTSqFsWgNFPWs1Aos6y!4+VRO1_bvOn zL(ipXskZxx7E|cDU%|a_S<~!=&7CYdtXBRsi{)6ib>jMhOlm& za$YuI-~IxM59?!LrDD8ZIy@h|t?WZ+ zZjiM({8u%2PMq7|2WtVrkH)}0x^JzPn=-C^$4o}9Vj<8_I{qLVw#hrT0d8A-w4vyI zGuEgZ@bt2b>cG+s;>1v18x< zUc1zw=Th$C_yLKj&XDF1eI&v0yoEW$!TN-c^ZSc0KswX8@UNWr>03$p`YM!HefEXZ z-S@m8{WP_{&f&m{ZczNyp&je&*Y8M!>UnbchV{2UUxfNIe%)g;gSa!I_t9irWTiz$ zq4Pcx9f!hf4kkP3h0nhX>1owXF>Y5g-CYfC%stfEN|Wi|Ibe5*nwonp?O!L0^WyWgM4XJOUbXvvO`$FNEfin9bhS zVQ{RwtI17xe!KEnUwH1lK_0=1rT7Pd(B9~*^hrZk!!||DQ$4Vjaa?oZ zOPGPLwhyeknBN=?#aFX_%qZ>(#h5dj4_y->-;3FNys`)u?fh=+?8%#t5OL*r*=b1dmPG~Q%3Pm3k6Z`PE16^g56^+>;rf_=@3T}7}@ zd%yG$JjefHa4hWCH{{iOvVI|H@bfOB-vO^Ex9sFkaN9Q7B*J>guH?&b%tDEI9ULF> zBKtFFtu^D2m7)j5S2LTEXm4mPn%P91>nAu*9>z}KNiXM~7hGo)7ooy`-wP@Z95*|3 z8T^R3&@&VA|5!IHuWx`ghvwx{6BMfF{x~zQI7mL=M#ZeNG&W9z=8zvO8ER$G(hsUg z=61|#&#X-l@1pdUewA9g`Nbrj>w{%g+SFJ$H!Osn=UYP8m?oFDLiH3nTs0nh*9Ov^ zPp)rzCnsJ8`SnXr6?C^-{hgAlxIKuhPwh4Rizl@|pd@F=<5(^OQmZ%)7V4GMPlt#A z5g-CYfCvx)BJi&u0BKMxq4vX>)q1z`k>~ZZYhmUz=qKYXK_h<6xUJ>H}e&qIT6cFoT+5hOX=yLei}u zQwQXHcCKIld$4}caC$1NAAB+p-!oyq)>1+JK&YRZ+jIr@Eqq^~w5fs?9M?4k^P5?_ zYU2taADGz;ezp~wt7kTbKiR|j1R%9W`h5tr5^Bxn6YZt@YGFNM1Tko0U|&IhyW4zZzBK>n_hLV-$$y@*pV4c zH7SDTo|sL=I%9ZzKizt$jM*Y7NoEAcAJ}RXKzo+6_B?pr5sp9bSl0=TKPVf07+xcA z`=ABv*S8h~!ZyQ(4}i6eO?xU}23>aafwh{M*?r;ogABVySbN%5ImHyYKH0f_(l^8N z>oG4rg5|;TzjYUkKM-p3WhHKI6?_MCTtVV$`1ul;I0Kq%gcULT?qf5v2^lj z-hDxpHB~nvMadP+Zz+Q<_{*VOB0vO)01+SpM1TnV83;sa9)`PFP&FT(Q_L^FqK8)0 zU!$WGwefDOhTh%9%J<$EVT;yNH?a_&-!_XZg8g8tUW;IVyXfR?c#c2MLL1g8D)XD+ zwq?ypfMW@cUa4|Kzn?V!8avb`F>_}>)MnD>Jz$M%aH|*W8!p*t0muAYsX2$blN{_& z{}W&uZ{eHi=(_!rQrjOkZd`-9wA@A4OVFAER_o0CTkx|l#W@r9of~R;!{7H2l`(KV z4&$8R{>a>J!hXGp<3L#6ZZd-9FHRt=ru_v&IYfX65CI}U1c(3;AOdeiAjKCS;CJ%9 z0goTmKgQmB%{nuqBobZH%*CA_3eN|}F24fDEPS(43dbzChs$6;c&+sA&@vSg}@ueN5gBrD?@QUb6LATO=E z?QNxH@*Uh3qz?O8t}b=3<}jsutUh1;%JR~GI0En8Uf8t^_GxwQuYlKM*cUdz+J^Qo zy0DKPGUz$nHaC@InKG_?$Go&_;WYxm&l_MJX_QqCtOXpr)*D_wUghHp3vkm&4B)#o zD31sb0U|&IhyW2F0z`la5P|=P1SY&ipH5603wuqo&S39ZXU)3TSO)J^mXx6duVJtC z=>q$-YqL(nbK)aAI>Oq9UhnC!&E&D3@ctAU8?VC*9vE7}dvC7Hvx9Y{c)M0uuiB|S z57q)w;a&UQ;^rX}B0vO)01+SpM1Tko0U|&Ih`^tS0PHx++4xq$Zao*`ehS{ZU3Xho zct5+%qr1a8#ZlAqupj(UZ#Q@zedw|eV4Fc-z6)=6KJ8E)%s}h54Lry1u*VU5alHK6 zk+}Q{-(B}$g$u5N42S>`AOb{y2oM1xKm>>Y5g-CY;6H@GthZQ@alsqjI*{u&_ZxVw zdVk-I@K$Ph1svQX*N0NrMxw0_`{>--o$%h9ou@s9uapW;*o1?>9J~s$AOb{y2oM1x zKm`6x1Ooij)FqSUa}GV^(?ar}06&R*%%RRRzR}<`xFvRrXGO=%nIARWTZIc@$w8Bq zaB0^EL2O4__7mXO;nn7}iZU9HeQ%mA&u2gFLl&3V&78hyy6u99$lu%jwpDL7=ZUU-TnYSRaM1)R7R55yICoQHHy#A Date: Thu, 10 Aug 2023 22:08:23 -0700 Subject: [PATCH 17/55] added minimal_bboxes fixture; added unit tests for crop_bboxes & integral_regression --- tests/assets/minimal_bboxes.pt | Bin 0 -> 1152 bytes tests/conftest.py | 1 + tests/fixtures/inference.py | 22 +++++ tests/inference/test_peak_finding.py | 141 +++++++++++++++++++++++++++ 4 files changed, 164 insertions(+) create mode 100644 tests/assets/minimal_bboxes.pt create mode 100644 tests/fixtures/inference.py diff --git a/tests/assets/minimal_bboxes.pt b/tests/assets/minimal_bboxes.pt new file mode 100644 index 0000000000000000000000000000000000000000..73ccce9cb2e9ed8341a1fd1c64b4aec9f4c460e0 GIT binary patch literal 1152 zcmZ{kF-+S)6o#Mekhn&bIuwZ1iZZ5Bn?$fMFhEpgxk@#mND+&1Y>QYman8OVM#KW$ z7&{>b2DWx<$BrG4O1CbR7?~=@j=U$I(ZwYf{rcYh_wL=Dbr+i%B>>Ze|Joc(!)`bT zciqV8_2T^?ZTPNm>%$LGyH$a$tR3i{h!d}0cSa(NQp-t#-YATGM+Ad3PMp0(i$kr} ztzi8WGu@qlYuOa59mTHL$V`jtT@7=Z+~^ueZ#d#5))#zgodJ8yN=!fEWo9y+X-B;B z(evUc3Oo_UgVcJL_(2l*Z$eL4+9vz z*0cPq^2lQV_^}3Xfx1yhUao2jfK*ec#w+N*N4-_(V@y8LMhVom!b9m3yQqFa`_d;K z;9hP~X9`W}6Mv(>fcmF!8)Nc`Llo`zlfq@}lTW0v6B+0HB6Eqf#xbg@@C0M>iSS)& zRiP4p!y4Bp+D}fvdE^sGw9XTSS90uXjc&lk(?37YH2ean$v0pYo)>SxUXY}C!P>3r z#4E65(F(s#cQ$q03dv>~;r>mj_iv5{#@d7zRxF|(VP>qrU+g_lq|q#vQ1>`HR^k@- zTPPt Date: Thu, 10 Aug 2023 22:26:49 -0700 Subject: [PATCH 18/55] added find_global_peaks unit tests --- tests/inference/test_peak_finding.py | 99 +++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 3 deletions(-) diff --git a/tests/inference/test_peak_finding.py b/tests/inference/test_peak_finding.py index f813a79b..4051a30a 100644 --- a/tests/inference/test_peak_finding.py +++ b/tests/inference/test_peak_finding.py @@ -58,7 +58,7 @@ def test_integral_regression(minimal_bboxes, minimal_cms): [-0.37117865681648254], [0.32524189352989197], [-0.18590612709522247], - [0.06249351054430008] + [0.06249351054430008], ] ) @@ -76,7 +76,7 @@ def test_integral_regression(minimal_bboxes, minimal_cms): [-0.28016677498817444], [0.32524189352989197], [-2.0254956325516105e-06], - [-0.37117743492126465] + [-0.37117743492126465], ] ) @@ -87,7 +87,54 @@ def test_integral_regression(minimal_bboxes, minimal_cms): def test_find_global_peaks_rough(minimal_cms): - pass + cms = torch.load(minimal_cms).unsqueeze(0) + + gt_rough_peaks = torch.Tensor( + [ + [ + [27.0, 23.0], + [40.0, 40.0], + [49.0, 55.0], + [54.0, 63.0], + [56.0, 60.0], + [18.0, 32.0], + [29.0, 12.0], + [17.0, 44.0], + [44.0, 20.0], + [36.0, 70.0], + [0.0, 0.0], + [25.0, 30.0], + [34.0, 24.0], + ] + ] + ) + gt_peak_vals = torch.Tensor( + [ + [ + 0.9163541793823242, + 0.9957404136657715, + 0.929328203201294, + 0.9020472168922424, + 0.8870090246200562, + 0.8547359108924866, + 0.8420282602310181, + 0.86271071434021, + 0.863940954208374, + 0.8226016163825989, + 1.0, + 0.9693551063537598, + 0.8798434734344482, + ] + ] + ) + + rough_peaks, peak_vals = find_global_peaks_rough(cms, threshold=0.1) + + assert rough_peaks.shape == (1, 13, 2) + assert peak_vals.shape == (1, 13) + assert rough_peaks.dtype == peak_vals.dtype == torch.float32 + assert torch.equal(gt_rough_peaks, rough_peaks) + assert torch.equal(gt_peak_vals, peak_vals) def test_find_global_peaks(minimal_cms): @@ -139,3 +186,49 @@ def test_find_global_peaks(minimal_cms): assert rough_peaks.dtype == peak_vals.dtype == torch.float32 assert torch.equal(gt_rough_peaks, rough_peaks) assert torch.equal(gt_peak_vals, peak_vals) + + gt_refined_peaks = torch.Tensor( + [ + [ + [27.2498, 22.8141], + [39.9390, 40.0320], + [48.7837, 54.8141], + [53.8752, 63.3142], + [56.1249, 60.3423], + [18.2802, 31.6910], + [29.0320, 12.4346], + [17.2178, 43.6591], + [44.3712, 19.8446], + [35.6288, 69.7198], + [0.3252, 0.3252], + [24.8141, 30.0000], + [34.0625, 23.6288], + ] + ] + ) + gt_peak_vals = torch.Tensor( + [ + [ + 0.9164, + 0.9957, + 0.9293, + 0.9020, + 0.8870, + 0.8547, + 0.8420, + 0.8627, + 0.8639, + 0.8226, + 1.0000, + 0.9694, + 0.8798, + ] + ] + ) + + refined_peaks, peak_vals = find_global_peaks( + cms, refinement="integral", threshold=0.2 + ) + + torch.testing.assert_close(gt_refined_peaks, refined_peaks, atol=0.001, rtol=0.0) + torch.testing.assert_close(gt_peak_vals, peak_vals, atol=0.001, rtol=0.0) From 77785121009a3f9d163bd69fb616cfdf61525646 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 16 Aug 2023 17:13:19 -0700 Subject: [PATCH 19/55] finished find_local_peaks_rough! --- sleap_nn/inference/peak_finding.py | 59 ++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index 48c96e2c..505b67a4 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -1,6 +1,7 @@ """Peak finding for inference.""" from typing import Optional, Tuple +import kornia as K import numpy as np import torch from kornia.geometry.transform import crop_and_resize @@ -196,3 +197,61 @@ def find_global_peaks( refined_peaks = refined_peaks.reshape(samples, channels, 2) return refined_peaks, peak_vals + + +def find_local_peaks_rough( + cms: torch.Tensor, threshold: float = 0.2 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Find local maxima via non-maximum suppresion. + + Args: + cms: Tensor of shape (samples, channels, height, width). + threshold: Scalar float specifying the minimum confidence value for peaks. Peaks + with values below this threshold will not be returned. + + Returns: + A tuple of (peak_points, peak_vals, peak_sample_inds, peak_channel_inds). + peak_points: float32 tensor of shape (n_peaks, 2), where the last axis + indicates peak locations in xy order. + + peak_vals: float32 tensor of shape (n_peaks,) containing the values at the peak + points. + + peak_sample_inds: int32 tensor of shape (n_peaks,) containing the indices of the + sample each peak belongs to. + + peak_channel_inds: int32 tensor of shape (n_peaks,) containing the indices of + the channel each peak belongs to. + """ + # Build custom local NMS kernel. + kernel = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.float32) + + # Reshape to have singleton channels. + height = cms.size(2) + width = cms.size(3) + channels = cms.size(1) + flat_img = cms.reshape(-1, 1, height, width) + + # Perform dilation filtering to find local maxima per channel and reshape back. + max_img = K.morphology.dilation(flat_img, kernel) + max_img = max_img.permute(1, 0, 2, 3) + + # Filter for maxima and threshold. + argmax_and_thresh_img = (cms > max_img) & (cms > threshold) + + # Convert to subscripts. + peak_subs = torch.stack( + torch.where(argmax_and_thresh_img.permute(0, 2, 3, 1)), axis=-1 + ) + + # Get peak values. + peak_vals = cms[peak_subs[:, 0], peak_subs[:, 3], peak_subs[:, 1], peak_subs[:, 2]] + + # Convert to points format. + peak_points = peak_subs[:, [2, 1]].to(torch.float32) + + # Pull out indexing vectors. + peak_sample_inds = peak_subs[:, 0].to(torch.int32) + peak_channel_inds = peak_subs[:, 3].to(torch.int32) + + return peak_points, peak_vals, peak_sample_inds, peak_channel_inds From 9f7ac3fa23efa93d80e20770435691a66c6714b5 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 16 Aug 2023 22:46:39 -0700 Subject: [PATCH 20/55] finished find_local_peaks! --- sleap_nn/inference/peak_finding.py | 78 ++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index 505b67a4..418d569f 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -255,3 +255,81 @@ def find_local_peaks_rough( peak_channel_inds = peak_subs[:, 3].to(torch.int32) return peak_points, peak_vals, peak_sample_inds, peak_channel_inds + + +def find_local_peaks( + cms: torch.Tensor, + threshold: float = 0.2, + refinement: Optional[str] = None, + integral_patch_size: int = 5, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Find local peaks with optional refinement. + + Args: + cms: Confidence maps. Tensor of shape (samples, channels, height, width). + threshold: Minimum confidence threshold. Peaks with values below this will + ignored. + refinement: If `None`, returns the grid-aligned peaks with no refinement. If + `"integral"`, peaks will be refined with integral regression. + integral_patch_size: Size of patches to crop around each rough peak as an + integer scalar. + + Returns: + A tuple of (peak_points, peak_vals, peak_sample_inds, peak_channel_inds). + + peak_points: float32 tensor of shape (n_peaks, 2), where the last axis + indicates peak locations in xy order. + + peak_vals: float32 tensor of shape (n_peaks,) containing the values at the peak + points. + + peak_sample_inds: int32 tensor of shape (n_peaks,) containing the indices of the + sample each peak belongs to. + + peak_channel_inds: int32 tensor of shape (n_peaks,) containing the indices of + the channel each peak belongs to. + """ + # Find grid aligned peaks. + ( + rough_peaks, + peak_vals, + peak_sample_inds, + peak_channel_inds, + ) = find_local_peaks_rough(cms, threshold=threshold) + + # Return early if no rough peaks found. + if rough_peaks.size(0) == 0 or refinement is None: + return rough_peaks, peak_vals, peak_sample_inds, peak_channel_inds + + if refinement == "integral": + crop_size = integral_patch_size + else: + return rough_peaks, peak_vals, peak_sample_inds, peak_channel_inds + + # Make bounding boxes for cropping around peaks. + bboxes = make_centered_bboxes( + rough_peaks, box_height=crop_size, box_width=crop_size + ) + + # Reshape to (samples * channels, height, width, 1). + samples = cms.size(0) + channels = cms.size(1) + cms = torch.reshape( + cms, + [samples * channels, 1, cms.size(2), cms.size(3)], + ) + box_sample_inds = (peak_sample_inds * channels) + peak_channel_inds + + # Crop patch around each grid-aligned peak. + cm_crops = crop_bboxes(cms, bboxes, sample_inds=box_sample_inds) + + # Compute offsets via integral regression on a local patch. + if refinement == "integral": + gv = torch.arange(crop_size, dtype=torch.float32) - ((crop_size - 1) / 2) + dx_hat, dy_hat = integral_regression(cm_crops, xv=gv, yv=gv) + offsets = torch.cat([dx_hat, dy_hat], dim=1) + + # Apply offsets. + refined_peaks = rough_peaks + offsets + + return refined_peaks, peak_vals, peak_sample_inds, peak_channel_inds From b9869d6e1bb0b2acd615b00bcdcd87b6d9170907 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 17 Aug 2023 00:23:09 -0700 Subject: [PATCH 21/55] added unit tests for find_local_peaks and find_local_peaks_rough --- sleap_nn/inference/peak_finding.py | 3 +- tests/inference/test_peak_finding.py | 153 ++++++++++++++++++++------- 2 files changed, 116 insertions(+), 40 deletions(-) diff --git a/sleap_nn/inference/peak_finding.py b/sleap_nn/inference/peak_finding.py index 418d569f..7f1b0969 100644 --- a/sleap_nn/inference/peak_finding.py +++ b/sleap_nn/inference/peak_finding.py @@ -85,8 +85,7 @@ def integral_regression( def find_global_peaks_rough( cms: torch.Tensor, threshold: float = 0.1 ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Find the global maximum for each sample and channel. + """Find the global maximum for each sample and channel. Args: cms: Tensor of shape (samples, channels, height, width). diff --git a/tests/inference/test_peak_finding.py b/tests/inference/test_peak_finding.py index 4051a30a..1dbae3cc 100644 --- a/tests/inference/test_peak_finding.py +++ b/tests/inference/test_peak_finding.py @@ -5,6 +5,8 @@ integral_regression, find_global_peaks, find_global_peaks_rough, + find_local_peaks, + find_local_peaks_rough, ) @@ -44,46 +46,8 @@ def test_integral_regression(minimal_bboxes, minimal_cms): gv = torch.arange(crop_size, dtype=torch.float32) - ((crop_size - 1) / 2) dx_hat, dy_hat = integral_regression(cm_crops, xv=gv, yv=gv) - gt_dx_hat = torch.Tensor( - [ - [0.24976766109466553], - [-0.06099589914083481], - [-0.216335266828537], - [-0.12479443103075027], - [0.12494532763957977], - [0.28015944361686707], - [0.03200167417526245], - [0.21784470975399017], - [0.3711766004562378], - [-0.37117865681648254], - [0.32524189352989197], - [-0.18590612709522247], - [0.06249351054430008], - ] - ) - - gt_dy_hat = torch.Tensor( - [ - [-0.1858985275030136], - [0.031994160264730453], - [-0.18588940799236298], - [0.3141670227050781], - [0.3423368036746979], - [-0.3090454936027527], - [0.43461763858795166], - [-0.3408771753311157], - [-0.155443474650383], - [-0.28016677498817444], - [0.32524189352989197], - [-2.0254956325516105e-06], - [-0.37117743492126465], - ] - ) - assert dx_hat.shape == dy_hat.shape == (13, 1) assert dx_hat.dtype == dy_hat.dtype == torch.float32 - assert torch.equal(gt_dx_hat, dx_hat) - assert torch.equal(gt_dy_hat, dy_hat) def test_find_global_peaks_rough(minimal_cms): @@ -232,3 +196,116 @@ def test_find_global_peaks(minimal_cms): torch.testing.assert_close(gt_refined_peaks, refined_peaks, atol=0.001, rtol=0.0) torch.testing.assert_close(gt_peak_vals, peak_vals, atol=0.001, rtol=0.0) + + +def test_find_local_peaks_rough(minimal_cms): + cms = torch.load(minimal_cms).unsqueeze(0) # (1, 13, 80, 80) + + ( + peak_points, + peak_vals, + peak_sample_inds, + peak_channel_inds, + ) = find_local_peaks_rough(cms) + + gt_peak_points = torch.Tensor( + [ + [0.0, 0.0], + [29.0, 12.0], + [44.0, 20.0], + [27.0, 23.0], + [34.0, 24.0], + [25.0, 30.0], + [18.0, 32.0], + [40.0, 40.0], + [17.0, 44.0], + [49.0, 55.0], + [56.0, 60.0], + [54.0, 63.0], + [36.0, 70.0], + ] + ) + + gt_peak_vals = torch.Tensor( + [ + 1.0, + 0.8420282602310181, + 0.863940954208374, + 0.9163541793823242, + 0.8798434734344482, + 0.9693551063537598, + 0.8547359108924866, + 0.9957404136657715, + 0.86271071434021, + 0.929328203201294, + 0.8870090246200562, + 0.9020472168922424, + 0.8226016163825989, + ] + ) + + gt_peak_sample_inds = torch.Tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + gt_peak_channel_inds = torch.Tensor([10, 6, 8, 0, 12, 11, 5, 1, 7, 2, 4, 3, 9]) + + assert peak_points.shape == (13, 2) + assert peak_vals.shape == peak_sample_inds.shape == peak_channel_inds.shape == (13,) + assert torch.equal(gt_peak_vals, peak_vals) + assert torch.equal(gt_peak_points, peak_points) + assert torch.equal(gt_peak_sample_inds, peak_sample_inds) + assert torch.equal(gt_peak_channel_inds, peak_channel_inds) + + +def test_find_local_peaks(minimal_cms): + cms = torch.load(minimal_cms).unsqueeze(0) # (1, 13, 80, 80) + + (peak_points, peak_vals, peak_sample_inds, peak_channel_inds) = find_local_peaks( + cms, refinement="integral" + ) + + gt_peak_points = torch.Tensor( + [ + [0.32524189352989197, 0.32524189352989197], + [29.032001495361328, 12.43461799621582], + [44.371177673339844, 19.84455680847168], + [27.249767303466797, 22.814102172851562], + [34.06249237060547, 23.628822326660156], + [24.81409454345703, 29.999998092651367], + [18.28015899658203, 31.690954208374023], + [39.939002990722656, 40.0319938659668], + [17.217844009399414, 43.659122467041016], + [48.78366470336914, 54.814109802246094], + [56.12494659423828, 60.34233856201172], + [53.875205993652344, 63.31416702270508], + [35.628822326660156, 69.71983337402344], + ] + ) + + gt_peak_vals = torch.Tensor( + [ + 1.0, + 0.8420282602310181, + 0.863940954208374, + 0.9163541793823242, + 0.8798434734344482, + 0.9693551063537598, + 0.8547359108924866, + 0.9957404136657715, + 0.86271071434021, + 0.929328203201294, + 0.8870090246200562, + 0.9020472168922424, + 0.8226016163825989, + ] + ) + + gt_peak_sample_inds = torch.Tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + gt_peak_channel_inds = torch.Tensor([10, 6, 8, 0, 12, 11, 5, 1, 7, 2, 4, 3, 9]) + + assert peak_points.shape == (13, 2) + assert peak_vals.shape == peak_sample_inds.shape == peak_channel_inds.shape == (13,) + assert torch.equal(gt_peak_vals, peak_vals) + assert torch.equal(gt_peak_points, peak_points) + assert torch.equal(gt_peak_sample_inds, peak_sample_inds) + assert torch.equal(gt_peak_channel_inds, peak_channel_inds) From bfd1caccb29c02241a7b15d8b755dbef61afd68c Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 17 Aug 2023 00:49:19 -0700 Subject: [PATCH 22/55] updated test cases --- tests/inference/test_peak_finding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/inference/test_peak_finding.py b/tests/inference/test_peak_finding.py index 1dbae3cc..a6ed836e 100644 --- a/tests/inference/test_peak_finding.py +++ b/tests/inference/test_peak_finding.py @@ -305,7 +305,7 @@ def test_find_local_peaks(minimal_cms): assert peak_points.shape == (13, 2) assert peak_vals.shape == peak_sample_inds.shape == peak_channel_inds.shape == (13,) - assert torch.equal(gt_peak_vals, peak_vals) - assert torch.equal(gt_peak_points, peak_points) + torch.testing.assert_close(gt_peak_points, peak_points, atol=0.001, rtol=0.0) + torch.testing.assert_close(gt_peak_vals, peak_vals, atol=0.001, rtol=0.0) assert torch.equal(gt_peak_sample_inds, peak_sample_inds) assert torch.equal(gt_peak_channel_inds, peak_channel_inds) From a8b3c311505c40addd6b291c20bf221e5b3bbb77 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 17 Aug 2023 01:00:15 -0700 Subject: [PATCH 23/55] added more test cases for find_local_peaks --- tests/inference/test_peak_finding.py | 70 ++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/tests/inference/test_peak_finding.py b/tests/inference/test_peak_finding.py index a6ed836e..7a76a228 100644 --- a/tests/inference/test_peak_finding.py +++ b/tests/inference/test_peak_finding.py @@ -151,6 +151,14 @@ def test_find_global_peaks(minimal_cms): assert torch.equal(gt_rough_peaks, rough_peaks) assert torch.equal(gt_peak_vals, peak_vals) + rough_peaks, peak_vals = find_global_peaks(cms, refinement="invalid_input", threshold=0.2) + + assert rough_peaks.shape == (1, 13, 2) + assert peak_vals.shape == (1, 13) + assert rough_peaks.dtype == peak_vals.dtype == torch.float32 + assert torch.equal(gt_rough_peaks, rough_peaks) + assert torch.equal(gt_peak_vals, peak_vals) + gt_refined_peaks = torch.Tensor( [ [ @@ -259,6 +267,68 @@ def test_find_local_peaks_rough(minimal_cms): def test_find_local_peaks(minimal_cms): cms = torch.load(minimal_cms).unsqueeze(0) # (1, 13, 80, 80) + (peak_points, peak_vals, peak_sample_inds, peak_channel_inds) = find_local_peaks( + cms + ) + + gt_peak_points = torch.Tensor( + [ + [0.0, 0.0], + [29.0, 12.0], + [44.0, 20.0], + [27.0, 23.0], + [34.0, 24.0], + [25.0, 30.0], + [18.0, 32.0], + [40.0, 40.0], + [17.0, 44.0], + [49.0, 55.0], + [56.0, 60.0], + [54.0, 63.0], + [36.0, 70.0], + ] + ) + + gt_peak_vals = torch.Tensor( + [ + 1.0, + 0.8420282602310181, + 0.863940954208374, + 0.9163541793823242, + 0.8798434734344482, + 0.9693551063537598, + 0.8547359108924866, + 0.9957404136657715, + 0.86271071434021, + 0.929328203201294, + 0.8870090246200562, + 0.9020472168922424, + 0.8226016163825989, + ] + ) + + gt_peak_sample_inds = torch.Tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + gt_peak_channel_inds = torch.Tensor([10, 6, 8, 0, 12, 11, 5, 1, 7, 2, 4, 3, 9]) + + assert peak_points.shape == (13, 2) + assert peak_vals.shape == peak_sample_inds.shape == peak_channel_inds.shape == (13,) + assert torch.equal(gt_peak_vals, peak_vals) + assert torch.equal(gt_peak_points, peak_points) + assert torch.equal(gt_peak_sample_inds, peak_sample_inds) + assert torch.equal(gt_peak_channel_inds, peak_channel_inds) + + (peak_points, peak_vals, peak_sample_inds, peak_channel_inds) = find_local_peaks( + cms, refinement="invalid_input" + ) + + assert peak_points.shape == (13, 2) + assert peak_vals.shape == peak_sample_inds.shape == peak_channel_inds.shape == (13,) + assert torch.equal(gt_peak_vals, peak_vals) + assert torch.equal(gt_peak_points, peak_points) + assert torch.equal(gt_peak_sample_inds, peak_sample_inds) + assert torch.equal(gt_peak_channel_inds, peak_channel_inds) + (peak_points, peak_vals, peak_sample_inds, peak_channel_inds) = find_local_peaks( cms, refinement="integral" ) From 125625db9ebad9ccc5457a1c0c6f6fc45601603c Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 17 Aug 2023 01:00:39 -0700 Subject: [PATCH 24/55] updated test cases --- tests/inference/test_peak_finding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/inference/test_peak_finding.py b/tests/inference/test_peak_finding.py index 7a76a228..1b257e73 100644 --- a/tests/inference/test_peak_finding.py +++ b/tests/inference/test_peak_finding.py @@ -151,7 +151,9 @@ def test_find_global_peaks(minimal_cms): assert torch.equal(gt_rough_peaks, rough_peaks) assert torch.equal(gt_peak_vals, peak_vals) - rough_peaks, peak_vals = find_global_peaks(cms, refinement="invalid_input", threshold=0.2) + rough_peaks, peak_vals = find_global_peaks( + cms, refinement="invalid_input", threshold=0.2 + ) assert rough_peaks.shape == (1, 13, 2) assert peak_vals.shape == (1, 13) From a25d92023b3882d27b5647ed456b311898ce8110 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 17 Aug 2023 10:01:48 -0700 Subject: [PATCH 25/55] added architectures folder --- sleap_nn/architectures/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 sleap_nn/architectures/__init__.py diff --git a/sleap_nn/architectures/__init__.py b/sleap_nn/architectures/__init__.py new file mode 100644 index 00000000..06dbf151 --- /dev/null +++ b/sleap_nn/architectures/__init__.py @@ -0,0 +1 @@ +"""Modules related to model architectures.""" From 3ba92b6f952230ed8bf022dcb7ec9418d788dddf Mon Sep 17 00:00:00 2001 From: alckasoc Date: Thu, 17 Aug 2023 13:41:20 -0700 Subject: [PATCH 26/55] added maxpool2d same padding, get_act_fn; added simpleconvblock, simpleupsamplingblock, encoder, decoder; added unet --- sleap_nn/architectures/common.py | 131 ++++++ sleap_nn/architectures/encoder_decoder.py | 503 ++++++++++++++++++++++ sleap_nn/architectures/unet.py | 101 +++++ 3 files changed, 735 insertions(+) create mode 100644 sleap_nn/architectures/common.py create mode 100644 sleap_nn/architectures/encoder_decoder.py create mode 100644 sleap_nn/architectures/unet.py diff --git a/sleap_nn/architectures/common.py b/sleap_nn/architectures/common.py new file mode 100644 index 00000000..70368c14 --- /dev/null +++ b/sleap_nn/architectures/common.py @@ -0,0 +1,131 @@ +"""Common utilities for architecture and model building.""" +import torch +from torch import nn +from torch.nn import functional as F + + +class MaxPool2dWithSamePadding(nn.MaxPool2d): + """A MaxPool2d module with support for same padding. + + This class extends the torch.nn.MaxPool2d module and adds the ability + to perform 'same' padding, similar to 'same' padding in convolutional + layers. When 'same' padding is specified, the input tensor is padded + with zeros to ensure that the output spatial dimensions match the input + spatial dimensions as closely as possible. + + Args: + nn.MaxPool2d arguments: Arguments that are passed to the parent + torch.nn.MaxPool2d class. + + Attributes: + Inherits all attributes from torch.nn.MaxPool2d. + + Methods: + forward(x: torch.Tensor) -> torch.Tensor: + Forward pass through the MaxPool2dWithSamePadding module. + + Note: + The 'same' padding is applied only when self.padding is set to "same". + + Example: + # Create an instance of MaxPool2dWithSamePadding + maxpool_layer = MaxPool2dWithSamePadding(kernel_size=3, stride=2, padding="same") + + # Perform a forward pass on an input tensor + input_tensor = torch.rand(1, 3, 32, 32) # Example input tensor + output = maxpool_layer(input_tensor) # Apply the MaxPool2d operation with same padding. + """ + + def _calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: + """Calculate the required padding to achieve 'same' padding. + + Args: + i (int): Input dimension (height or width). + k (int): Kernel size. + s (int): Stride. + d (int): Dilation. + + Returns: + int: The calculated padding value. + """ + return max( + (torch.ceil(torch.tensor(i / s)).item() - 1) * s + (k - 1) * d + 1 - i, 0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the MaxPool2dWithSamePadding module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the MaxPool2d operation. + """ + if self.padding == "same": + ih, iw = x.size()[-2:] + + pad_h = self._calc_same_pad( + i=ih, + k=self.kernel_size + if type(self.kernel_size) is int + else self.kernel_size[0], + s=self.stride if type(self.stride) is int else self.stride[0], + d=self.dilation if type(self.dilation) is int else self.dilation[0], + ) + pad_w = self._calc_same_pad( + i=iw, + k=self.kernel_size + if type(self.kernel_size) is int + else self.kernel_size[1], + s=self.stride if type(self.stride) is int else self.stride[1], + d=self.dilation if type(self.dilation) is int else self.dilation[1], + ) + + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + self.padding = 0 + + return F.max_pool2d( + x, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) + + +def get_act_fn(activation: str) -> nn.Module: + """Get an instance of an activation function module based on the provided name. + + This function returns an instance of a PyTorch activation function module + corresponding to the given activation function name. + + Args: + activation (str): Name of the activation function. Supported values are 'relu', 'sigmoid', and 'tanh'. + + Returns: + nn.Module: An instance of the requested activation function module. + + Raises: + KeyError: If the provided activation function name is not one of the supported values. + + Example: + # Get an instance of the ReLU activation function + relu_fn = get_act_fn('relu') + + # Apply the activation function to an input tensor + input_tensor = torch.randn(1, 64, 64) + output = relu_fn(input_tensor) + """ + activations = {"relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh()} + + if activation not in activations: + raise KeyError( + f"Unsupported activation function: {activation}. Supported activations are: {', '.join(activations.keys())}" + ) + + return activations[activation] diff --git a/sleap_nn/architectures/encoder_decoder.py b/sleap_nn/architectures/encoder_decoder.py new file mode 100644 index 00000000..30288251 --- /dev/null +++ b/sleap_nn/architectures/encoder_decoder.py @@ -0,0 +1,503 @@ +"""Generic encoder-decoder fully convolutional backbones. + +This module contains building blocks for creating encoder-decoder architectures of +general form. + +The encoder branch of the network forms the initial multi-scale feature extraction via +repeated blocks of convolutions and pooling steps. + +The decoder branch is then responsible for upsampling the low resolution feature maps +to achieve the target output stride. + +This pattern is generalizable and describes most fully convolutional architectures. For +example: + - simple convolutions with pooling form the structure in `LEAP CNN +`_; + - adding skip connections forms `U-Net `_; + - using residual blocks with skip connections forms the base module in `stacked + hourglass `_; + - using dense blocks with skip connections forms `FC-DenseNet +`_. + +This module implements blocks used in all of these variants on top of a generic base +classes. + +See the `EncoderDecoder` base class for requirements for creating new architectures. +""" + +from typing import List, Text, Tuple, Union + +import torch +from common import MaxPool2dWithSamePadding, get_act_fn +from torch import nn + + +class SimpleConvBlock(nn.Module): + """A simple convolutional block module. + + This class defines a convolutional block that consists of convolutional layers, + optional pooling layers, batch normalization, and activation functions. + + The layers within the SimpleConvBlock are organized as follows: + + 1. Optional max pooling (with same padding) layer (before convolutional layers). + 2. Convolutional layers with specified number of filters, kernel size, and activation. + 3. Optional batch normalization layer after each convolutional layer (if batch_norm is True). + 4. Activation function after each convolutional layer (ReLU, Sigmoid, Tanh, etc.). + 5. Optional max pooling (with same padding) layer (after convolutional layers). + + Args: + in_channels: Number of input channels. + pool: Whether to include pooling layers. Default is True. + pooling_stride: Stride for pooling layers. Default is 2. + pool_before_convs: Whether to apply pooling before convolutional layers. Default is False. + num_convs: Number of convolutional layers. Default is 2. + filters: Number of filters for convolutional layers. Default is 32. + kernel_size: Size of the convolutional kernels. Default is 3. + use_bias: Whether to use bias in convolutional layers. Default is True. + batch_norm: Whether to apply batch normalization. Default is False. + activation: Activation function name. Default is "relu". + + Attributes: + Inherits all attributes from torch.nn.Module. + + Note: + The 'same' padding is applied using custom MaxPool2dWithSamePadding layers. + """ + + def __init__( + self, + in_channels: int, + pool: bool = True, + pooling_stride: int = 2, + pool_before_convs: bool = False, + num_convs: int = 2, + filters: int = 32, + kernel_size: int = 3, + use_bias: bool = True, + batch_norm: bool = False, + activation: Text = "relu", + ) -> None: + """Initialize the class.""" + super().__init__() + + self.in_channels = in_channels + self.pool = pool + self.pooling_stride = pooling_stride + self.pool_before_convs = pool_before_convs + self.num_convs = num_convs + self.filters = filters + self.kernel_size = kernel_size + self.use_bias = use_bias + self.batch_norm = batch_norm + self.activation = activation + + self.blocks = [] + if pool and pool_before_convs: + self.blocks.append( + MaxPool2dWithSamePadding( + kernel_size=2, stride=pooling_stride, padding="same" + ) + ) + + for i in range(num_convs): + self.blocks.append( + nn.Conv2d( + in_channels=in_channels if i == 0 else filters, + out_channels=filters, + kernel_size=kernel_size, + stride=1, + padding="same", + bias=use_bias, + ) + ) + + if batch_norm: + self.blocks.append(nn.BatchNorm2d(filters)) + + self.blocks.append(get_act_fn(activation)) + + if pool and not pool_before_convs: + self.blocks.append( + MaxPool2dWithSamePadding( + kernel_size=2, stride=pooling_stride, padding="same" + ) + ) + + self.blocks = nn.Sequential(*self.blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the SimpleConvBlock module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the convolutional block operations. + """ + return self.blocks(x) + + +class Encoder(nn.Module): + """Encoder module for a neural network architecture. + + This class defines the encoder part of a neural network architecture, + which consists of a stack of convolutional blocks for feature extraction. + + The Encoder consists of a stack of SimpleConvBlocks designed for feature extraction. + + Args: + in_channels: Number of input channels. Default is 3. + filters: Number of filters for the initial block. Default is 64. + down_blocks: Number of downsampling blocks. Default is 4. + filters_rate: Factor to increase the number of filters per block. Default is 2. + current_stride: Initial stride for pooling operations. Default is 2. + stem_blocks: Number of initial stem blocks. Default is 0. + convs_per_block: Number of convolutional layers per block. Default is 2. + kernel_size: Size of the convolutional kernels. Default is 3. + middle_block: Whether to include a middle block. Default is True. + block_contraction: Whether to contract the channels in the middle block. Default is False. + + Attributes: + Inherits all attributes from torch.nn.Module. + """ + + def __init__( + self, + in_channels: int = 3, + filters: int = 64, + down_blocks: int = 4, + filters_rate: Union[float, int] = 2, + current_stride: int = 2, + stem_blocks: int = 0, + convs_per_block: int = 2, + kernel_size: Union[int, Tuple[int, int]] = 3, + middle_block: bool = True, + block_contraction: bool = False, + ) -> None: + """Initialize the class.""" + super().__init__() + + self.in_channels = in_channels + self.filters = filters + self.down_blocks = down_blocks + self.filters_rate = filters_rate + self.current_stride = current_stride + self.stem_blocks = stem_blocks + self.convs_per_block = convs_per_block + self.kernel_size = kernel_size + self.middle_block = middle_block + self.block_contraction = block_contraction + + self.encoder_stack = nn.ModuleList([]) + for block in range(down_blocks): + prev_block_filters = -1 if block == 0 else block_filters + block_filters = int(filters * (filters_rate ** (block + stem_blocks))) + + self.encoder_stack.append( + SimpleConvBlock( + in_channels=in_channels if block == 0 else prev_block_filters, + pool=(block > 0), + pool_before_convs=True, + pooling_stride=2, + num_convs=convs_per_block, + filters=block_filters, + kernel_size=kernel_size, + use_bias=True, + batch_norm=False, + activation="relu", + ) + ) + after_block_filters = block_filters + + self.encoder_stack.append( + MaxPool2dWithSamePadding(kernel_size=2, stride=2, padding="same") + ) + + # Create a middle block (like the CARE implementation). + if middle_block: + if convs_per_block > 1: + # First convs are one exponent higher than the last encoder block. + block_filters = int( + filters * (filters_rate ** (down_blocks + stem_blocks)) + ) + self.encoder_stack.append( + SimpleConvBlock( + in_channels=after_block_filters, + pool=False, + pool_before_convs=False, + pooling_stride=2, + num_convs=convs_per_block - 1, + filters=block_filters, + kernel_size=kernel_size, + use_bias=True, + batch_norm=False, + activation="relu", + ) + ) + + if block_contraction: + # Contract the channels with an exponent lower than the last encoder block. + block_filters = int( + filters * (filters_rate ** (down_blocks + stem_blocks - 1)) + ) + else: + # Keep the block output filters the same. + block_filters = int( + filters * (filters_rate ** (down_blocks + stem_blocks)) + ) + + self.encoder_stack.append( + SimpleConvBlock( + in_channels=block_filters, + pool=False, + pool_before_convs=False, + pooling_stride=2, + num_convs=1, + filters=block_filters, + kernel_size=kernel_size, + use_bias=True, + batch_norm=False, + activation="relu", + ) + ) + + self.intermediate_features = {} + for i, block in enumerate(self.encoder_stack): + if isinstance(block, SimpleConvBlock) and block.pool: + current_stride *= block.pooling_stride + + if current_stride not in self.intermediate_features.values(): + self.intermediate_features[i] = current_stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the Encoder module. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the encoder operations. + list: List of intermediate feature tensors from different levels of the encoder. + """ + features = [] + for i in range(len(self.encoder_stack)): + x = self.encoder_stack[i](x) + + if i in self.intermediate_features.keys(): + features.append(x) + + return x, features[1:][::-1] + + +class SimpleUpsamplingBlock(nn.Module): + """A simple upsampling and refining block module. + + This class defines an upsampling and refining block that consists of upsampling layers, + convolutional layers for refinement, batch normalization, and activation functions. + + The block includes: + 1. Upsampling layers with adjustable stride and interpolation method. + 2. Refinement convolutional layers with customizable parameters. + 3. BatchNormalization layers (if specified; can be before or after activation function). + 4. Activation functions (default is ReLU) applied before or after BatchNormalization. + + Args: + x_in_shape: Number of input channels for the feature map. + current_stride: Current stride value to adjust during upsampling. + upsampling_stride: Stride for upsampling. Default is 2. + interp_method: Interpolation method for upsampling. Default is "bilinear". + refine_convs: Number of convolutional layers for refinement. Default is 2. + refine_convs_filters: Number of filters for refinement convolutional layers. Default is 64. + refine_convs_kernel_size: Size of the refinement convolutional kernels. Default is 3. + refine_convs_use_bias: Whether to use bias in refinement convolutional layers. Default is True. + refine_convs_batch_norm: Whether to apply batch normalization. Default is True. + refine_convs_batch_norm_before_activation: Whether to apply batch normalization before activation. + refine_convs_activation: Activation function name. Default is "relu". + + Attributes: + Inherits all attributes from torch.nn.Module. + """ + + def __init__( + self, + x_in_shape: int, + current_stride: int, + upsampling_stride: int = 2, + interp_method: Text = "bilinear", + refine_convs: int = 2, + refine_convs_filters: int = 64, + refine_convs_kernel_size: int = 3, + refine_convs_use_bias: bool = True, + refine_convs_batch_norm: bool = True, + refine_convs_batch_norm_before_activation: bool = True, + refine_convs_activation: Text = "relu", + ) -> None: + """Initialize the class.""" + super().__init__() + + self.x_in_shape = x_in_shape + self.current_stride = current_stride + self.upsampling_stride = upsampling_stride + self.interp_method = interp_method + self.refine_convs = refine_convs + self.refine_convs_filters = refine_convs_filters + self.refine_convs_kernel_size = refine_convs_kernel_size + self.refine_convs_use_bias = refine_convs_use_bias + self.refine_convs_batch_norm = refine_convs_batch_norm + self.refine_convs_batch_norm_before_activation = ( + refine_convs_batch_norm_before_activation + ) + self.refine_convs_activation = refine_convs_activation + + self.blocks = nn.ModuleList([]) + if current_stride is not None: + # Append the strides to the block prefix. + new_stride = current_stride // upsampling_stride + + # Upsample via interpolation. + self.blocks.append( + nn.Upsample( + scale_factor=upsampling_stride, + mode=interp_method, + ) + ) + + # Add further convolutions to refine after upsampling and/or skip. + for i in range(refine_convs): + filters = refine_convs_filters + self.blocks.append( + nn.Conv2d( + in_channels=x_in_shape if i == 0 else filters, + out_channels=filters, + kernel_size=refine_convs_kernel_size, + stride=1, + padding="same", + bias=refine_convs_use_bias, + ) + ) + + if refine_convs_batch_norm and refine_convs_batch_norm_before_activation: + self.blocks.append(nn.BatchNorm2d(num_features=filters)) + + self.blocks.append(get_act_fn(refine_convs_activation)) + + if ( + refine_convs_batch_norm + and not refine_convs_batch_norm_before_activation + ): + self.blocks.append(nn.BatchNorm2d(num_features=filters)) + + def forward(self, x: torch.Tensor, feature: torch.Tensor) -> torch.Tensor: + """Forward pass through the SimpleUpsamplingBlock module. + + Args: + x: Input tensor. + feature: Feature tensor to be concatenated with the upsampled tensor. + + Returns: + torch.Tensor: Output tensor after applying the upsampling and refining operations. + """ + for idx, b in enumerate(self.blocks): + if idx == 1: # Right after upsampling or convtranspose2d. + x = torch.concat((x, feature), dim=1) + x = b(x) + + return x + + +class Decoder(nn.Module): + """Decoder module for the UNet architecture. + + This class defines the decoder part of the UNet, + which consists of a stack of upsampling and refining blocks for feature reconstruction. + + Args: + x_in_shape: Number of input channels for the decoder's input. + current_stride: Current stride value to adjust during upsampling. + filters: Number of filters for the initial block. Default is 64. + up_blocks: Number of upsampling blocks. Default is 4. + down_blocks: Number of downsampling blocks. Default is 3. + filters_rate: Factor to adjust the number of filters per block. Default is 2. + stem_blocks: Number of initial stem blocks. Default is 0. + convs_per_block: Number of convolutional layers per block. Default is 2. + kernel_size: Size of the convolutional kernels. Default is 3. + block_contraction: Whether to contract the channels in the upsampling blocks. Default is False. + + Attributes: + Inherits all attributes from torch.nn.Module. + """ + + def __init__( + self, + x_in_shape: int, + current_stride: int, + filters: int = 64, + up_blocks: int = 4, + down_blocks: int = 3, + filters_rate: int = 2, + stem_blocks: int = 0, + convs_per_block: int = 2, + kernel_size: int = 3, + block_contraction: bool = False, + ) -> None: + """Initialize the class.""" + super().__init__() + + self.x_in_shape = x_in_shape + self.current_stride = current_stride + self.filters = filters + self.up_blocks = up_blocks + self.down_blocks = down_blocks + self.filters_rate = filters_rate + self.stem_blocks = stem_blocks + self.convs_per_block = convs_per_block + self.kernel_size = kernel_size + self.block_contraction = block_contraction + + self.decoder_stack = nn.ModuleList([]) + for block in range(up_blocks): + prev_block_filters_in = -1 if block == 0 else block_filters_in + block_filters_in = int( + filters * (filters_rate ** (down_blocks + stem_blocks - 1 - block)) + ) + if block_contraction: + block_filters_out = int( + filters * (filters_rate ** (down_blocks + stem_blocks - 2 - block)) + ) + else: + block_filters_out = block_filters_in + + next_stride = current_stride // 2 + + self.decoder_stack.append( + SimpleUpsamplingBlock( + x_in_shape=(x_in_shape + block_filters_in) + if block == 0 + else (prev_block_filters_in + block_filters_in), + current_stride=current_stride, + upsampling_stride=2, + interp_method="bilinear", + refine_convs=self.convs_per_block, + refine_convs_filters=block_filters_out, + refine_convs_kernel_size=self.kernel_size, + refine_convs_batch_norm=False, + ) + ) + + current_stride = next_stride + + def forward(self, x: torch.Tensor, features: List[torch.Tensor]) -> torch.Tensor: + """Forward pass through the Decoder module. + + Args: + x: Input tensor for the decoder. + features: List of feature tensors from different encoder levels. + + Returns: + torch.Tensor: Output tensor after applying the decoder operations. + """ + for i in range(len(self.decoder_stack)): + x = self.decoder_stack[i](x, features[i]) + + return x diff --git a/sleap_nn/architectures/unet.py b/sleap_nn/architectures/unet.py new file mode 100644 index 00000000..5cfce65f --- /dev/null +++ b/sleap_nn/architectures/unet.py @@ -0,0 +1,101 @@ +"""This module provides a generalized implementation of UNet. + +See the `UNet` class docstring for more information. +""" + +import math +from typing import List, Optional, Text, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from sleap_nn.architectures.encoder_decoder import Decoder, Encoder + + +class UNet(nn.Module): + """U-Net architecture for pose estimation. + + This class defines the U-Net architecture for pose estimation, combining an + encoder and a decoder. The encoder extracts features from the input, while the + decoder generates confidence maps based on the features. + + Args: + in_channels: Number of input channels. Default is 1. + kernel_size: Size of the convolutional kernels. Default is 3. + filters: Number of filters for the initial block. Default is 32. + filters_rate: Factor to adjust the number of filters per block. Default is 1.5. + stem_blocks: Number of initial stem blocks. Default is 0. + down_blocks: Number of downsampling blocks. Default is 4. + up_blocks: Number of upsampling blocks in the decoder. Default is 3. + convs_per_block: Number of convolutional layers per block. Default is 2. + middle_block: Whether to include a middle block in the encoder. Default is True. + block_contraction: Whether to contract the channels in the decoder blocks. Default is False. + + Attributes: + Inherits all attributes from torch.nn.Module. + """ + + def __init__( + self, + in_channels: int = 1, + kernel_size: int = 3, + filters: int = 32, + filters_rate: int = 1.5, + stem_blocks: int = 0, + down_blocks: int = 4, + up_blocks: int = 3, + convs_per_block: int = 2, + middle_block: bool = True, + block_contraction: bool = False, + ) -> None: + """Initialize the class.""" + super().__init__() + + self.enc = Encoder( + in_channels=in_channels, + filters=filters, + down_blocks=down_blocks, + filters_rate=filters_rate, + stem_blocks=stem_blocks, + convs_per_block=convs_per_block, + kernel_size=kernel_size, + middle_block=middle_block, + block_contraction=block_contraction, + ) + + current_stride = int( + np.prod( + [ + block.pooling_stride + for block in self.enc.encoder_stack + if hasattr(block, "pool") and block.pool + ] + + [1] + ) + ) + + x_in_shape = int(filters * (filters_rate ** (down_blocks + stem_blocks))) + + self.dec = Decoder( + x_in_shape=x_in_shape, + current_stride=current_stride, + filters=filters, + up_blocks=up_blocks, + down_blocks=down_blocks, + filters_rate=filters_rate, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the U-Net architecture. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the U-Net operations. + """ + x, features = self.enc(x) + x = self.dec(x, features) + return x From f9558f273643455adec534e7e9c90c37197e3886 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Fri, 18 Aug 2023 15:25:40 -0700 Subject: [PATCH 27/55] added test_unet_reference --- sleap_nn/architectures/common.py | 22 ++++++ sleap_nn/architectures/encoder_decoder.py | 4 +- tests/architectures/test_unet.py | 95 +++++++++++++++++++++++ tests/architectures/tmp.ipynb | 76 ++++++++++++++++++ 4 files changed, 195 insertions(+), 2 deletions(-) create mode 100644 tests/architectures/test_unet.py create mode 100644 tests/architectures/tmp.ipynb diff --git a/sleap_nn/architectures/common.py b/sleap_nn/architectures/common.py index 70368c14..3017d8df 100644 --- a/sleap_nn/architectures/common.py +++ b/sleap_nn/architectures/common.py @@ -129,3 +129,25 @@ def get_act_fn(activation: str) -> nn.Module: ) return activations[activation] + + +def get_children_layers(model: torch.nn.Module): + """Recursively retrieves a flattened list of all children modules and submodules within the given model. + + Args: + model: The PyTorch model to extract children from. + + Returns: + list of nn.Module: A flattened list containing all children modules and submodules. + """ + children = list(model.children()) + flattened_children = [] + if children == []: + return model + else: + for child in children: + try: + flattened_children.extend(get_children_layers(child)) + except TypeError: + flattened_children.append(get_children_layers(child)) + return flattened_children \ No newline at end of file diff --git a/sleap_nn/architectures/encoder_decoder.py b/sleap_nn/architectures/encoder_decoder.py index 30288251..ce2cdb8f 100644 --- a/sleap_nn/architectures/encoder_decoder.py +++ b/sleap_nn/architectures/encoder_decoder.py @@ -28,7 +28,7 @@ from typing import List, Text, Tuple, Union import torch -from common import MaxPool2dWithSamePadding, get_act_fn +from sleap_nn.architectures.common import MaxPool2dWithSamePadding, get_act_fn from torch import nn @@ -287,7 +287,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if i in self.intermediate_features.keys(): features.append(x) - return x, features[1:][::-1] + return x, features[::-1] class SimpleUpsamplingBlock(nn.Module): diff --git a/tests/architectures/test_unet.py b/tests/architectures/test_unet.py new file mode 100644 index 00000000..8460e85d --- /dev/null +++ b/tests/architectures/test_unet.py @@ -0,0 +1,95 @@ +import torch +from torch import nn + +from sleap_nn.architectures.common import get_children_layers +from sleap_nn.architectures.encoder_decoder import Encoder +from sleap_nn.architectures.unet import UNet + + +def test_unet_reference(): + device = "cuda" if torch.cuda.is_available() else "cpu" + + in_channels = 1 + filters = 64 + filters_rate = 2 + kernel_size = 3 + down_blocks = 4 + stem_blocks = 0 + up_blocks = 4 + convs_per_block = 2 + middle_block = True + block_contraction = False + + unet = UNet( + in_channels=in_channels, + filters=filters, + filters_rate=filters_rate, + down_blocks=down_blocks, + stem_blocks=stem_blocks, + up_blocks=up_blocks + ) + + in_channels = int( + filters + * ( + filters_rate + ** (down_blocks + stem_blocks - 1 - up_blocks + 1) + ) + ) + model = nn.Sequential(*[ + unet, + nn.Conv2d(in_channels=in_channels, out_channels=13, kernel_size=1, padding="same") + ]) + + # Test number of layers. + flattened_layers = get_children_layers(model) + assert len(flattened_layers) == 45 + + # Test number of trainable weights. + trainable_weights_count = sum([1 if p.requires_grad else 0 for p in model.parameters()]) + assert trainable_weights_count == 38 + + # Test trainable parameter count. + pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + assert pytorch_trainable_params == 31378573 + + # Test total parameter count. + pytorch_total_params = sum(p.numel() for p in model.parameters()) + assert pytorch_total_params == 31378573 + + # Test final output shape. + model = model.to(device) + _ = model.eval() + + x = torch.rand(1, 1, 192, 192).to(device) + with torch.no_grad(): + y = model(x) + assert y.shape == (1, 13, 192, 192) + + # Test number of intermediate features outputted from encoder. + enc = Encoder( + in_channels = 1, + filters = filters, + down_blocks = down_blocks, + filters_rate = filters_rate, + current_stride = 2, + stem_blocks = stem_blocks, + convs_per_block = convs_per_block, + kernel_size = kernel_size, + middle_block = middle_block, + block_contraction = block_contraction, + ) + + enc = enc.to(device) + _ = enc.eval() + + x = torch.rand(1, 1, 192, 192).to(device) + with torch.no_grad(): + y, features = enc(x) + + assert y.shape == (1, 1024, 12, 12) + assert len(features) == 4 + assert features[0].shape == (1, 512, 24, 24) + assert features[1].shape == (1, 256, 48, 48) + assert features[2].shape == (1, 128, 96, 96) + assert features[3].shape == (1, 64, 192, 192) \ No newline at end of file diff --git a/tests/architectures/tmp.ipynb b/tests/architectures/tmp.ipynb new file mode 100644 index 00000000..706b2bf8 --- /dev/null +++ b/tests/architectures/tmp.ipynb @@ -0,0 +1,76 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from sleap_nn.architectures.unet import UNet\n", + "import torch\n", + "from torch import nn\n", + "\n", + "in_channels = 1\n", + "filters = 64\n", + "filters_rate = 2\n", + "kernel_size = 3\n", + "down_blocks = 4\n", + "stem_blocks = 0\n", + "up_blocks = 4\n", + "convs_per_block = 2\n", + "middle_block = True\n", + "block_contraction = False\n", + "\n", + "unet = UNet(\n", + " in_channels=in_channels,\n", + " filters=filters, \n", + " filters_rate=filters_rate, \n", + " down_blocks=down_blocks, \n", + " stem_blocks=stem_blocks, \n", + " up_blocks=up_blocks\n", + ")\n", + "\n", + "in_channels = int(\n", + " filters\n", + " * (\n", + " filters_rate\n", + " ** (down_blocks + stem_blocks - 1 - up_blocks + 1)\n", + " )\n", + ")\n", + "model = nn.Sequential(*[\n", + " unet,\n", + " nn.Conv2d(in_channels=in_channels, out_channels=13, kernel_size=1, padding=\"same\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 28d57ca130a6ace0b6314e683a0b530bcf24741c Mon Sep 17 00:00:00 2001 From: alckasoc Date: Fri, 18 Aug 2023 15:26:56 -0700 Subject: [PATCH 28/55] black formatted common.py & test_unet.py --- sleap_nn/architectures/common.py | 4 +- tests/architectures/test_unet.py | 174 ++++++++++++++++--------------- 2 files changed, 91 insertions(+), 87 deletions(-) diff --git a/sleap_nn/architectures/common.py b/sleap_nn/architectures/common.py index 3017d8df..9065c3cd 100644 --- a/sleap_nn/architectures/common.py +++ b/sleap_nn/architectures/common.py @@ -145,9 +145,9 @@ def get_children_layers(model: torch.nn.Module): if children == []: return model else: - for child in children: + for child in children: try: flattened_children.extend(get_children_layers(child)) except TypeError: flattened_children.append(get_children_layers(child)) - return flattened_children \ No newline at end of file + return flattened_children diff --git a/tests/architectures/test_unet.py b/tests/architectures/test_unet.py index 8460e85d..2f61bde8 100644 --- a/tests/architectures/test_unet.py +++ b/tests/architectures/test_unet.py @@ -7,89 +7,93 @@ def test_unet_reference(): - device = "cuda" if torch.cuda.is_available() else "cpu" - - in_channels = 1 - filters = 64 - filters_rate = 2 - kernel_size = 3 - down_blocks = 4 - stem_blocks = 0 - up_blocks = 4 - convs_per_block = 2 - middle_block = True - block_contraction = False - - unet = UNet( - in_channels=in_channels, - filters=filters, - filters_rate=filters_rate, - down_blocks=down_blocks, - stem_blocks=stem_blocks, - up_blocks=up_blocks - ) - - in_channels = int( - filters - * ( - filters_rate - ** (down_blocks + stem_blocks - 1 - up_blocks + 1) - ) - ) - model = nn.Sequential(*[ + device = "cuda" if torch.cuda.is_available() else "cpu" + + in_channels = 1 + filters = 64 + filters_rate = 2 + kernel_size = 3 + down_blocks = 4 + stem_blocks = 0 + up_blocks = 4 + convs_per_block = 2 + middle_block = True + block_contraction = False + + unet = UNet( + in_channels=in_channels, + filters=filters, + filters_rate=filters_rate, + down_blocks=down_blocks, + stem_blocks=stem_blocks, + up_blocks=up_blocks, + ) + + in_channels = int( + filters * (filters_rate ** (down_blocks + stem_blocks - 1 - up_blocks + 1)) + ) + model = nn.Sequential( + *[ unet, - nn.Conv2d(in_channels=in_channels, out_channels=13, kernel_size=1, padding="same") - ]) - - # Test number of layers. - flattened_layers = get_children_layers(model) - assert len(flattened_layers) == 45 - - # Test number of trainable weights. - trainable_weights_count = sum([1 if p.requires_grad else 0 for p in model.parameters()]) - assert trainable_weights_count == 38 - - # Test trainable parameter count. - pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - assert pytorch_trainable_params == 31378573 - - # Test total parameter count. - pytorch_total_params = sum(p.numel() for p in model.parameters()) - assert pytorch_total_params == 31378573 - - # Test final output shape. - model = model.to(device) - _ = model.eval() - - x = torch.rand(1, 1, 192, 192).to(device) - with torch.no_grad(): - y = model(x) - assert y.shape == (1, 13, 192, 192) - - # Test number of intermediate features outputted from encoder. - enc = Encoder( - in_channels = 1, - filters = filters, - down_blocks = down_blocks, - filters_rate = filters_rate, - current_stride = 2, - stem_blocks = stem_blocks, - convs_per_block = convs_per_block, - kernel_size = kernel_size, - middle_block = middle_block, - block_contraction = block_contraction, - ) - - enc = enc.to(device) - _ = enc.eval() - - x = torch.rand(1, 1, 192, 192).to(device) - with torch.no_grad(): - y, features = enc(x) - - assert y.shape == (1, 1024, 12, 12) - assert len(features) == 4 - assert features[0].shape == (1, 512, 24, 24) - assert features[1].shape == (1, 256, 48, 48) - assert features[2].shape == (1, 128, 96, 96) - assert features[3].shape == (1, 64, 192, 192) \ No newline at end of file + nn.Conv2d( + in_channels=in_channels, out_channels=13, kernel_size=1, padding="same" + ), + ] + ) + + # Test number of layers. + flattened_layers = get_children_layers(model) + assert len(flattened_layers) == 45 + + # Test number of trainable weights. + trainable_weights_count = sum( + [1 if p.requires_grad else 0 for p in model.parameters()] + ) + assert trainable_weights_count == 38 + + # Test trainable parameter count. + pytorch_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + assert pytorch_trainable_params == 31378573 + + # Test total parameter count. + pytorch_total_params = sum(p.numel() for p in model.parameters()) + assert pytorch_total_params == 31378573 + + # Test final output shape. + model = model.to(device) + _ = model.eval() + + x = torch.rand(1, 1, 192, 192).to(device) + with torch.no_grad(): + y = model(x) + assert y.shape == (1, 13, 192, 192) + + # Test number of intermediate features outputted from encoder. + enc = Encoder( + in_channels=1, + filters=filters, + down_blocks=down_blocks, + filters_rate=filters_rate, + current_stride=2, + stem_blocks=stem_blocks, + convs_per_block=convs_per_block, + kernel_size=kernel_size, + middle_block=middle_block, + block_contraction=block_contraction, + ) + + enc = enc.to(device) + _ = enc.eval() + + x = torch.rand(1, 1, 192, 192).to(device) + with torch.no_grad(): + y, features = enc(x) + + assert y.shape == (1, 1024, 12, 12) + assert len(features) == 4 + assert features[0].shape == (1, 512, 24, 24) + assert features[1].shape == (1, 256, 48, 48) + assert features[2].shape == (1, 128, 96, 96) + assert features[3].shape == (1, 64, 192, 192) From 87cd034747545e218b7e7ed9daa9932547ad552d Mon Sep 17 00:00:00 2001 From: alckasoc Date: Fri, 18 Aug 2023 16:07:48 -0700 Subject: [PATCH 29/55] deleted tmp nb --- tests/architectures/tmp.ipynb | 76 ----------------------------------- 1 file changed, 76 deletions(-) delete mode 100644 tests/architectures/tmp.ipynb diff --git a/tests/architectures/tmp.ipynb b/tests/architectures/tmp.ipynb deleted file mode 100644 index 706b2bf8..00000000 --- a/tests/architectures/tmp.ipynb +++ /dev/null @@ -1,76 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from sleap_nn.architectures.unet import UNet\n", - "import torch\n", - "from torch import nn\n", - "\n", - "in_channels = 1\n", - "filters = 64\n", - "filters_rate = 2\n", - "kernel_size = 3\n", - "down_blocks = 4\n", - "stem_blocks = 0\n", - "up_blocks = 4\n", - "convs_per_block = 2\n", - "middle_block = True\n", - "block_contraction = False\n", - "\n", - "unet = UNet(\n", - " in_channels=in_channels,\n", - " filters=filters, \n", - " filters_rate=filters_rate, \n", - " down_blocks=down_blocks, \n", - " stem_blocks=stem_blocks, \n", - " up_blocks=up_blocks\n", - ")\n", - "\n", - "in_channels = int(\n", - " filters\n", - " * (\n", - " filters_rate\n", - " ** (down_blocks + stem_blocks - 1 - up_blocks + 1)\n", - " )\n", - ")\n", - "model = nn.Sequential(*[\n", - " unet,\n", - " nn.Conv2d(in_channels=in_channels, out_channels=13, kernel_size=1, padding=\"same\")\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 7004869284a59216b341a17a86a2bee21e9de3af Mon Sep 17 00:00:00 2001 From: alckasoc Date: Fri, 18 Aug 2023 17:06:12 -0700 Subject: [PATCH 30/55] _calc_same_pad returns int --- sleap_nn/architectures/common.py | 6 +++--- tests/architectures/test_common.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 tests/architectures/test_common.py diff --git a/sleap_nn/architectures/common.py b/sleap_nn/architectures/common.py index 9065c3cd..0833f174 100644 --- a/sleap_nn/architectures/common.py +++ b/sleap_nn/architectures/common.py @@ -48,9 +48,9 @@ def _calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: Returns: int: The calculated padding value. """ - return max( + return int(max( (torch.ceil(torch.tensor(i / s)).item() - 1) * s + (k - 1) * d + 1 - i, 0 - ) + )) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the MaxPool2dWithSamePadding module. @@ -83,7 +83,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if pad_h > 0 or pad_w > 0: x = F.pad( - x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) ) self.padding = 0 diff --git a/tests/architectures/test_common.py b/tests/architectures/test_common.py new file mode 100644 index 00000000..2cd21e8a --- /dev/null +++ b/tests/architectures/test_common.py @@ -0,0 +1,14 @@ +import pytest +import torch +from sleap_nn.architectures.common import get_act_fn, MaxPool2dWithSamePadding + +def test_maxpool2d_with_same_padding(): + pooling = MaxPool2dWithSamePadding(kernel_size=3, stride=2, dilation=2, padding="same") + + x = torch.rand(1, 10, 100, 100) + z = pooling(x) + assert z.shape == (1, 10, 48, 48) + +def test_get_act_fn(): + with pytest.raises(KeyError): + get_act_fn("invalid_input") \ No newline at end of file From 680778d7a581cc6e07a26fb345c951daf7b662d1 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Fri, 18 Aug 2023 17:07:55 -0700 Subject: [PATCH 31/55] fixed test case --- sleap_nn/architectures/common.py | 9 ++++++--- tests/architectures/test_common.py | 13 +++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/sleap_nn/architectures/common.py b/sleap_nn/architectures/common.py index 0833f174..1865e471 100644 --- a/sleap_nn/architectures/common.py +++ b/sleap_nn/architectures/common.py @@ -48,9 +48,12 @@ def _calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: Returns: int: The calculated padding value. """ - return int(max( - (torch.ceil(torch.tensor(i / s)).item() - 1) * s + (k - 1) * d + 1 - i, 0 - )) + return int( + max( + (torch.ceil(torch.tensor(i / s)).item() - 1) * s + (k - 1) * d + 1 - i, + 0, + ) + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the MaxPool2dWithSamePadding module. diff --git a/tests/architectures/test_common.py b/tests/architectures/test_common.py index 2cd21e8a..ddbbd8f7 100644 --- a/tests/architectures/test_common.py +++ b/tests/architectures/test_common.py @@ -1,14 +1,19 @@ import pytest import torch -from sleap_nn.architectures.common import get_act_fn, MaxPool2dWithSamePadding + +from sleap_nn.architectures.common import MaxPool2dWithSamePadding, get_act_fn + def test_maxpool2d_with_same_padding(): - pooling = MaxPool2dWithSamePadding(kernel_size=3, stride=2, dilation=2, padding="same") + pooling = MaxPool2dWithSamePadding( + kernel_size=3, stride=2, dilation=2, padding="same" + ) x = torch.rand(1, 10, 100, 100) z = pooling(x) - assert z.shape == (1, 10, 48, 48) + assert z.shape == (1, 10, 50, 50) + def test_get_act_fn(): with pytest.raises(KeyError): - get_act_fn("invalid_input") \ No newline at end of file + get_act_fn("invalid_input") From 7cd75dcf61a13296aa5ed0e2af94f809e20c3149 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Fri, 18 Aug 2023 17:16:52 -0700 Subject: [PATCH 32/55] added simpleconvblock tests --- tests/architectures/test_encoder_decoder.py | 27 +++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/architectures/test_encoder_decoder.py diff --git a/tests/architectures/test_encoder_decoder.py b/tests/architectures/test_encoder_decoder.py new file mode 100644 index 00000000..fb080bb5 --- /dev/null +++ b/tests/architectures/test_encoder_decoder.py @@ -0,0 +1,27 @@ +import torch + +import sleap_nn +from sleap_nn.architectures.encoder_decoder import SimpleConvBlock + + +def test_simple_conv_block(): + block = SimpleConvBlock( + in_channels=1, + pool=True, + pooling_stride=2, + pool_before_convs=False, + num_convs=2, + filters=32, + kernel_size=3, + use_bias=True, + batch_norm=True, + activation="relu", + ) + + block.blocks[0].__class__ == torch.nn.modules.conv.Conv2d + block.blocks[1].__class__ == torch.nn.modules.batchnorm.BatchNorm2d + block.blocks[2].__class__ == torch.nn.modules.activation.ReLU + block.blocks[3].__class__ == torch.nn.modules.conv.Conv2d + block.blocks[4].__class__ == torch.nn.modules.batchnorm.BatchNorm2d + block.blocks[5].__class__ == torch.nn.modules.activation.ReLU + block.blocks[6].__class__ == sleap_nn.architectures.common.MaxPool2dWithSamePadding From 79b535d563cdac2e5ab7cd1d0112e5434406fd6b Mon Sep 17 00:00:00 2001 From: alckasoc Date: Fri, 18 Aug 2023 18:17:21 -0700 Subject: [PATCH 33/55] added tests --- sleap_nn/architectures/encoder_decoder.py | 26 ++++------------------- sleap_nn/architectures/unet.py | 3 --- tests/architectures/test_unet.py | 1 - 3 files changed, 4 insertions(+), 26 deletions(-) diff --git a/sleap_nn/architectures/encoder_decoder.py b/sleap_nn/architectures/encoder_decoder.py index ce2cdb8f..c0f5bff8 100644 --- a/sleap_nn/architectures/encoder_decoder.py +++ b/sleap_nn/architectures/encoder_decoder.py @@ -156,7 +156,6 @@ class Encoder(nn.Module): convs_per_block: Number of convolutional layers per block. Default is 2. kernel_size: Size of the convolutional kernels. Default is 3. middle_block: Whether to include a middle block. Default is True. - block_contraction: Whether to contract the channels in the middle block. Default is False. Attributes: Inherits all attributes from torch.nn.Module. @@ -173,7 +172,6 @@ def __init__( convs_per_block: int = 2, kernel_size: Union[int, Tuple[int, int]] = 3, middle_block: bool = True, - block_contraction: bool = False, ) -> None: """Initialize the class.""" super().__init__() @@ -187,7 +185,6 @@ def __init__( self.convs_per_block = convs_per_block self.kernel_size = kernel_size self.middle_block = middle_block - self.block_contraction = block_contraction self.encoder_stack = nn.ModuleList([]) for block in range(down_blocks): @@ -236,16 +233,8 @@ def __init__( ) ) - if block_contraction: - # Contract the channels with an exponent lower than the last encoder block. - block_filters = int( - filters * (filters_rate ** (down_blocks + stem_blocks - 1)) - ) - else: - # Keep the block output filters the same. - block_filters = int( - filters * (filters_rate ** (down_blocks + stem_blocks)) - ) + # Keep the block output filters the same. + block_filters = int(filters * (filters_rate ** (down_blocks + stem_blocks))) self.encoder_stack.append( SimpleConvBlock( @@ -422,7 +411,6 @@ class Decoder(nn.Module): stem_blocks: Number of initial stem blocks. Default is 0. convs_per_block: Number of convolutional layers per block. Default is 2. kernel_size: Size of the convolutional kernels. Default is 3. - block_contraction: Whether to contract the channels in the upsampling blocks. Default is False. Attributes: Inherits all attributes from torch.nn.Module. @@ -439,7 +427,6 @@ def __init__( stem_blocks: int = 0, convs_per_block: int = 2, kernel_size: int = 3, - block_contraction: bool = False, ) -> None: """Initialize the class.""" super().__init__() @@ -453,7 +440,6 @@ def __init__( self.stem_blocks = stem_blocks self.convs_per_block = convs_per_block self.kernel_size = kernel_size - self.block_contraction = block_contraction self.decoder_stack = nn.ModuleList([]) for block in range(up_blocks): @@ -461,12 +447,8 @@ def __init__( block_filters_in = int( filters * (filters_rate ** (down_blocks + stem_blocks - 1 - block)) ) - if block_contraction: - block_filters_out = int( - filters * (filters_rate ** (down_blocks + stem_blocks - 2 - block)) - ) - else: - block_filters_out = block_filters_in + + block_filters_out = block_filters_in next_stride = current_stride // 2 diff --git a/sleap_nn/architectures/unet.py b/sleap_nn/architectures/unet.py index 5cfce65f..65a2d761 100644 --- a/sleap_nn/architectures/unet.py +++ b/sleap_nn/architectures/unet.py @@ -31,7 +31,6 @@ class UNet(nn.Module): up_blocks: Number of upsampling blocks in the decoder. Default is 3. convs_per_block: Number of convolutional layers per block. Default is 2. middle_block: Whether to include a middle block in the encoder. Default is True. - block_contraction: Whether to contract the channels in the decoder blocks. Default is False. Attributes: Inherits all attributes from torch.nn.Module. @@ -48,7 +47,6 @@ def __init__( up_blocks: int = 3, convs_per_block: int = 2, middle_block: bool = True, - block_contraction: bool = False, ) -> None: """Initialize the class.""" super().__init__() @@ -62,7 +60,6 @@ def __init__( convs_per_block=convs_per_block, kernel_size=kernel_size, middle_block=middle_block, - block_contraction=block_contraction, ) current_stride = int( diff --git a/tests/architectures/test_unet.py b/tests/architectures/test_unet.py index 2f61bde8..38b0e42f 100644 --- a/tests/architectures/test_unet.py +++ b/tests/architectures/test_unet.py @@ -81,7 +81,6 @@ def test_unet_reference(): convs_per_block=convs_per_block, kernel_size=kernel_size, middle_block=middle_block, - block_contraction=block_contraction, ) enc = enc.to(device) From 691af4548f37fbc24d6ed00d8e21e176b317bee1 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Fri, 18 Aug 2023 18:30:34 -0700 Subject: [PATCH 34/55] added tests for simple upsampling block --- tests/architectures/test_encoder_decoder.py | 81 ++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/tests/architectures/test_encoder_decoder.py b/tests/architectures/test_encoder_decoder.py index fb080bb5..3b74e520 100644 --- a/tests/architectures/test_encoder_decoder.py +++ b/tests/architectures/test_encoder_decoder.py @@ -1,7 +1,10 @@ import torch import sleap_nn -from sleap_nn.architectures.encoder_decoder import SimpleConvBlock +from sleap_nn.architectures.encoder_decoder import ( + SimpleConvBlock, + SimpleUpsamplingBlock, +) def test_simple_conv_block(): @@ -25,3 +28,79 @@ def test_simple_conv_block(): block.blocks[4].__class__ == torch.nn.modules.batchnorm.BatchNorm2d block.blocks[5].__class__ == torch.nn.modules.activation.ReLU block.blocks[6].__class__ == sleap_nn.architectures.common.MaxPool2dWithSamePadding + + +def test_simple_upsampling_block(): + device = "cuda" if torch.cuda.is_available() else "cpu" + + block = SimpleUpsamplingBlock( + x_in_shape=10, + current_stride=1, + upsampling_stride=2, + interp_method="bilinear", + refine_convs=2, + refine_convs_filters=64, + refine_convs_kernel_size=3, + refine_convs_use_bias=True, + refine_convs_batch_norm=True, + refine_convs_batch_norm_before_activation=True, + refine_convs_activation="relu", + ) + + block = block.to(device) + _ = block.eval() + + x = torch.rand(5, 5, 100, 100).to(device) + feature = torch.rand(5, 5, 200, 200).to(device) + + z = block(x, feature=feature) + + assert z.shape == (5, 64, 200, 200) + + block = SimpleUpsamplingBlock( + x_in_shape=10, + current_stride=1, + upsampling_stride=2, + interp_method="bilinear", + refine_convs=2, + refine_convs_filters=64, + refine_convs_kernel_size=3, + refine_convs_use_bias=True, + refine_convs_batch_norm=True, + refine_convs_batch_norm_before_activation=True, + refine_convs_activation="relu", + ) + + block = block.to(device) + _ = block.eval() + + x = torch.rand(5, 5, 100, 100).to(device) + feature = torch.rand(5, 5, 200, 200).to(device) + + z = block(x, feature=feature) + + assert z.shape == (5, 64, 200, 200) + + block = SimpleUpsamplingBlock( + x_in_shape=10, + current_stride=1, + upsampling_stride=2, + interp_method="bilinear", + refine_convs=2, + refine_convs_filters=64, + refine_convs_kernel_size=3, + refine_convs_use_bias=True, + refine_convs_batch_norm=True, + refine_convs_batch_norm_before_activation=False, + refine_convs_activation="relu", + ) + + block = block.to(device) + _ = block.eval() + + x = torch.rand(5, 5, 100, 100).to(device) + feature = torch.rand(5, 5, 200, 200).to(device) + + z = block(x, feature=feature) + + assert z.shape == (5, 64, 200, 200) From 2520fa2bd68bb360517fd29ab5a325bf375bb2ee Mon Sep 17 00:00:00 2001 From: alckasoc Date: Mon, 28 Aug 2023 14:18:34 -0700 Subject: [PATCH 35/55] updated test_unet --- tests/architectures/test_unet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/architectures/test_unet.py b/tests/architectures/test_unet.py index 38b0e42f..3e3d565e 100644 --- a/tests/architectures/test_unet.py +++ b/tests/architectures/test_unet.py @@ -18,7 +18,6 @@ def test_unet_reference(): up_blocks = 4 convs_per_block = 2 middle_block = True - block_contraction = False unet = UNet( in_channels=in_channels, From bcf4069cf056fd75cb29e448ccc3d748821fc418 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Tue, 29 Aug 2023 19:49:29 -0700 Subject: [PATCH 36/55] removed unnecessary variables --- sleap_nn/architectures/encoder_decoder.py | 63 +++++++++-------------- sleap_nn/architectures/unet.py | 8 +-- tests/architectures/test_unet.py | 9 +--- 3 files changed, 27 insertions(+), 53 deletions(-) diff --git a/sleap_nn/architectures/encoder_decoder.py b/sleap_nn/architectures/encoder_decoder.py index c0f5bff8..a2386f51 100644 --- a/sleap_nn/architectures/encoder_decoder.py +++ b/sleap_nn/architectures/encoder_decoder.py @@ -152,10 +152,8 @@ class Encoder(nn.Module): down_blocks: Number of downsampling blocks. Default is 4. filters_rate: Factor to increase the number of filters per block. Default is 2. current_stride: Initial stride for pooling operations. Default is 2. - stem_blocks: Number of initial stem blocks. Default is 0. convs_per_block: Number of convolutional layers per block. Default is 2. kernel_size: Size of the convolutional kernels. Default is 3. - middle_block: Whether to include a middle block. Default is True. Attributes: Inherits all attributes from torch.nn.Module. @@ -168,10 +166,8 @@ def __init__( down_blocks: int = 4, filters_rate: Union[float, int] = 2, current_stride: int = 2, - stem_blocks: int = 0, convs_per_block: int = 2, kernel_size: Union[int, Tuple[int, int]] = 3, - middle_block: bool = True, ) -> None: """Initialize the class.""" super().__init__() @@ -181,15 +177,13 @@ def __init__( self.down_blocks = down_blocks self.filters_rate = filters_rate self.current_stride = current_stride - self.stem_blocks = stem_blocks self.convs_per_block = convs_per_block self.kernel_size = kernel_size - self.middle_block = middle_block self.encoder_stack = nn.ModuleList([]) for block in range(down_blocks): prev_block_filters = -1 if block == 0 else block_filters - block_filters = int(filters * (filters_rate ** (block + stem_blocks))) + block_filters = int(filters * (filters_rate ** (block + 0))) self.encoder_stack.append( SimpleConvBlock( @@ -211,38 +205,16 @@ def __init__( MaxPool2dWithSamePadding(kernel_size=2, stride=2, padding="same") ) - # Create a middle block (like the CARE implementation). - if middle_block: - if convs_per_block > 1: - # First convs are one exponent higher than the last encoder block. - block_filters = int( - filters * (filters_rate ** (down_blocks + stem_blocks)) - ) - self.encoder_stack.append( - SimpleConvBlock( - in_channels=after_block_filters, - pool=False, - pool_before_convs=False, - pooling_stride=2, - num_convs=convs_per_block - 1, - filters=block_filters, - kernel_size=kernel_size, - use_bias=True, - batch_norm=False, - activation="relu", - ) - ) - - # Keep the block output filters the same. - block_filters = int(filters * (filters_rate ** (down_blocks + stem_blocks))) - + if convs_per_block > 1: + # First convs are one exponent higher than the last encoder block. + block_filters = int(filters * (filters_rate ** (down_blocks + 0))) self.encoder_stack.append( SimpleConvBlock( - in_channels=block_filters, + in_channels=after_block_filters, pool=False, pool_before_convs=False, pooling_stride=2, - num_convs=1, + num_convs=convs_per_block - 1, filters=block_filters, kernel_size=kernel_size, use_bias=True, @@ -251,6 +223,24 @@ def __init__( ) ) + # Keep the block output filters the same. + block_filters = int(filters * (filters_rate ** (down_blocks + 0))) + + self.encoder_stack.append( + SimpleConvBlock( + in_channels=block_filters, + pool=False, + pool_before_convs=False, + pooling_stride=2, + num_convs=1, + filters=block_filters, + kernel_size=kernel_size, + use_bias=True, + batch_norm=False, + activation="relu", + ) + ) + self.intermediate_features = {} for i, block in enumerate(self.encoder_stack): if isinstance(block, SimpleConvBlock) and block.pool: @@ -408,7 +398,6 @@ class Decoder(nn.Module): up_blocks: Number of upsampling blocks. Default is 4. down_blocks: Number of downsampling blocks. Default is 3. filters_rate: Factor to adjust the number of filters per block. Default is 2. - stem_blocks: Number of initial stem blocks. Default is 0. convs_per_block: Number of convolutional layers per block. Default is 2. kernel_size: Size of the convolutional kernels. Default is 3. @@ -424,7 +413,6 @@ def __init__( up_blocks: int = 4, down_blocks: int = 3, filters_rate: int = 2, - stem_blocks: int = 0, convs_per_block: int = 2, kernel_size: int = 3, ) -> None: @@ -437,7 +425,6 @@ def __init__( self.up_blocks = up_blocks self.down_blocks = down_blocks self.filters_rate = filters_rate - self.stem_blocks = stem_blocks self.convs_per_block = convs_per_block self.kernel_size = kernel_size @@ -445,7 +432,7 @@ def __init__( for block in range(up_blocks): prev_block_filters_in = -1 if block == 0 else block_filters_in block_filters_in = int( - filters * (filters_rate ** (down_blocks + stem_blocks - 1 - block)) + filters * (filters_rate ** (down_blocks + 0 - 1 - block)) ) block_filters_out = block_filters_in diff --git a/sleap_nn/architectures/unet.py b/sleap_nn/architectures/unet.py index 65a2d761..08ee16c2 100644 --- a/sleap_nn/architectures/unet.py +++ b/sleap_nn/architectures/unet.py @@ -26,11 +26,9 @@ class UNet(nn.Module): kernel_size: Size of the convolutional kernels. Default is 3. filters: Number of filters for the initial block. Default is 32. filters_rate: Factor to adjust the number of filters per block. Default is 1.5. - stem_blocks: Number of initial stem blocks. Default is 0. down_blocks: Number of downsampling blocks. Default is 4. up_blocks: Number of upsampling blocks in the decoder. Default is 3. convs_per_block: Number of convolutional layers per block. Default is 2. - middle_block: Whether to include a middle block in the encoder. Default is True. Attributes: Inherits all attributes from torch.nn.Module. @@ -42,11 +40,9 @@ def __init__( kernel_size: int = 3, filters: int = 32, filters_rate: int = 1.5, - stem_blocks: int = 0, down_blocks: int = 4, up_blocks: int = 3, convs_per_block: int = 2, - middle_block: bool = True, ) -> None: """Initialize the class.""" super().__init__() @@ -56,10 +52,8 @@ def __init__( filters=filters, down_blocks=down_blocks, filters_rate=filters_rate, - stem_blocks=stem_blocks, convs_per_block=convs_per_block, kernel_size=kernel_size, - middle_block=middle_block, ) current_stride = int( @@ -73,7 +67,7 @@ def __init__( ) ) - x_in_shape = int(filters * (filters_rate ** (down_blocks + stem_blocks))) + x_in_shape = int(filters * (filters_rate ** (down_blocks + 0))) self.dec = Decoder( x_in_shape=x_in_shape, diff --git a/tests/architectures/test_unet.py b/tests/architectures/test_unet.py index 3e3d565e..7e68f14b 100644 --- a/tests/architectures/test_unet.py +++ b/tests/architectures/test_unet.py @@ -14,23 +14,18 @@ def test_unet_reference(): filters_rate = 2 kernel_size = 3 down_blocks = 4 - stem_blocks = 0 up_blocks = 4 convs_per_block = 2 - middle_block = True unet = UNet( in_channels=in_channels, filters=filters, filters_rate=filters_rate, down_blocks=down_blocks, - stem_blocks=stem_blocks, up_blocks=up_blocks, ) - in_channels = int( - filters * (filters_rate ** (down_blocks + stem_blocks - 1 - up_blocks + 1)) - ) + in_channels = int(filters * (filters_rate ** (down_blocks + 0 - 1 - up_blocks + 1))) model = nn.Sequential( *[ unet, @@ -76,10 +71,8 @@ def test_unet_reference(): down_blocks=down_blocks, filters_rate=filters_rate, current_stride=2, - stem_blocks=stem_blocks, convs_per_block=convs_per_block, kernel_size=kernel_size, - middle_block=middle_block, ) enc = enc.to(device) From dbccdcf7f69d55c5e436310eb2689f3f50d5493f Mon Sep 17 00:00:00 2001 From: alckasoc Date: Tue, 29 Aug 2023 19:59:39 -0700 Subject: [PATCH 37/55] updated augmentation random erase default values --- sleap_nn/data/augmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index fb095405..7602981c 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -160,8 +160,8 @@ def __init__( contrast_p: float = 0.5, brightness: Optional[float] = 0.0, brightness_p: float = 0.5, - erase_scale: Optional[Tuple[float, float]] = (0.02, 0.1), - erase_ratio: Optional[Tuple[float, float]] = (0.3, 1.6), + erase_scale: Optional[Tuple[float, float]] = (0.0001, 0.01), + erase_ratio: Optional[Tuple[float, float]] = (1, 1), erase_p: float = 0.5, mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None, mixup_p: float = 0.5, From 029a5454d4ef03be98e63e5acc1e668efd4a745a Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 30 Aug 2023 16:05:40 -0700 Subject: [PATCH 38/55] created data/pipelines.py --- sleap_nn/data/pipelines.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 sleap_nn/data/pipelines.py diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py new file mode 100644 index 00000000..57e30c2c --- /dev/null +++ b/sleap_nn/data/pipelines.py @@ -0,0 +1,5 @@ +"""This module defines high level pipeline configurations from providers/transformers. + +This allows for convenient ways to configure individual variants of common pipelines, as +well as to define training vs inference versions based on the same configurations. +""" From 3e5ae68991b65f41656614359c3e55ce1b383fc2 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 30 Aug 2023 17:53:13 -0700 Subject: [PATCH 39/55] added base config in config/data; temporary till config system settled --- sleap_nn/config/__init__.py | 1 + sleap_nn/config/data.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) create mode 100644 sleap_nn/config/__init__.py create mode 100644 sleap_nn/config/data.py diff --git a/sleap_nn/config/__init__.py b/sleap_nn/config/__init__.py new file mode 100644 index 00000000..16afead9 --- /dev/null +++ b/sleap_nn/config/__init__.py @@ -0,0 +1 @@ +"""Modules relating to configuring data pipelines.""" diff --git a/sleap_nn/config/data.py b/sleap_nn/config/data.py new file mode 100644 index 00000000..a105c15d --- /dev/null +++ b/sleap_nn/config/data.py @@ -0,0 +1,14 @@ +"""This module implements base configurations for data pipelines.""" + +from omegaconf import OmegaConf + +# Base TopDownConfmapsPipeline data config. +base_topdown_data_config = OmegaConf.create( + { + "preproccessing": { + "crop_hw": (160, 160), + "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, + }, + "augmentation_config": {"random_crop": 0.0, "random_crop_hw": (160, 160)}, + } +) From 1b8002bbbbe42c5af84313ee5f5167628f99ae2b Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 30 Aug 2023 17:54:22 -0700 Subject: [PATCH 40/55] updated variable defaults to 0 and edited variable names in augmentation --- sleap_nn/data/augmentation.py | 44 +++++++++++++++++---------------- tests/data/test_augmentation.py | 8 +++--- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index 7602981c..3382ceae 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -100,9 +100,6 @@ class KorniaAugmenter(IterDataPipe): Attributes: source_dp: The input `IterDataPipe` with examples that contain `"instances"` and `"image"` keys. - crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int], - then out_h = size[0], out_w = size[1]. - crop_p: Probability of applying random crop. rotation: Angles in degrees as a scalar float of the amount of rotation. A random angle in `(-rotation, rotation)` will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. @@ -125,11 +122,14 @@ class KorniaAugmenter(IterDataPipe): contrast_p: Probability of applying random contrast. brightness: The brightness factor to apply Default: `(1.0, 1.0)`. brightness_p: Probability of applying random brightness. - erase_scale: Range of proportion of erased area against input image. Default: `(0.02, 0.33)`. - erase_ratio: Range of aspect ratio of erased area. Default: `(0.3, 3.3)`. + erase_scale: Range of proportion of erased area against input image. Default: `(0.0001, 0.01)`. + erase_ratio: Range of aspect ratio of erased area. Default: `(1, 1)`. erase_p: Probability of applying random erase. mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`. mixup_p: Probability of applying random mixup v2. + random_crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int], + then out_h = size[0], out_w = size[1]. + random_crop_p: Probability of applying random crop. Notes: This block expects the "image" and "instances" keys to be present in the input @@ -150,23 +150,23 @@ def __init__( rotation: Optional[float] = 15.0, scale: Optional[float] = 0.05, translate: Optional[Tuple[float, float]] = (0.02, 0.02), - affine_p: float = 0.5, + affine_p: float = 0.0, uniform_noise: Optional[Tuple[float, float]] = (0.0, 0.04), - uniform_noise_p: float = 0.5, + uniform_noise_p: float = 0.0, gaussian_noise_mean: Optional[float] = 0.02, gaussian_noise_std: Optional[float] = 0.004, - gaussian_noise_p: float = 0.5, + gaussian_noise_p: float = 0.0, contrast: Optional[Tuple[float, float]] = (0.5, 2.0), - contrast_p: float = 0.5, + contrast_p: float = 0.0, brightness: Optional[float] = 0.0, - brightness_p: float = 0.5, + brightness_p: float = 0.0, erase_scale: Optional[Tuple[float, float]] = (0.0001, 0.01), erase_ratio: Optional[Tuple[float, float]] = (1, 1), - erase_p: float = 0.5, + erase_p: float = 0.0, mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None, - mixup_p: float = 0.5, - crop_hw: Tuple[int, int] = (0, 0), - crop_p: float = 0.0, + mixup_p: float = 0.0, + random_crop_hw: Tuple[int, int] = (0, 0), + random_crop_p: float = 0.0, ): """Initialize the block and the augmentation pipeline.""" self.source_dp = source_dp @@ -188,8 +188,8 @@ def __init__( self.erase_p = erase_p self.mixup_lambda = mixup_lambda self.mixup_p = mixup_p - self.crop_hw = crop_hw - self.crop_p = crop_p + self.random_crop_hw = random_crop_hw + self.random_crop_p = random_crop_p aug_stack = [] if self.affine_p > 0: @@ -259,19 +259,21 @@ def __init__( same_on_batch=True, ) ) - if self.crop_p > 0: - if self.crop_hw[0] > 0 and self.crop_hw[1] > 0: + if self.random_crop_p > 0: + if self.random_crop_hw[0] > 0 and self.random_crop_hw[1] > 0: aug_stack.append( K.augmentation.RandomCrop( - size=self.crop_hw, + size=self.random_crop_hw, pad_if_needed=True, - p=self.crop_p, + p=self.random_crop_p, keepdim=True, same_on_batch=True, ) ) else: - raise ValueError(f"crop_hw height and width must be greater than 0.") + raise ValueError( + f"random_crop_hw height and width must be greater than 0." + ) self.augmenter = AugmentationSequential( *aug_stack, diff --git a/tests/data/test_augmentation.py b/tests/data/test_augmentation.py index a9e907a2..a5dbf9a2 100644 --- a/tests/data/test_augmentation.py +++ b/tests/data/test_augmentation.py @@ -47,8 +47,8 @@ def test_kornia_augmentation(minimal_instance): erase_p=1.0, mixup_p=1.0, mixup_lambda=(0.0, 1.0), - crop_hw=(384, 384), - crop_p=1.0, + random_crop_hw=(384, 384), + random_crop_p=1.0, ) # Test all augmentations. @@ -68,6 +68,6 @@ def test_kornia_augmentation(minimal_instance): ): p = KorniaAugmenter( p, - crop_hw=(0, 0), - crop_p=1.0, + random_crop_hw=(0, 0), + random_crop_p=1.0, ) From f1c64f45c0d952da62693ee4cee40b7904705dd3 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 30 Aug 2023 17:55:21 -0700 Subject: [PATCH 41/55] updated parameter names in data/instance_cropping --- sleap_nn/data/instance_cropping.py | 19 ++++++------------- tests/data/test_confmaps.py | 2 +- tests/data/test_instance_cropping.py | 2 +- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 81f02056..9f80e56d 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -1,5 +1,5 @@ """Handle cropping of instances.""" -from typing import Optional +from typing import Optional, Tuple import numpy as np import sleap_io as sio @@ -58,20 +58,13 @@ class InstanceCropper(IterDataPipe): Attributes: source_dp: The previous `DataPipe` with samples that contain an `instances` key. - crop_width: Width of the crop in pixels - crop_height: Height of the crop in pixels + crop_hw: Height and Width of the crop in pixels """ - def __init__( - self, - source_dp: IterDataPipe, - crop_width: int, - crop_height: int, - ): + def __init__(self, source_dp: IterDataPipe, crop_hw: Tuple[int, int]): """Initialize InstanceCropper with the source `DataPipe.""" self.source_dp = source_dp - self.crop_width = crop_width - self.crop_height = crop_height + self.crop_hw = crop_hw def __iter__(self): """Generate instance cropped examples.""" @@ -82,10 +75,10 @@ def __iter__(self): for instance, centroid in zip(instances[0], centroids[0]): # Generate bounding boxes from centroid. bbox = torch.unsqueeze( - make_centered_bboxes(centroid, self.crop_height, self.crop_width), 0 + make_centered_bboxes(centroid, self.crop_hw[0], self.crop_hw[1]), 0 ) # (frames, 4, 2) - box_size = (self.crop_height, self.crop_width) + box_size = (self.crop_hw[0], self.crop_hw[1]) # Generate cropped image of shape (frames, channels, crop_height, crop_width) instance_image = crop_and_resize( diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index 21531009..f7821e6d 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -12,7 +12,7 @@ def test_confmaps(minimal_instance): datapipe = LabelsReader.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) datapipe = Normalizer(datapipe) - datapipe = InstanceCropper(datapipe, 100, 100) + datapipe = InstanceCropper(datapipe, (100, 100)) datapipe1 = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=1) sample = next(iter(datapipe1)) diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 966e5572..29fd91f3 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -26,7 +26,7 @@ def test_instance_cropper(minimal_instance): datapipe = LabelsReader.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) datapipe = Normalizer(datapipe) - datapipe = InstanceCropper(datapipe, 100, 100) + datapipe = InstanceCropper(datapipe, (100, 100)) sample = next(iter(datapipe)) # Test shapes. From 2a2267485e0cc83cb7aaf2b399d948699ddb1a78 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 30 Aug 2023 17:55:45 -0700 Subject: [PATCH 42/55] added data/pipelines topdown pipeline make_base_pipeline --- sleap_nn/data/pipelines.py | 86 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index 57e30c2c..1a188bfc 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -3,3 +3,89 @@ This allows for convenient ways to configure individual variants of common pipelines, as well as to define training vs inference versions based on the same configurations. """ +import torch +from omegaconf.dictconfig import DictConfig +from torch.utils.data.datapipes.datapipe import IterDataPipe + +from sleap_nn.data.augmentation import KorniaAugmenter +from sleap_nn.data.confidence_maps import ConfidenceMapGenerator +from sleap_nn.data.instance_centroids import InstanceCentroidFinder +from sleap_nn.data.instance_cropping import InstanceCropper +from sleap_nn.data.normalization import Normalizer +from sleap_nn.data.providers import LabelsReader + + +class SleapDataset(IterDataPipe): + """Returns image and corresponding heatmap for the DataLoader. + + This class is to return the image and its corresponding confidence map + to load the dataset with the DataLoader class + + Attributes: + source_dp: The previous `DataPipe` with samples that contain an `instances` key. + """ + + def __init__(self, source_dp: IterDataPipe): + """Initialize SleapDataset with the source `DataPipe.""" + self.dp = source_dp + + def __iter__(self): + """Return a tuple with the cropped image and the heatmap.""" + for example in self.dp: + # reszie_img = kornia.geometry.transform.resize(example["crop_img"], (512, 512)) + if len(example["instance_image"].shape) == 4: + example["instance_image"] = example["instance_image"].squeeze(dim=0) + torch.cuda.empty_cache() + # yield example["instance_image"].cuda(), example["confidence_maps"].cuda() + yield example["instance_image"], example["confidence_maps"] + + +class TopdownConfmapsPipeline: + """Pipeline builder for instance-centered confidence map models. + + Attributes: + data_config: Data-related configuration. + optimization_config: Optimization-related configuration. + instance_confmap_head: Instantiated head describing the output centered + confidence maps tensor. + offsets_head: Optional head describing the offset refinement maps. + """ + + def __init__(self, data_config: DictConfig) -> None: + """Initialize the data config.""" + self.data_config = data_config + + def make_base_pipeline( + self, data_provider: IterDataPipe, filename: str + ) -> IterDataPipe: + """Create base pipeline with input data only. + + Args: + data_provider: A `Provider` that generates data examples, typically a + `LabelsReader` instance. + filename: A string path to the name of the `.slp` file. + + Returns: + An `IterDataPipe` instance configured to produce input examples. + """ + datapipe = data_provider.from_filename(filename=filename) + datapipe = Normalizer(datapipe) + + datapipe = InstanceCentroidFinder(datapipe) + datapipe = InstanceCropper(datapipe, self.data_config.preprocessing.crop_hw) + + if self.data_config.augmentation_config.random_crop: + datapipe = KorniaAugmenter( + datapipe, + random_crop_hw=self.data_config.augmentation_config.random_crop_hw, + random_crop_p=1.0, + ) + + datapipe = ConfidenceMapGenerator( + datapipe, + sigma=self.data_config.preprocessing.conf_map_gen.sigma, + output_stride=self.data_config.preprocessing.conf_map_gen.output_stride, + ) + datapipe = SleapDataset(datapipe) + + return datapipe From f3ddf2f7cf71b7087a888032e485d2e2e5c41c5e Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 30 Aug 2023 18:20:04 -0700 Subject: [PATCH 43/55] added test_pipelines --- sleap_nn/config/data.py | 2 +- tests/data/test_pipelines.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 tests/data/test_pipelines.py diff --git a/sleap_nn/config/data.py b/sleap_nn/config/data.py index a105c15d..c8f81143 100644 --- a/sleap_nn/config/data.py +++ b/sleap_nn/config/data.py @@ -5,7 +5,7 @@ # Base TopDownConfmapsPipeline data config. base_topdown_data_config = OmegaConf.create( { - "preproccessing": { + "preprocessing": { "crop_hw": (160, 160), "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, }, diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py new file mode 100644 index 00000000..023bafb9 --- /dev/null +++ b/tests/data/test_pipelines.py @@ -0,0 +1,35 @@ +import torch + +from sleap_nn.config.data import base_topdown_data_config +from sleap_nn.data.confidence_maps import ConfidenceMapGenerator +from sleap_nn.data.instance_centroids import InstanceCentroidFinder +from sleap_nn.data.instance_cropping import InstanceCropper +from sleap_nn.data.normalization import Normalizer +from sleap_nn.data.pipelines import SleapDataset, TopdownConfmapsPipeline +from sleap_nn.data.providers import LabelsReader + + +def test_sleap_dataset(minimal_instance): + datapipe = LabelsReader.from_filename(filename=minimal_instance) + datapipe = Normalizer(datapipe) + datapipe = InstanceCentroidFinder(datapipe) + datapipe = InstanceCropper(datapipe, (160, 160)) + datapipe = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=2) + datapipe = SleapDataset(datapipe) + + sample = next(iter(datapipe)) + assert len(sample) == 2 + assert sample[0].shape == (1, 160, 160) + assert sample[1].shape == (2, 80, 80) + + +def test_topdownconfmapspipeline(minimal_instance): + pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) + datapipe = pipeline.make_base_pipeline( + data_provider=LabelsReader, filename=minimal_instance + ) + + sample = next(iter(datapipe)) + assert len(sample) == 2 + assert sample[0].shape == (1, 160, 160) + assert sample[1].shape == (2, 80, 80) \ No newline at end of file From c861c7263e3751fd60e1fbebeea2cc667fb312ca Mon Sep 17 00:00:00 2001 From: alckasoc Date: Mon, 4 Sep 2023 23:42:21 -0700 Subject: [PATCH 44/55] removed configs --- sleap_nn/config/__init__.py | 1 - sleap_nn/config/data.py | 14 -------------- tests/data/test_pipelines.py | 14 ++++++++++++-- 3 files changed, 12 insertions(+), 17 deletions(-) delete mode 100644 sleap_nn/config/__init__.py delete mode 100644 sleap_nn/config/data.py diff --git a/sleap_nn/config/__init__.py b/sleap_nn/config/__init__.py deleted file mode 100644 index 16afead9..00000000 --- a/sleap_nn/config/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Modules relating to configuring data pipelines.""" diff --git a/sleap_nn/config/data.py b/sleap_nn/config/data.py deleted file mode 100644 index c8f81143..00000000 --- a/sleap_nn/config/data.py +++ /dev/null @@ -1,14 +0,0 @@ -"""This module implements base configurations for data pipelines.""" - -from omegaconf import OmegaConf - -# Base TopDownConfmapsPipeline data config. -base_topdown_data_config = OmegaConf.create( - { - "preprocessing": { - "crop_hw": (160, 160), - "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, - }, - "augmentation_config": {"random_crop": 0.0, "random_crop_hw": (160, 160)}, - } -) diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index 023bafb9..feb6da6e 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -1,6 +1,6 @@ import torch +from omegaconf import OmegaConf -from sleap_nn.config.data import base_topdown_data_config from sleap_nn.data.confidence_maps import ConfidenceMapGenerator from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.instance_cropping import InstanceCropper @@ -24,6 +24,16 @@ def test_sleap_dataset(minimal_instance): def test_topdownconfmapspipeline(minimal_instance): + base_topdown_data_config = OmegaConf.create( + { + "preprocessing": { + "crop_hw": (160, 160), + "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, + }, + "augmentation_config": {"random_crop": 0.0, "random_crop_hw": (160, 160)}, + } + ) + pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) datapipe = pipeline.make_base_pipeline( data_provider=LabelsReader, filename=minimal_instance @@ -32,4 +42,4 @@ def test_topdownconfmapspipeline(minimal_instance): sample = next(iter(datapipe)) assert len(sample) == 2 assert sample[0].shape == (1, 160, 160) - assert sample[1].shape == (2, 80, 80) \ No newline at end of file + assert sample[1].shape == (2, 80, 80) From 31aadc182607cb94333088e1c3a92e076e8f3b4d Mon Sep 17 00:00:00 2001 From: alckasoc Date: Tue, 5 Sep 2023 19:45:30 -0700 Subject: [PATCH 45/55] updated augmentation class --- sleap_nn/data/augmentation.py | 31 ++++++++++++++---- sleap_nn/data/instance_centroids.py | 8 ++--- sleap_nn/data/instance_cropping.py | 24 +++++++------- sleap_nn/data/pipelines.py | 47 +++++++++++++++++++++++++--- tests/data/test_instance_cropping.py | 2 +- tests/data/test_pipelines.py | 34 +++++++++++++++++++- 6 files changed, 117 insertions(+), 29 deletions(-) diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index 3382ceae..b445bd6a 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -285,9 +285,28 @@ def __init__( def __iter__(self): """Return an example dictionary with the augmented image and instances.""" for ex in self.source_dp: - inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2) - image, instances = ex["image"], ex["instances"].reshape( - inst_shape[0], -1, 2 - ) - aug_image, aug_instances = self.augmenter(image, instances) - yield {"image": aug_image, "instances": aug_instances.reshape(*inst_shape)} + if "instance_image" in ex and "instance" in ex: + inst_shape = ex["instance"].shape + # (B, channels, height, width), (1, num_nodes, 2) + image, instances = ex["instance_image"], ex["instance"].unsqueeze(0) + aug_image, aug_instances = self.augmenter(image, instances) + ex.update( + { + "instance_image": aug_image, + "instance": aug_instances.reshape(*inst_shape), + } + ) + elif "image" in ex and "instances" in ex: + inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2) + image, instances = ex["image"], ex["instances"].reshape( + inst_shape[0], -1, 2 + ) # (B, channels, height, width), (B, num_instances x num_nodes, 2) + + aug_image, aug_instances = self.augmenter(image, instances) + ex.update( + { + "image": aug_image, + "instances": aug_instances.reshape(*inst_shape), + } + ) + yield ex diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index b15deb91..6a2596db 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -86,8 +86,8 @@ def __init__( def __iter__(self): """Add `"centroids"` key to example.""" - for example in self.source_dp: - example["centroids"] = find_centroids( - example["instances"], anchor_ind=self.anchor_ind + for ex in self.source_dp: + ex["centroids"] = find_centroids( + ex["instances"], anchor_ind=self.anchor_ind ) - yield example + yield ex diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 9f80e56d..3682505f 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -68,33 +68,33 @@ def __init__(self, source_dp: IterDataPipe, crop_hw: Tuple[int, int]): def __iter__(self): """Generate instance cropped examples.""" - for example in self.source_dp: - image = example["image"] # (frames, channels, height, width) - instances = example["instances"] # (frames, n_instances, n_nodes, 2) - centroids = example["centroids"] # (frames, n_instances, 2) + for ex in self.source_dp: + image = ex["image"] # (B, channels, height, width) + instances = ex["instances"] # (B, n_instances, num_nodes, 2) + centroids = ex["centroids"] # (B, n_instances, 2) for instance, centroid in zip(instances[0], centroids[0]): # Generate bounding boxes from centroid. - bbox = torch.unsqueeze( + instance_bbox = torch.unsqueeze( make_centered_bboxes(centroid, self.crop_hw[0], self.crop_hw[1]), 0 - ) # (frames, 4, 2) + ) # (B, 4, 2) box_size = (self.crop_hw[0], self.crop_hw[1]) - # Generate cropped image of shape (frames, channels, crop_height, crop_width) + # Generate cropped image of shape (B, channels, crop_height, crop_width) instance_image = crop_and_resize( image, - boxes=bbox, + boxes=instance_bbox, size=box_size, ) # Access top left point (x,y) of bounding box and subtract this offset from # position of nodes. - point = bbox[0][0] + point = instance_bbox[0][0] center_instance = instance - point instance_example = { - "instance_image": instance_image, # (frames, channels, crop_height, crop_width) - "bbox": bbox, # (frames, 4, 2) - "instance": center_instance, # (n_instances, 2) + "instance_image": instance_image, # (B, channels, crop_height, crop_width) + "instance_bbox": instance_bbox, # (B, 4, 2) + "instance": center_instance, # (num_nodes, 2) } yield instance_example diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index 1a188bfc..354109b0 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -32,11 +32,9 @@ def __init__(self, source_dp: IterDataPipe): def __iter__(self): """Return a tuple with the cropped image and the heatmap.""" for example in self.dp: - # reszie_img = kornia.geometry.transform.resize(example["crop_img"], (512, 512)) if len(example["instance_image"].shape) == 4: example["instance_image"] = example["instance_image"].squeeze(dim=0) torch.cuda.empty_cache() - # yield example["instance_image"].cuda(), example["confidence_maps"].cuda() yield example["instance_image"], example["confidence_maps"] @@ -74,11 +72,11 @@ def make_base_pipeline( datapipe = InstanceCentroidFinder(datapipe) datapipe = InstanceCropper(datapipe, self.data_config.preprocessing.crop_hw) - if self.data_config.augmentation_config.random_crop: + if self.data_config.augmentation_config.random_crop.random_crop_p: datapipe = KorniaAugmenter( datapipe, - random_crop_hw=self.data_config.augmentation_config.random_crop_hw, - random_crop_p=1.0, + random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw, + random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, ) datapipe = ConfidenceMapGenerator( @@ -89,3 +87,42 @@ def make_base_pipeline( datapipe = SleapDataset(datapipe) return datapipe + + def make_training_pipeline( + self, data_provider: IterDataPipe, filename: str + ) -> IterDataPipe: + """Create training pipeline with input data only. + + Args: + data_provider: A `Provider` that generates data examples, typically a + `LabelsReader` instance. + filename: A string path to the name of the `.slp` file. + + Returns: + An `IterDataPipe` instance configured to produce input examples. + """ + datapipe = data_provider.from_filename(filename=filename) + datapipe = Normalizer(datapipe) + + datapipe = InstanceCentroidFinder(datapipe) + datapipe = InstanceCropper(datapipe, self.data_config.preprocessing.crop_hw) + + if self.data_config.augmentation_config.random_crop.random_crop_p: + datapipe = KorniaAugmenter( + datapipe, + random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw, + random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, + ) + + datapipe = KorniaAugmenter( + datapipe, **dict(self.data_config.augmentation_config.augmentations) + ) + + datapipe = ConfidenceMapGenerator( + datapipe, + sigma=self.data_config.preprocessing.conf_map_gen.sigma, + output_stride=self.data_config.preprocessing.conf_map_gen.output_stride, + ) + datapipe = SleapDataset(datapipe) + + return datapipe diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 29fd91f3..78114992 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -32,7 +32,7 @@ def test_instance_cropper(minimal_instance): # Test shapes. assert sample["instance"].shape == (2, 2) assert sample["instance_image"].shape == (1, 1, 100, 100) - assert sample["bbox"].shape == (1, 4, 2) + assert sample["instance_bbox"].shape == (1, 4, 2) # Test samples. gt = torch.Tensor( diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index feb6da6e..5e71e0aa 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -30,7 +30,29 @@ def test_topdownconfmapspipeline(minimal_instance): "crop_hw": (160, 160), "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, }, - "augmentation_config": {"random_crop": 0.0, "random_crop_hw": (160, 160)}, + "augmentation_config": { + "random_crop": {"random_crop_p": 0.0, "random_crop_hw": (160, 160)}, + "augmentations": { + "rotation": 15.0, + "scale": 0.05, + "translate": (0.02, 0.02), + "affine_p": 0.0, + "uniform_noise": (0.0, 0.04), + "uniform_noise_p": 0.0, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.0, + "contrast": (0.5, 2.0), + "contrast_p": 0.0, + "brightness": 0.0, + "brightness_p": 0.0, + "erase_scale": (0.0001, 0.01), + "erase_ratio": (1, 1), + "erase_p": 0.0, + "mixup_lambda": None, + "mixup_p": 0.0, + }, + }, } ) @@ -43,3 +65,13 @@ def test_topdownconfmapspipeline(minimal_instance): assert len(sample) == 2 assert sample[0].shape == (1, 160, 160) assert sample[1].shape == (2, 80, 80) + + pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) + datapipe = pipeline.make_training_pipeline( + data_provider=LabelsReader, filename=minimal_instance + ) + + sample = next(iter(datapipe)) + assert len(sample) == 2 + assert sample[0].shape == (1, 160, 160) + assert sample[1].shape == (2, 80, 80) From 663015588cdf08e42cd0f8aa4441d0160df980d1 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Tue, 5 Sep 2023 22:39:50 -0700 Subject: [PATCH 46/55] modified test --- tests/data/test_pipelines.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index 5e71e0aa..b6f953a3 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -31,26 +31,26 @@ def test_topdownconfmapspipeline(minimal_instance): "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, }, "augmentation_config": { - "random_crop": {"random_crop_p": 0.0, "random_crop_hw": (160, 160)}, + "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)}, "augmentations": { "rotation": 15.0, "scale": 0.05, "translate": (0.02, 0.02), - "affine_p": 0.0, + "affine_p": 0.5, "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.0, + "uniform_noise_p": 0.5, "gaussian_noise_mean": 0.02, "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.0, + "gaussian_noise_p": 0.5, "contrast": (0.5, 2.0), - "contrast_p": 0.0, + "contrast_p": 0.5, "brightness": 0.0, - "brightness_p": 0.0, + "brightness_p": 0.5, "erase_scale": (0.0001, 0.01), "erase_ratio": (1, 1), - "erase_p": 0.0, + "erase_p": 0.5, "mixup_lambda": None, - "mixup_p": 0.0, + "mixup_p": 0.5, }, }, } From 55cf1a93e60de21e91d8de003e8ea339e45999af Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 6 Sep 2023 14:29:26 -0700 Subject: [PATCH 47/55] updated pipelines docstring --- sleap_nn/data/pipelines.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index 354109b0..56d31822 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -43,10 +43,6 @@ class TopdownConfmapsPipeline: Attributes: data_config: Data-related configuration. - optimization_config: Optimization-related configuration. - instance_confmap_head: Instantiated head describing the output centered - confidence maps tensor. - offsets_head: Optional head describing the offset refinement maps. """ def __init__(self, data_config: DictConfig) -> None: From 9715c01a284b82701ba5824f3fd3e9072aad5564 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 6 Sep 2023 14:40:50 -0700 Subject: [PATCH 48/55] removed make_base_pipeline and updated tests --- sleap_nn/data/pipelines.py | 42 ++++-------------------------------- tests/data/test_pipelines.py | 11 +--------- 2 files changed, 5 insertions(+), 48 deletions(-) diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index 56d31822..c7b44f8a 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -49,41 +49,6 @@ def __init__(self, data_config: DictConfig) -> None: """Initialize the data config.""" self.data_config = data_config - def make_base_pipeline( - self, data_provider: IterDataPipe, filename: str - ) -> IterDataPipe: - """Create base pipeline with input data only. - - Args: - data_provider: A `Provider` that generates data examples, typically a - `LabelsReader` instance. - filename: A string path to the name of the `.slp` file. - - Returns: - An `IterDataPipe` instance configured to produce input examples. - """ - datapipe = data_provider.from_filename(filename=filename) - datapipe = Normalizer(datapipe) - - datapipe = InstanceCentroidFinder(datapipe) - datapipe = InstanceCropper(datapipe, self.data_config.preprocessing.crop_hw) - - if self.data_config.augmentation_config.random_crop.random_crop_p: - datapipe = KorniaAugmenter( - datapipe, - random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw, - random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, - ) - - datapipe = ConfidenceMapGenerator( - datapipe, - sigma=self.data_config.preprocessing.conf_map_gen.sigma, - output_stride=self.data_config.preprocessing.conf_map_gen.output_stride, - ) - datapipe = SleapDataset(datapipe) - - return datapipe - def make_training_pipeline( self, data_provider: IterDataPipe, filename: str ) -> IterDataPipe: @@ -110,9 +75,10 @@ def make_training_pipeline( random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, ) - datapipe = KorniaAugmenter( - datapipe, **dict(self.data_config.augmentation_config.augmentations) - ) + if self.data_config.augmentation_config.use_augmentations: + datapipe = KorniaAugmenter( + datapipe, **dict(self.data_config.augmentation_config.augmentations) + ) datapipe = ConfidenceMapGenerator( datapipe, diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index b6f953a3..cbc61ea4 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -32,6 +32,7 @@ def test_topdownconfmapspipeline(minimal_instance): }, "augmentation_config": { "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)}, + "use_augmentations": False, "augmentations": { "rotation": 15.0, "scale": 0.05, @@ -56,16 +57,6 @@ def test_topdownconfmapspipeline(minimal_instance): } ) - pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) - datapipe = pipeline.make_base_pipeline( - data_provider=LabelsReader, filename=minimal_instance - ) - - sample = next(iter(datapipe)) - assert len(sample) == 2 - assert sample[0].shape == (1, 160, 160) - assert sample[1].shape == (2, 80, 80) - pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) datapipe = pipeline.make_training_pipeline( data_provider=LabelsReader, filename=minimal_instance From 7deec6535da7f50829c2a3354d6ad547f1ceb795 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 6 Sep 2023 14:44:00 -0700 Subject: [PATCH 49/55] removed empty_cache in SleapDataset --- sleap_nn/data/pipelines.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index c7b44f8a..e6722e99 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -34,7 +34,6 @@ def __iter__(self): for example in self.dp: if len(example["instance_image"].shape) == 4: example["instance_image"] = example["instance_image"].squeeze(dim=0) - torch.cuda.empty_cache() yield example["instance_image"], example["confidence_maps"] From b1ef93cecfe2a88426e53a95a9a449fb8747c378 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Wed, 6 Sep 2023 17:08:14 -0700 Subject: [PATCH 50/55] updated test_pipelines --- tests/data/test_pipelines.py | 43 ++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index cbc61ea4..572c39c3 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -66,3 +66,46 @@ def test_topdownconfmapspipeline(minimal_instance): assert len(sample) == 2 assert sample[0].shape == (1, 160, 160) assert sample[1].shape == (2, 80, 80) + + base_topdown_data_config = OmegaConf.create( + { + "preprocessing": { + "crop_hw": (160, 160), + "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, + }, + "augmentation_config": { + "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)}, + "use_augmentations": True, + "augmentations": { + "rotation": 15.0, + "scale": 0.05, + "translate": (0.02, 0.02), + "affine_p": 0.5, + "uniform_noise": (0.0, 0.04), + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast": (0.5, 2.0), + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + "erase_scale": (0.0001, 0.01), + "erase_ratio": (1, 1), + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, + }, + }, + } + ) + + pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) + datapipe = pipeline.make_training_pipeline( + data_provider=LabelsReader, filename=minimal_instance + ) + + sample = next(iter(datapipe)) + assert len(sample) == 2 + assert sample[0].shape == (1, 160, 160) + assert sample[1].shape == (2, 80, 80) From ae523d11c05e67a9aceba14963c6376e49b99a77 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Tue, 12 Sep 2023 09:17:32 -0700 Subject: [PATCH 51/55] updated sleapdataset to return a dict --- sleap_nn/data/pipelines.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index e6722e99..b75d7cc3 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -30,11 +30,21 @@ def __init__(self, source_dp: IterDataPipe): self.dp = source_dp def __iter__(self): - """Return a tuple with the cropped image and the heatmap.""" + """Return a dictionary with the relevant outputs. + + This dictionary includes: + - image: the full frame image + - instances: all keypoints of all instances in the frame image + - centroids: all centroids of all instances in the frame image + - instance: the individual instance's keypoints + - instance_bbox: the individual instance's bbox + - instance_image: the individual instance's cropped image + - confidence_maps: the individual instance's heatmap + """ for example in self.dp: if len(example["instance_image"].shape) == 4: example["instance_image"] = example["instance_image"].squeeze(dim=0) - yield example["instance_image"], example["confidence_maps"] + yield example class TopdownConfmapsPipeline: From b4219379a41f54bcc1b5782cf5bf3e9b923e4962 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Tue, 12 Sep 2023 10:32:29 -0700 Subject: [PATCH 52/55] added key filter transformer block, removed sleap dataset, added type hinting --- sleap_nn/architectures/common.py | 4 +- sleap_nn/data/augmentation.py | 4 +- sleap_nn/data/confidence_maps.py | 8 ++-- sleap_nn/data/general.py | 38 +++++++++++++++++++ sleap_nn/data/instance_centroids.py | 6 +-- sleap_nn/data/instance_cropping.py | 7 +++- sleap_nn/data/normalization.py | 6 ++- sleap_nn/data/pipelines.py | 36 +----------------- sleap_nn/data/providers.py | 7 ++-- tests/data/test_instance_cropping.py | 14 ++++++- tests/data/test_pipelines.py | 56 ++++++++++++++++++++++------ 11 files changed, 123 insertions(+), 63 deletions(-) create mode 100644 sleap_nn/data/general.py diff --git a/sleap_nn/architectures/common.py b/sleap_nn/architectures/common.py index 1865e471..839d4fea 100644 --- a/sleap_nn/architectures/common.py +++ b/sleap_nn/architectures/common.py @@ -1,4 +1,6 @@ """Common utilities for architecture and model building.""" +from typing import List + import torch from torch import nn from torch.nn import functional as F @@ -134,7 +136,7 @@ def get_act_fn(activation: str) -> nn.Module: return activations[activation] -def get_children_layers(model: torch.nn.Module): +def get_children_layers(model: torch.nn.Module) -> List[nn.Module]: """Recursively retrieves a flattened list of all children modules and submodules within the given model. Args: diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index b445bd6a..d38d7910 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -167,7 +167,7 @@ def __init__( mixup_p: float = 0.0, random_crop_hw: Tuple[int, int] = (0, 0), random_crop_p: float = 0.0, - ): + ) -> None: """Initialize the block and the augmentation pipeline.""" self.source_dp = source_dp self.rotation = rotation @@ -282,7 +282,7 @@ def __init__( same_on_batch=True, ) - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: """Return an example dictionary with the augmented image and instances.""" for ex in self.source_dp: if "instance_image" in ex and "instance" in ex: diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index 42b58310..deae59a4 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -1,5 +1,5 @@ """Generate confidence maps.""" -from typing import Optional +from typing import Dict, Optional import sleap_io as sio import torch @@ -10,7 +10,7 @@ def make_confmaps( points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float -): +) -> torch.Tensor: """Make confidence maps from a set of points from a single instance. Args: @@ -70,7 +70,7 @@ def __init__( output_stride: int = 1, instance_key: str = "instance", image_key: str = "instance_image", - ): + ) -> None: """Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride.""" self.source_dp = source_dp self.sigma = sigma @@ -78,7 +78,7 @@ def __init__( self.instance_key = instance_key self.image_key = image_key - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: """Generate confidence maps for each example.""" for example in self.source_dp: instance = example[self.instance_key] diff --git a/sleap_nn/data/general.py b/sleap_nn/data/general.py new file mode 100644 index 00000000..251b2bd0 --- /dev/null +++ b/sleap_nn/data/general.py @@ -0,0 +1,38 @@ +"""General purpose transformers for common pipeline processing tasks.""" +from typing import Callable, Dict, List, Text + +from torch.utils.data.datapipes.datapipe import IterDataPipe + + +class KeyFilter(IterDataPipe): + """Transformer for filtering example keys.""" + + def __init__(self, source_dp: IterDataPipe, keep_keys: List[Text] = None) -> None: + """Initialize KeyFilter with the source `DataPipe.""" + self.dp = source_dp + self.keep_keys = keep_keys + + def __iter__(self): + """Return a dictionary filtered for the relevant outputs. + + The input dictionary includes: + - image: the full frame image + - instances: all keypoints of all instances in the frame image + - centroids: all centroids of all instances in the frame image + - instance: the individual instance's keypoints + - instance_bbox: the individual instance's bbox + - instance_image: the individual instance's cropped image + - confidence_maps: the individual instance's heatmap + """ + for example in self.dp: + if self.keep_keys is None: + # If keep_keys is not provided, yield the entire example. + yield example + else: + # Filter the example dictionary based on keep_keys. + filtered_example = { + key: value + for key, value in example.items() + if key in self.keep_keys + } + yield filtered_example diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index 6a2596db..f32cba52 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -1,5 +1,5 @@ """Handle calculation of instance centroids.""" -from typing import Optional +from typing import Dict, Optional import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -79,12 +79,12 @@ def __init__( self, source_dp: IterDataPipe, anchor_ind: Optional[int] = None, - ): + ) -> None: """Initialize InstanceCentroidFinder with the source `DataPipe.""" self.source_dp = source_dp self.anchor_ind = anchor_ind - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: """Add `"centroids"` key to example.""" for ex in self.source_dp: ex["centroids"] = find_centroids( diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 3682505f..f53d0c73 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -93,8 +93,11 @@ def __iter__(self): center_instance = instance - point instance_example = { - "instance_image": instance_image, # (B, channels, crop_height, crop_width) + "instance_image": instance_image.squeeze( + 0 + ), # (B=1, channels, crop_height, crop_width) "instance_bbox": instance_bbox, # (B, 4, 2) "instance": center_instance, # (num_nodes, 2) } - yield instance_example + ex.update(instance_example) + yield ex diff --git a/sleap_nn/data/normalization.py b/sleap_nn/data/normalization.py index 8c150ebc..1b6f61c7 100644 --- a/sleap_nn/data/normalization.py +++ b/sleap_nn/data/normalization.py @@ -1,4 +1,6 @@ """This module implements data pipeline blocks for normalization operations.""" +from typing import Dict + import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -16,11 +18,11 @@ class Normalizer(IterDataPipe): def __init__( self, source_dp: IterDataPipe, - ): + ) -> None: """Initialize the block.""" self.source_dp = source_dp - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: """Return an example dictionary with the augmented image and instance.""" for ex in self.source_dp: if not torch.is_floating_point(ex["image"]): diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index b75d7cc3..a443d114 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -9,42 +9,10 @@ from sleap_nn.data.augmentation import KorniaAugmenter from sleap_nn.data.confidence_maps import ConfidenceMapGenerator +from sleap_nn.data.general import KeyFilter from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.instance_cropping import InstanceCropper from sleap_nn.data.normalization import Normalizer -from sleap_nn.data.providers import LabelsReader - - -class SleapDataset(IterDataPipe): - """Returns image and corresponding heatmap for the DataLoader. - - This class is to return the image and its corresponding confidence map - to load the dataset with the DataLoader class - - Attributes: - source_dp: The previous `DataPipe` with samples that contain an `instances` key. - """ - - def __init__(self, source_dp: IterDataPipe): - """Initialize SleapDataset with the source `DataPipe.""" - self.dp = source_dp - - def __iter__(self): - """Return a dictionary with the relevant outputs. - - This dictionary includes: - - image: the full frame image - - instances: all keypoints of all instances in the frame image - - centroids: all centroids of all instances in the frame image - - instance: the individual instance's keypoints - - instance_bbox: the individual instance's bbox - - instance_image: the individual instance's cropped image - - confidence_maps: the individual instance's heatmap - """ - for example in self.dp: - if len(example["instance_image"].shape) == 4: - example["instance_image"] = example["instance_image"].squeeze(dim=0) - yield example class TopdownConfmapsPipeline: @@ -94,6 +62,6 @@ def make_training_pipeline( sigma=self.data_config.preprocessing.conf_map_gen.sigma, output_stride=self.data_config.preprocessing.conf_map_gen.output_stride, ) - datapipe = SleapDataset(datapipe) + datapipe = KeyFilter(datapipe, keep_keys=self.data_config.general.keep_keys) return datapipe diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index d090775c..f8ce2121 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -1,4 +1,5 @@ """This module implements pipeline blocks for reading input data such as labels.""" +from typing import Dict import numpy as np import sleap_io as sio import torch @@ -16,17 +17,17 @@ class LabelsReader(IterDataPipe): accessed through a torchdata DataPipe """ - def __init__(self, labels: sio.Labels): + def __init__(self, labels: sio.Labels) -> None: """Initialize labels attribute of the class.""" self.labels = labels @classmethod - def from_filename(cls, filename: str): + def from_filename(cls, filename: str) -> "LabelsReader": """Create LabelsReader from a .slp filename.""" labels = sio.load_slp(filename) return cls(labels) - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: """Return an example dictionary containing the following elements. "image": A torch.Tensor containing full raw frame image as a uint8 array diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 78114992..662f709a 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -29,9 +29,21 @@ def test_instance_cropper(minimal_instance): datapipe = InstanceCropper(datapipe, (100, 100)) sample = next(iter(datapipe)) + gt_sample_keys = [ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + ] + # Test shapes. + assert len(sample.keys()) == 6 + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key assert sample["instance"].shape == (2, 2) - assert sample["instance_image"].shape == (1, 1, 100, 100) + assert sample["instance_image"].shape == (1, 100, 100) assert sample["instance_bbox"].shape == (1, 4, 2) # Test samples. diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index 572c39c3..15d7ebbf 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -2,10 +2,11 @@ from omegaconf import OmegaConf from sleap_nn.data.confidence_maps import ConfidenceMapGenerator +from sleap_nn.data.general import KeyFilter from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.instance_cropping import InstanceCropper from sleap_nn.data.normalization import Normalizer -from sleap_nn.data.pipelines import SleapDataset, TopdownConfmapsPipeline +from sleap_nn.data.pipelines import TopdownConfmapsPipeline from sleap_nn.data.providers import LabelsReader @@ -15,17 +16,31 @@ def test_sleap_dataset(minimal_instance): datapipe = InstanceCentroidFinder(datapipe) datapipe = InstanceCropper(datapipe, (160, 160)) datapipe = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=2) - datapipe = SleapDataset(datapipe) + datapipe = KeyFilter(datapipe, keep_keys=None) + + gt_sample_keys = [ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + "confidence_maps", + ] sample = next(iter(datapipe)) - assert len(sample) == 2 - assert sample[0].shape == (1, 160, 160) - assert sample[1].shape == (2, 80, 80) + assert len(sample.keys()) == len(gt_sample_keys) + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["instance_image"].shape == (1, 160, 160) + assert sample["confidence_maps"].shape == (2, 80, 80) def test_topdownconfmapspipeline(minimal_instance): base_topdown_data_config = OmegaConf.create( { + "general": {"keep_keys": ["instance_image", "confidence_maps"]}, "preprocessing": { "crop_hw": (160, 160), "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, @@ -62,13 +77,19 @@ def test_topdownconfmapspipeline(minimal_instance): data_provider=LabelsReader, filename=minimal_instance ) + gt_sample_keys = ["instance_image", "confidence_maps"] + sample = next(iter(datapipe)) - assert len(sample) == 2 - assert sample[0].shape == (1, 160, 160) - assert sample[1].shape == (2, 80, 80) + assert len(sample.keys()) == len(gt_sample_keys) + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["instance_image"].shape == (1, 160, 160) + assert sample["confidence_maps"].shape == (2, 80, 80) base_topdown_data_config = OmegaConf.create( { + "general": {"keep_keys": None}, "preprocessing": { "crop_hw": (160, 160), "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, @@ -105,7 +126,20 @@ def test_topdownconfmapspipeline(minimal_instance): data_provider=LabelsReader, filename=minimal_instance ) + gt_sample_keys = [ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + "confidence_maps", + ] + sample = next(iter(datapipe)) - assert len(sample) == 2 - assert sample[0].shape == (1, 160, 160) - assert sample[1].shape == (2, 80, 80) + assert len(sample.keys()) == len(gt_sample_keys) + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["instance_image"].shape == (1, 160, 160) + assert sample["confidence_maps"].shape == (2, 80, 80) From fe61f15f0ea3ea404dae5feb676f737d1a938201 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Tue, 12 Sep 2023 10:45:27 -0700 Subject: [PATCH 53/55] updated type hints --- sleap_nn/data/augmentation.py | 4 ++-- sleap_nn/data/confidence_maps.py | 4 ++-- sleap_nn/data/general.py | 5 +++-- sleap_nn/data/instance_centroids.py | 4 ++-- sleap_nn/data/instance_cropping.py | 6 +++--- sleap_nn/data/normalization.py | 4 ++-- sleap_nn/data/providers.py | 5 +++-- tests/data/test_instance_cropping.py | 2 +- 8 files changed, 18 insertions(+), 16 deletions(-) diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index d38d7910..6db595c1 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -1,5 +1,5 @@ """This module implements data pipeline blocks for augmentation operations.""" -from typing import Any, Dict, Optional, Text, Tuple, Union +from typing import Any, Dict, Iterator, Optional, Text, Tuple, Union import kornia as K import torch @@ -282,7 +282,7 @@ def __init__( same_on_batch=True, ) - def __iter__(self) -> Dict[str, torch.Tensor]: + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return an example dictionary with the augmented image and instances.""" for ex in self.source_dp: if "instance_image" in ex and "instance" in ex: diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index deae59a4..dbdbf722 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -1,5 +1,5 @@ """Generate confidence maps.""" -from typing import Dict, Optional +from typing import Dict, Iterator, Optional import sleap_io as sio import torch @@ -78,7 +78,7 @@ def __init__( self.instance_key = instance_key self.image_key = image_key - def __iter__(self) -> Dict[str, torch.Tensor]: + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Generate confidence maps for each example.""" for example in self.source_dp: instance = example[self.instance_key] diff --git a/sleap_nn/data/general.py b/sleap_nn/data/general.py index 251b2bd0..9977504a 100644 --- a/sleap_nn/data/general.py +++ b/sleap_nn/data/general.py @@ -1,6 +1,7 @@ """General purpose transformers for common pipeline processing tasks.""" -from typing import Callable, Dict, List, Text +from typing import Callable, Dict, Iterator, List, Text +import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -12,7 +13,7 @@ def __init__(self, source_dp: IterDataPipe, keep_keys: List[Text] = None) -> Non self.dp = source_dp self.keep_keys = keep_keys - def __iter__(self): + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return a dictionary filtered for the relevant outputs. The input dictionary includes: diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index f32cba52..6240fd35 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -1,5 +1,5 @@ """Handle calculation of instance centroids.""" -from typing import Dict, Optional +from typing import Dict, Iterator, Optional import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -84,7 +84,7 @@ def __init__( self.source_dp = source_dp self.anchor_ind = anchor_ind - def __iter__(self) -> Dict[str, torch.Tensor]: + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Add `"centroids"` key to example.""" for ex in self.source_dp: ex["centroids"] = find_centroids( diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index f53d0c73..6f57599e 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -1,5 +1,5 @@ """Handle cropping of instances.""" -from typing import Optional, Tuple +from typing import Dict, Iterator, Optional, Tuple import numpy as np import sleap_io as sio @@ -61,12 +61,12 @@ class InstanceCropper(IterDataPipe): crop_hw: Height and Width of the crop in pixels """ - def __init__(self, source_dp: IterDataPipe, crop_hw: Tuple[int, int]): + def __init__(self, source_dp: IterDataPipe, crop_hw: Tuple[int, int]) -> None: """Initialize InstanceCropper with the source `DataPipe.""" self.source_dp = source_dp self.crop_hw = crop_hw - def __iter__(self): + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Generate instance cropped examples.""" for ex in self.source_dp: image = ex["image"] # (B, channels, height, width) diff --git a/sleap_nn/data/normalization.py b/sleap_nn/data/normalization.py index 1b6f61c7..ae7a7b6d 100644 --- a/sleap_nn/data/normalization.py +++ b/sleap_nn/data/normalization.py @@ -1,5 +1,5 @@ """This module implements data pipeline blocks for normalization operations.""" -from typing import Dict +from typing import Dict, Iterator import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -22,7 +22,7 @@ def __init__( """Initialize the block.""" self.source_dp = source_dp - def __iter__(self) -> Dict[str, torch.Tensor]: + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return an example dictionary with the augmented image and instance.""" for ex in self.source_dp: if not torch.is_floating_point(ex["image"]): diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index f8ce2121..6d778102 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -1,5 +1,6 @@ """This module implements pipeline blocks for reading input data such as labels.""" -from typing import Dict +from typing import Dict, Iterator + import numpy as np import sleap_io as sio import torch @@ -27,7 +28,7 @@ def from_filename(cls, filename: str) -> "LabelsReader": labels = sio.load_slp(filename) return cls(labels) - def __iter__(self) -> Dict[str, torch.Tensor]: + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return an example dictionary containing the following elements. "image": A torch.Tensor containing full raw frame image as a uint8 array diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 662f709a..8f112991 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -39,7 +39,7 @@ def test_instance_cropper(minimal_instance): ] # Test shapes. - assert len(sample.keys()) == 6 + assert len(sample.keys()) == len(gt_sample_keys) for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): assert gt_key == key assert sample["instance"].shape == (2, 2) From 0214abb559695dd81ab675e78458c91e2efbb175 Mon Sep 17 00:00:00 2001 From: alckasoc Date: Tue, 12 Sep 2023 10:53:56 -0700 Subject: [PATCH 54/55] added coderabbit suggestions --- sleap_nn/data/general.py | 2 +- sleap_nn/data/instance_cropping.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sleap_nn/data/general.py b/sleap_nn/data/general.py index 9977504a..06c9b7de 100644 --- a/sleap_nn/data/general.py +++ b/sleap_nn/data/general.py @@ -11,7 +11,7 @@ class KeyFilter(IterDataPipe): def __init__(self, source_dp: IterDataPipe, keep_keys: List[Text] = None) -> None: """Initialize KeyFilter with the source `DataPipe.""" self.dp = source_dp - self.keep_keys = keep_keys + self.keep_keys = set(keep_keys) if keep_keys else None def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return a dictionary filtered for the relevant outputs. diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 6f57599e..20b6855a 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -93,9 +93,7 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: center_instance = instance - point instance_example = { - "instance_image": instance_image.squeeze( - 0 - ), # (B=1, channels, crop_height, crop_width) + "instance_image": instance_image.squeeze(), # (B=1, channels, crop_height, crop_width) "instance_bbox": instance_bbox, # (B, 4, 2) "instance": center_instance, # (num_nodes, 2) } From e3b28da3b52eaaf34e4d5af2a6c7c10509a6ecbc Mon Sep 17 00:00:00 2001 From: alckasoc Date: Tue, 12 Sep 2023 10:56:05 -0700 Subject: [PATCH 55/55] fixed small squeeze issue --- sleap_nn/data/instance_cropping.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 20b6855a..6f57599e 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -93,7 +93,9 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: center_instance = instance - point instance_example = { - "instance_image": instance_image.squeeze(), # (B=1, channels, crop_height, crop_width) + "instance_image": instance_image.squeeze( + 0 + ), # (B=1, channels, crop_height, crop_width) "instance_bbox": instance_bbox, # (B, 4, 2) "instance": center_instance, # (num_nodes, 2) }