From ce5e6c7f94d72da7d9b27d8b459f20be69faf2ed Mon Sep 17 00:00:00 2001 From: David Samy Date: Thu, 20 Jul 2023 14:01:05 -0700 Subject: [PATCH 1/5] Add confidence maps --- sleap_nn/data/confidence_maps.py | 119 +++++++++++++++++++++++++++++++ tests/data/test_confmaps.py | 28 ++++++++ 2 files changed, 147 insertions(+) create mode 100644 sleap_nn/data/confidence_maps.py create mode 100644 tests/data/test_confmaps.py diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py new file mode 100644 index 00000000..6f744052 --- /dev/null +++ b/sleap_nn/data/confidence_maps.py @@ -0,0 +1,119 @@ +"""Generate confidence maps.""" +from torch.utils.data.datapipes.datapipe import IterDataPipe +from typing import Optional +import sleap_io as sio +import torch + + +def make_confmaps( + points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float +): + """Make confidence maps from a set of points from a single instance. + + Args: + points: A tensor of points of shape `(n_nodes, 2)` and dtype `torch.float32` where + the last axis corresponds to (x, y) pixel coordinates on the image. These + can contain NaNs to indicate missing points. + xv: Sampling grid vector for x-coordinates of shape `(grid_width,)` and dtype + `torch.float32`. This can be generated by + `sleap.nn.data.utils.make_grid_vectors`. + yv: Sampling grid vector for y-coordinates of shape `(grid_height,)` and dtype + `torch.float32`. This can be generated by + `sleap.nn.data.utils.make_grid_vectors`. + sigma: Standard deviation of the 2D Gaussian distribution sampled to generate + confidence maps. + + Returns: + Confidence maps as a tensor of shape `(grid_height, grid_width, n_nodes)` of + dtype `torch.float32`. + """ + x = torch.reshape(points[:, 0], (1, 1, -1)) + y = torch.reshape(points[:, 1], (1, 1, -1)) + cm = torch.exp( + -( + (torch.reshape(xv, (1, -1, 1)) - x) ** 2 + + (torch.reshape(yv, (-1, 1, 1)) - y) ** 2 + ) + / (2 * sigma**2) + ) + + # Replace NaNs with 0. + cm = torch.where(torch.isnan(cm), 0.0, cm) + return cm + + +def make_grid_vectors(image_height: int, image_width: int, output_stride: int): + """Make sampling grid vectors from image dimensions. + + Args: + image_height: Height of the image grid that will be sampled, specified as a + scalar integer. + image_width: width of the image grid that will be sampled, specified as a + scalar integer. + output_stride: Sampling step size, specified as a scalar integer. + + Returns: + Tuple of grid vectors (xv, yv). These are tensors of dtype torch.float32 with + shapes (grid_width,) and (grid_height,) respectively. + + The grid dimensions are calculated as: + grid_width = image_width // output_stride + grid_height = image_height // output_stride + """ + xv = torch.arange(0, image_width, step=output_stride).to( + torch.float32 + ) # (image_width,) + yv = torch.arange(0, image_height, step=output_stride).to( + torch.float32 + ) # (image_height,) + return xv, yv + + +class ConfidenceMapGenerator(IterDataPipe): + """DataPipe for generating confidence maps. + + This DataPipe will generate confidence maps for examples from the input pipeline. + Input examples must contain image of shape (frames, channels, crop_height, crop_width) + and instance of shape (n_instances, 2). + + Attributes: + source_dp: The input `IterDataPipe` with examples that contain an instance and + an image. + sigma: The standard deviation of the Gaussian distribution that is used to + generate confidence maps. + output_stride: The relative stride to use when generating confidence maps. + A larger stride will generate smaller confidence maps. + instance_key: The name of the key where the instance points are. + image_key: The name of the key where the image is. + """ + + def __init__( + self, + source_dp: IterDataPipe, + sigma: int = 1.5, + output_stride: int = 1, + instance_key: str = "instance", + image_key: str = "instance_image", + ): + """Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride.""" + self.source_dp = source_dp + self.sigma = sigma + self.output_stride = output_stride + self.instance_key = instance_key + self.image_key = image_key + + def __iter__(self): + """Generate confidence maps for each example.""" + for example in self.source_dp: + instance = example[self.instance_key] + width = example[self.image_key].shape[-1] + height = example[self.image_key].shape[-2] + + xv, yv = make_grid_vectors(height, width, self.output_stride) + + confidence_maps = make_confmaps( + instance, xv, yv, self.sigma + ) # (height, width, n_nodes) + + example["confidence_maps"] = confidence_maps + yield example diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py new file mode 100644 index 00000000..ab26cf64 --- /dev/null +++ b/tests/data/test_confmaps.py @@ -0,0 +1,28 @@ +from sleap_nn.data.providers import LabelsReader +import torch +from sleap_nn.data.instance_cropping import make_centered_bboxes, InstanceCropper +from sleap_nn.data.instance_centroids import InstanceCentroidFinder +from sleap_nn.data.normalization import Normalizer +from sleap_nn.data.confidence_maps import ConfidenceMapGenerator + + +def test_confmaps(minimal_instance): + datapipe = LabelsReader.from_filename(minimal_instance) + datapipe = InstanceCentroidFinder(datapipe) + datapipe = Normalizer(datapipe) + datapipe = InstanceCropper(datapipe, 100, 100) + datapipe1 = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=1) + sample = next(iter(datapipe1)) + + assert sample["confidence_maps"].shape == (100, 100, 2) + assert torch.max(sample["confidence_maps"]) == torch.Tensor( + [0.989626109600067138671875] + ) + + datapipe2 = ConfidenceMapGenerator(datapipe, sigma=3.0, output_stride=2) + sample = next(iter(datapipe2)) + + assert sample["confidence_maps"].shape == (50, 50, 2) + assert torch.max(sample["confidence_maps"]) == torch.Tensor( + [0.99739634990692138671875] + ) From 5262a6ff9443bf631f9ca888b4c0c34814154766 Mon Sep 17 00:00:00 2001 From: David Samy <96505813+davidasamy@users.noreply.github.com> Date: Fri, 4 Aug 2023 10:33:58 -0700 Subject: [PATCH 2/5] Channel first order Co-authored-by: Talmo Pereira --- sleap_nn/data/confidence_maps.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index 6f744052..f7dd8b0e 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -27,12 +27,12 @@ def make_confmaps( Confidence maps as a tensor of shape `(grid_height, grid_width, n_nodes)` of dtype `torch.float32`. """ - x = torch.reshape(points[:, 0], (1, 1, -1)) - y = torch.reshape(points[:, 1], (1, 1, -1)) + x = torch.reshape(points[:, 0], (-1, 1, 1)) + y = torch.reshape(points[:, 1], (-1, 1, 1)) cm = torch.exp( -( - (torch.reshape(xv, (1, -1, 1)) - x) ** 2 - + (torch.reshape(yv, (-1, 1, 1)) - y) ** 2 + (torch.reshape(xv, (1, 1, -1)) - x) ** 2 + + (torch.reshape(yv, (1, -1, 1)) - y) ** 2 ) / (2 * sigma**2) ) From e8cb308c168abc27c3bee7a2324643497c215053 Mon Sep 17 00:00:00 2001 From: David Samy Date: Fri, 4 Aug 2023 11:13:16 -0700 Subject: [PATCH 3/5] Add test for replacing nans. --- tests/data/test_confmaps.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index ab26cf64..a65470c4 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -3,7 +3,11 @@ from sleap_nn.data.instance_cropping import make_centered_bboxes, InstanceCropper from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.normalization import Normalizer -from sleap_nn.data.confidence_maps import ConfidenceMapGenerator +from sleap_nn.data.confidence_maps import ( + ConfidenceMapGenerator, + make_confmaps, + make_grid_vectors, +) def test_confmaps(minimal_instance): @@ -14,7 +18,7 @@ def test_confmaps(minimal_instance): datapipe1 = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=1) sample = next(iter(datapipe1)) - assert sample["confidence_maps"].shape == (100, 100, 2) + assert sample["confidence_maps"].shape == (2, 100, 100) assert torch.max(sample["confidence_maps"]) == torch.Tensor( [0.989626109600067138671875] ) @@ -22,7 +26,25 @@ def test_confmaps(minimal_instance): datapipe2 = ConfidenceMapGenerator(datapipe, sigma=3.0, output_stride=2) sample = next(iter(datapipe2)) - assert sample["confidence_maps"].shape == (50, 50, 2) + assert sample["confidence_maps"].shape == (2, 50, 50) assert torch.max(sample["confidence_maps"]) == torch.Tensor( [0.99739634990692138671875] ) + + xv, yv = make_grid_vectors(2, 2, 1) + points = torch.Tensor([[1.0, 1.0], [torch.nan, torch.nan]]) + cms = make_confmaps(points, xv, yv, 2.0) + gt = torch.Tensor( + [ + [ + [0.77880078554153442382812500000, 0.88249689340591430664062500000], + [0.88249689340591430664062500000, 1.00000000000000000000000000000], + ], + [ + [0.00000000000000000000000000000, 0.00000000000000000000000000000], + [0.00000000000000000000000000000, 0.00000000000000000000000000000], + ], + ] + ) + + assert torch.equal(gt, cms) From a3bb725e847bc8fccd1e031ebceae6d134ec81b3 Mon Sep 17 00:00:00 2001 From: David Samy Date: Mon, 7 Aug 2023 16:53:21 -0700 Subject: [PATCH 4/5] Address changes --- sleap_nn/data/confidence_maps.py | 6 +++--- tests/data/test_confmaps.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index f7dd8b0e..2c83a546 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -24,7 +24,7 @@ def make_confmaps( confidence maps. Returns: - Confidence maps as a tensor of shape `(grid_height, grid_width, n_nodes)` of + Confidence maps as a tensor of shape `(n_nodes, grid_height, grid_width)` of dtype `torch.float32`. """ x = torch.reshape(points[:, 0], (-1, 1, 1)) @@ -83,8 +83,8 @@ class ConfidenceMapGenerator(IterDataPipe): generate confidence maps. output_stride: The relative stride to use when generating confidence maps. A larger stride will generate smaller confidence maps. - instance_key: The name of the key where the instance points are. - image_key: The name of the key where the image is. + instance_key: The name of the key where the instance points (n_instances, 2) are. + image_key: The name of the key where the image (frames, channels, crop_height, crop_width) is. """ def __init__( diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index a65470c4..c93ff269 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -37,14 +37,14 @@ def test_confmaps(minimal_instance): gt = torch.Tensor( [ [ - [0.77880078554153442382812500000, 0.88249689340591430664062500000], - [0.88249689340591430664062500000, 1.00000000000000000000000000000], + [0.7788, 0.8824], + [0.8824, 1.0000], ], [ - [0.00000000000000000000000000000, 0.00000000000000000000000000000], - [0.00000000000000000000000000000, 0.00000000000000000000000000000], + [0.0000, 0.0000], + [0.0000, 0.0000], ], ] ) - assert torch.equal(gt, cms) + torch.testing.assert_close(gt, cms, atol=0.001, rtol=0.0) From 0c0061bde41dbc39e9e92e6dcf579afa3c76e010 Mon Sep 17 00:00:00 2001 From: David Samy <96505813+davidasamy@users.noreply.github.com> Date: Tue, 8 Aug 2023 13:58:50 -0700 Subject: [PATCH 5/5] Update sleap_nn/data/confidence_maps.py Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com> --- sleap_nn/data/confidence_maps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index 2c83a546..e2b924f9 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -113,7 +113,7 @@ def __iter__(self): confidence_maps = make_confmaps( instance, xv, yv, self.sigma - ) # (height, width, n_nodes) + ) # (n_nodes, height, width) example["confidence_maps"] = confidence_maps yield example