Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confidence Map Generation #11

Merged
merged 5 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions sleap_nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Generate confidence maps."""
from torch.utils.data.datapipes.datapipe import IterDataPipe
from typing import Optional
import sleap_io as sio
import torch


def make_confmaps(
points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
):
"""Make confidence maps from a set of points from a single instance.

Args:
points: A tensor of points of shape `(n_nodes, 2)` and dtype `torch.float32` where
the last axis corresponds to (x, y) pixel coordinates on the image. These
can contain NaNs to indicate missing points.
xv: Sampling grid vector for x-coordinates of shape `(grid_width,)` and dtype
`torch.float32`. This can be generated by
`sleap.nn.data.utils.make_grid_vectors`.
yv: Sampling grid vector for y-coordinates of shape `(grid_height,)` and dtype
`torch.float32`. This can be generated by
`sleap.nn.data.utils.make_grid_vectors`.
sigma: Standard deviation of the 2D Gaussian distribution sampled to generate
confidence maps.

Returns:
Confidence maps as a tensor of shape `(grid_height, grid_width, n_nodes)` of
davidasamy marked this conversation as resolved.
Show resolved Hide resolved
dtype `torch.float32`.
"""
x = torch.reshape(points[:, 0], (1, 1, -1))
y = torch.reshape(points[:, 1], (1, 1, -1))
cm = torch.exp(
-(
(torch.reshape(xv, (1, -1, 1)) - x) ** 2
+ (torch.reshape(yv, (-1, 1, 1)) - y) ** 2
)
/ (2 * sigma**2)
)
davidasamy marked this conversation as resolved.
Show resolved Hide resolved

# Replace NaNs with 0.
cm = torch.where(torch.isnan(cm), 0.0, cm)
davidasamy marked this conversation as resolved.
Show resolved Hide resolved
return cm


def make_grid_vectors(image_height: int, image_width: int, output_stride: int):
"""Make sampling grid vectors from image dimensions.

Args:
image_height: Height of the image grid that will be sampled, specified as a
scalar integer.
image_width: width of the image grid that will be sampled, specified as a
scalar integer.
output_stride: Sampling step size, specified as a scalar integer.

Returns:
Tuple of grid vectors (xv, yv). These are tensors of dtype torch.float32 with
shapes (grid_width,) and (grid_height,) respectively.

The grid dimensions are calculated as:
grid_width = image_width // output_stride
grid_height = image_height // output_stride
"""
xv = torch.arange(0, image_width, step=output_stride).to(
torch.float32
) # (image_width,)
yv = torch.arange(0, image_height, step=output_stride).to(
torch.float32
) # (image_height,)
return xv, yv


class ConfidenceMapGenerator(IterDataPipe):
"""DataPipe for generating confidence maps.

This DataPipe will generate confidence maps for examples from the input pipeline.
Input examples must contain image of shape (frames, channels, crop_height, crop_width)
and instance of shape (n_instances, 2).

Attributes:
source_dp: The input `IterDataPipe` with examples that contain an instance and
an image.
sigma: The standard deviation of the Gaussian distribution that is used to
generate confidence maps.
output_stride: The relative stride to use when generating confidence maps.
davidasamy marked this conversation as resolved.
Show resolved Hide resolved
A larger stride will generate smaller confidence maps.
instance_key: The name of the key where the instance points are.
image_key: The name of the key where the image is.
davidasamy marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
source_dp: IterDataPipe,
sigma: int = 1.5,
output_stride: int = 1,
instance_key: str = "instance",
image_key: str = "instance_image",
):
"""Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride."""
self.source_dp = source_dp
self.sigma = sigma
self.output_stride = output_stride
self.instance_key = instance_key
self.image_key = image_key

def __iter__(self):
"""Generate confidence maps for each example."""
for example in self.source_dp:
instance = example[self.instance_key]
width = example[self.image_key].shape[-1]
height = example[self.image_key].shape[-2]

xv, yv = make_grid_vectors(height, width, self.output_stride)

confidence_maps = make_confmaps(
instance, xv, yv, self.sigma
) # (height, width, n_nodes)
davidasamy marked this conversation as resolved.
Show resolved Hide resolved

example["confidence_maps"] = confidence_maps
yield example
28 changes: 28 additions & 0 deletions tests/data/test_confmaps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from sleap_nn.data.providers import LabelsReader
import torch
from sleap_nn.data.instance_cropping import make_centered_bboxes, InstanceCropper
from sleap_nn.data.instance_centroids import InstanceCentroidFinder
from sleap_nn.data.normalization import Normalizer
from sleap_nn.data.confidence_maps import ConfidenceMapGenerator


def test_confmaps(minimal_instance):
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = InstanceCentroidFinder(datapipe)
datapipe = Normalizer(datapipe)
datapipe = InstanceCropper(datapipe, 100, 100)
datapipe1 = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=1)
sample = next(iter(datapipe1))

assert sample["confidence_maps"].shape == (100, 100, 2)
assert torch.max(sample["confidence_maps"]) == torch.Tensor(
[0.989626109600067138671875]
)

datapipe2 = ConfidenceMapGenerator(datapipe, sigma=3.0, output_stride=2)
sample = next(iter(datapipe2))

assert sample["confidence_maps"].shape == (50, 50, 2)
assert torch.max(sample["confidence_maps"]) == torch.Tensor(
[0.99739634990692138671875]
)