diff --git a/pyproject.toml b/pyproject.toml index 82d1cd76..e39fdc88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,6 @@ Repository = "https://github.com/talmolab/sleap-nn" line-length = 88 [tool.ruff] -output-format = "github" select = [ "D", # pydocstyle ] diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index 6bffd7a5..ddaa6ead 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -127,6 +127,9 @@ class KorniaAugmenter(IterDataPipe): random_crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int], then out_h = size[0], out_w = size[1]. random_crop_p: Probability of applying random crop. + input_key: Can be `image` or `instance`. The input_key `instance` expects the + the KorniaAugmenter to follow the InstanceCropper else `image` otherwise + for default. Notes: This block expects the "image" and "instances" keys to be present in the input @@ -164,6 +167,8 @@ def __init__( mixup_p: float = 0.0, random_crop_hw: Tuple[int, int] = (0, 0), random_crop_p: float = 0.0, + image_key: str = "image", + instance_key: str = "instances", ) -> None: """Initialize the block and the augmentation pipeline.""" self.source_dp = source_dp @@ -187,6 +192,8 @@ def __init__( self.mixup_p = mixup_p self.random_crop_hw = random_crop_hw self.random_crop_p = random_crop_p + self.image_key = image_key + self.instance_key = instance_key aug_stack = [] if self.affine_p > 0: @@ -282,28 +289,22 @@ def __init__( def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return an example dictionary with the augmented image and instances.""" for ex in self.source_dp: - if "instance_image" in ex and "instance" in ex: - inst_shape = ex["instance"].shape - # (B, channels, height, width), (1, num_nodes, 2) - image, instances = ex["instance_image"], ex["instance"].unsqueeze(0) - aug_image, aug_instances = self.augmenter(image, instances) - ex.update( - { - "instance_image": aug_image, - "instance": aug_instances.reshape(*inst_shape), - } - ) - elif "image" in ex and "instances" in ex: - inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2) - image, instances = ex["image"], ex["instances"].reshape( - inst_shape[0], -1, 2 - ) # (B, channels, height, width), (B, num_instances x num_nodes, 2) + inst_shape = ex[self.instance_key].shape + # Before (self.input_key="image"): (B=1, C, H, W), (B=1, num_instances, num_nodes, 2) + # or + # Before (self.input_key="instance"): (B=1, C, crop_H, crop_W), (B=1, num_nodes, 2) + image, instances = ex[self.image_key], ex[self.instance_key].reshape( + inst_shape[0], -1, 2 + ) # (B=1, C, H, W), (B=1, num_instances * num_nodes, 2) OR (B=1, num_nodes, 2) - aug_image, aug_instances = self.augmenter(image, instances) - ex.update( - { - "image": aug_image, - "instances": aug_instances.reshape(*inst_shape), - } - ) + aug_image, aug_instances = self.augmenter(image, instances) + ex.update( + { + self.image_key: aug_image, + self.instance_key: aug_instances.reshape(*inst_shape), + } + ) + # After (self.input_key="image"): (B=1, C, H, W), (B=1, num_instances, num_nodes, 2) + # or + # After (self.input_key="instance"): (B=1, C, crop_H, crop_W), (B=1, num_nodes, 2) yield ex diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index 89478f8c..c97d9e30 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -1,7 +1,6 @@ """Generate confidence maps.""" -from typing import Dict, Iterator, Optional +from typing import Dict, Iterator -import sleap_io as sio import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -9,14 +8,14 @@ def make_confmaps( - points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float + points_batch: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float ) -> torch.Tensor: - """Make confidence maps from a set of points from a single instance. + """Make confidence maps from a batch of points for multiple instances. Args: - points: A tensor of points of shape `(n_nodes, 2)` and dtype `torch.float32` where - the last axis corresponds to (x, y) pixel coordinates on the image. These - can contain NaNs to indicate missing points. + points_batch: A tensor of points of shape `(batch_size, n_nodes, 2)` and dtype `torch.float32` where + the last axis corresponds to (x, y) pixel coordinates on the image for each instance. + These can contain NaNs to indicate missing points. xv: Sampling grid vector for x-coordinates of shape `(grid_width,)` and dtype `torch.float32`. This can be generated by `sleap.nn.data.utils.make_grid_vectors`. @@ -27,21 +26,24 @@ def make_confmaps( confidence maps. Returns: - Confidence maps as a tensor of shape `(n_nodes, grid_height, grid_width)` of + Confidence maps as a tensor of shape `(batch_size, n_nodes, grid_height, grid_width)` of dtype `torch.float32`. """ - x = torch.reshape(points[:, 0], (-1, 1, 1)) - y = torch.reshape(points[:, 1], (-1, 1, 1)) + batch_size, n_nodes, _ = points_batch.shape + + x = torch.reshape(points_batch[:, :, 0], (batch_size, n_nodes, 1, 1)) + y = torch.reshape(points_batch[:, :, 1], (batch_size, n_nodes, 1, 1)) + + xv_reshaped = torch.reshape(xv, (1, 1, 1, -1)) + yv_reshaped = torch.reshape(yv, (1, 1, -1, 1)) + cm = torch.exp( - -( - (torch.reshape(xv, (1, 1, -1)) - x) ** 2 - + (torch.reshape(yv, (1, -1, 1)) - y) ** 2 - ) - / (2 * sigma**2) + -((xv_reshaped - x) ** 2 + (yv_reshaped - y) ** 2) / (2 * sigma**2) ) # Replace NaNs with 0. cm = torch.nan_to_num(cm) + return cm @@ -59,8 +61,8 @@ class ConfidenceMapGenerator(IterDataPipe): generate confidence maps. output_stride: The relative stride to use when generating confidence maps. A larger stride will generate smaller confidence maps. - instance_key: The name of the key where the instance points (n_instances, 2) are. image_key: The name of the key where the image (frames, channels, crop_height, crop_width) is. + instance_key: The name of the key where the instance points (n_instances, 2) are. """ def __init__( @@ -68,27 +70,33 @@ def __init__( source_dp: IterDataPipe, sigma: int = 1.5, output_stride: int = 1, - instance_key: str = "instance", - image_key: str = "instance_image", + image_key: str = "image", + instance_key: str = "instances", ) -> None: """Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride.""" self.source_dp = source_dp self.sigma = sigma self.output_stride = output_stride - self.instance_key = instance_key self.image_key = image_key + self.instance_key = instance_key def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Generate confidence maps for each example.""" for example in self.source_dp: instance = example[self.instance_key] + if self.instance_key == "instances": + instance = instance.view(instance.shape[0], -1, 2) + width = example[self.image_key].shape[-1] height = example[self.image_key].shape[-2] xv, yv = make_grid_vectors(height, width, self.output_stride) confidence_maps = make_confmaps( - instance, xv, yv, self.sigma + instance, + xv, + yv, + self.sigma, ) # (n_nodes, height, width) example["confidence_maps"] = confidence_maps diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index 6240fd35..f541b321 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -89,5 +89,5 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: for ex in self.source_dp: ex["centroids"] = find_centroids( ex["instances"], anchor_ind=self.anchor_ind - ) + ) # (B=1, num_instances, 2) yield ex diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 4de67b4d..f78ba65d 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -67,9 +67,9 @@ def __init__(self, source_dp: IterDataPipe, crop_hw: Tuple[int, int]) -> None: def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Generate instance cropped examples.""" for ex in self.source_dp: - image = ex["image"] # (B=1, channels, height, width) - instances = ex["instances"] # (B=1, n_instances, num_nodes, 2) - centroids = ex["centroids"] # (B=1, n_instances, 2) + image = ex["image"] # (B=1, C, H, W) + instances = ex["instances"] # (B=1, num_instances, num_nodes, 2) + centroids = ex["centroids"] # (B=1, num_instances, 2) for instance, centroid in zip(instances[0], centroids[0]): # Generate bounding boxes from centroid. instance_bbox = torch.unsqueeze( @@ -78,7 +78,7 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: box_size = (self.crop_hw[0], self.crop_hw[1]) - # Generate cropped image of shape (B=1, channels, crop_height, crop_width) + # Generate cropped image of shape (B=1, C, crop_H, crop_W) instance_image = crop_and_resize( image, boxes=instance_bbox, @@ -91,11 +91,9 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: center_instance = instance - point instance_example = { - "instance_image": instance_image.squeeze( - 0 - ), # (B=1, channels, crop_height, crop_width) + "instance_image": instance_image, # (B=1, C, crop_H, crop_W) "instance_bbox": instance_bbox, # (B=1, 4, 2) - "instance": center_instance, # (num_nodes, 2) + "instance": center_instance.unsqueeze(0), # (B=1, num_nodes, 2) } ex.update(instance_example) yield ex diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index e7776222..e2e959e5 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -24,26 +24,25 @@ def __init__(self, data_config: DictConfig) -> None: """Initialize the data config.""" self.data_config = data_config - def make_training_pipeline( - self, data_provider: IterDataPipe, filename: str - ) -> IterDataPipe: + def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: """Create training pipeline with input data only. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. - filename: A string path to the name of the `.slp` file. Returns: An `IterDataPipe` instance configured to produce input examples. """ - datapipe = data_provider.from_filename(filename=filename) + datapipe = data_provider datapipe = Normalizer(datapipe) if self.data_config.augmentation_config.use_augmentations: datapipe = KorniaAugmenter( datapipe, **dict(self.data_config.augmentation_config.augmentations.intensity), + image_key="image", + instance_key="instances", ) datapipe = InstanceCentroidFinder( @@ -56,19 +55,92 @@ def make_training_pipeline( datapipe, random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw, random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, + image_key="instance_image", + instance_key="instance", ) if self.data_config.augmentation_config.use_augmentations: datapipe = KorniaAugmenter( datapipe, **dict(self.data_config.augmentation_config.augmentations.geometric), + image_key="instance_image", + instance_key="instance", ) datapipe = ConfidenceMapGenerator( datapipe, sigma=self.data_config.preprocessing.conf_map_gen.sigma, output_stride=self.data_config.preprocessing.conf_map_gen.output_stride, + image_key="instance_image", + instance_key="instance", + ) + datapipe = KeyFilter( + datapipe, + keep_keys=[ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + "confidence_maps", + ], + ) + + return datapipe + + +class SingleInstanceConfmapsPipeline: + """Pipeline builder for single-instance confidence map models. + + Attributes: + data_config: Data-related configuration. + """ + + def __init__(self, data_config: DictConfig) -> None: + """Initialize the data config.""" + self.data_config = data_config + + def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: + """Create training pipeline with input data only. + + Args: + data_provider: A `Provider` that generates data examples, typically a + `LabelsReader` instance. + + Returns: + An `IterDataPipe` instance configured to produce input examples. + """ + datapipe = data_provider + datapipe = Normalizer(datapipe) + + if self.data_config.augmentation_config.use_augmentations: + datapipe = KorniaAugmenter( + datapipe, + **dict(self.data_config.augmentation_config.augmentations.intensity), + **dict(self.data_config.augmentation_config.augmentations.geometric), + image_key="image", + instance_key="instances", + ) + + if self.data_config.augmentation_config.random_crop.random_crop_p: + datapipe = KorniaAugmenter( + datapipe, + random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw, + random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, + image_key="image", + instance_key="instances", + ) + + datapipe = ConfidenceMapGenerator( + datapipe, + sigma=self.data_config.preprocessing.conf_map_gen.sigma, + output_stride=self.data_config.preprocessing.conf_map_gen.output_stride, + image_key="image", + instance_key="instances", + ) + datapipe = KeyFilter( + datapipe, keep_keys=["image", "instances", "confidence_maps"] ) - datapipe = KeyFilter(datapipe, keep_keys=self.data_config.general.keep_keys) return datapipe diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index bfafd42f..35fe5b87 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -18,6 +18,7 @@ class LabelsReader(IterDataPipe): labels: sleap_io.Labels object that contains LabeledFrames that will be accessed through a torchdata DataPipe user_instances_only: True if filter labels only to user instances else False. Default value True + """ def __init__(self, labels: sio.Labels, user_instances_only: bool = True): diff --git a/sleap_nn/evaluation.py b/sleap_nn/evaluation.py index c7072724..26388348 100644 --- a/sleap_nn/evaluation.py +++ b/sleap_nn/evaluation.py @@ -16,7 +16,7 @@ class MatchInstance: def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]: - """Function to get a list of instances of type MatchInstance from the Labeled Frame. + """Get a list of instances of type MatchInstance from the Labeled Frame. Args: labeled_frame: Input Labeled frame of type sio.LabeledFrame. @@ -555,7 +555,7 @@ def voc_metrics( } def mOKS(self): - """Returns the meanOKS value.""" + """Return the meanOKS value.""" pair_oks = np.array([oks for _, _, oks in self.positive_pairs]) return {"mOKS": pair_oks.mean()} diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index f7821e6d..fa3ad2e1 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -13,21 +13,33 @@ def test_confmaps(minimal_instance): datapipe = InstanceCentroidFinder(datapipe) datapipe = Normalizer(datapipe) datapipe = InstanceCropper(datapipe, (100, 100)) - datapipe1 = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=1) + datapipe1 = ConfidenceMapGenerator( + datapipe, + sigma=1.5, + output_stride=1, + image_key="instance_image", + instance_key="instance", + ) sample = next(iter(datapipe1)) - assert sample["confidence_maps"].shape == (2, 100, 100) + assert sample["confidence_maps"].shape == (1, 2, 100, 100) assert torch.max(sample["confidence_maps"]) == torch.Tensor([0.9479378461837769]) - datapipe2 = ConfidenceMapGenerator(datapipe, sigma=3.0, output_stride=2) + datapipe2 = ConfidenceMapGenerator( + datapipe, + sigma=3.0, + output_stride=2, + image_key="instance_image", + instance_key="instance", + ) sample = next(iter(datapipe2)) - assert sample["confidence_maps"].shape == (2, 50, 50) + assert sample["confidence_maps"].shape == (1, 2, 50, 50) assert torch.max(sample["confidence_maps"]) == torch.Tensor([0.9867223501205444]) xv, yv = make_grid_vectors(2, 2, 1) points = torch.Tensor([[1.0, 1.0], [torch.nan, torch.nan]]) - cms = make_confmaps(points, xv, yv, 2.0) + cms = make_confmaps(points.unsqueeze(0), xv, yv, 2.0) gt = torch.Tensor( [ [ @@ -41,4 +53,4 @@ def test_confmaps(minimal_instance): ] ) - torch.testing.assert_close(gt, cms, atol=0.001, rtol=0.0) + torch.testing.assert_close(gt.unsqueeze(0), cms, atol=0.001, rtol=0.0) diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 35518856..a5b31bdc 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -44,8 +44,8 @@ def test_instance_cropper(minimal_instance): assert len(sample.keys()) == len(gt_sample_keys) for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): assert gt_key == key - assert sample["instance"].shape == (2, 2) - assert sample["instance_image"].shape == (1, 100, 100) + assert sample["instance"].shape == (1, 2, 2) + assert sample["instance_image"].shape == (1, 1, 100, 100) assert sample["instance_bbox"].shape == (1, 4, 2) # Test samples. @@ -56,4 +56,4 @@ def test_instance_cropper(minimal_instance): ] ) centered_instance = sample["instance"] - assert torch.equal(centered_instance, gt) + assert torch.equal(centered_instance, gt.unsqueeze(0)) diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index d80294a3..9504c604 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -1,12 +1,16 @@ -import torch from omegaconf import OmegaConf +import sleap_io as sio + from sleap_nn.data.confidence_maps import ConfidenceMapGenerator from sleap_nn.data.general import KeyFilter from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.instance_cropping import InstanceCropper from sleap_nn.data.normalization import Normalizer -from sleap_nn.data.pipelines import TopdownConfmapsPipeline +from sleap_nn.data.pipelines import ( + TopdownConfmapsPipeline, + SingleInstanceConfmapsPipeline, +) from sleap_nn.data.providers import LabelsReader @@ -15,7 +19,13 @@ def test_key_filter(minimal_instance): datapipe = Normalizer(datapipe) datapipe = InstanceCentroidFinder(datapipe) datapipe = InstanceCropper(datapipe, (160, 160)) - datapipe = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=2) + datapipe = ConfidenceMapGenerator( + datapipe, + sigma=1.5, + output_stride=2, + image_key="instance_image", + instance_key="instance", + ) datapipe = KeyFilter(datapipe, keep_keys=None) gt_sample_keys = [ @@ -35,8 +45,8 @@ def test_key_filter(minimal_instance): for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): assert gt_key == key - assert sample["instance_image"].shape == (1, 160, 160) - assert sample["confidence_maps"].shape == (2, 80, 80) + assert sample["instance_image"].shape == (1, 1, 160, 160) + assert sample["confidence_maps"].shape == (1, 2, 80, 80) assert sample["frame_idx"] == 0 assert sample["video_idx"] == 0 @@ -44,14 +54,6 @@ def test_key_filter(minimal_instance): def test_topdownconfmapspipeline(minimal_instance): base_topdown_data_config = OmegaConf.create( { - "general": { - "keep_keys": [ - "instance_image", - "confidence_maps", - "video_idx", - "frame_idx", - ] - }, "preprocessing": { "anchor_ind": None, "crop_hw": (160, 160), @@ -89,23 +91,28 @@ def test_topdownconfmapspipeline(minimal_instance): ) pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) - datapipe = pipeline.make_training_pipeline( - data_provider=LabelsReader, filename=minimal_instance - ) - - gt_sample_keys = ["instance_image", "confidence_maps", "video_idx", "frame_idx"] + data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + gt_sample_keys = [ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + "confidence_maps", + ] sample = next(iter(datapipe)) assert len(sample.keys()) == len(gt_sample_keys) for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): assert gt_key == key - assert sample["instance_image"].shape == (1, 160, 160) - assert sample["confidence_maps"].shape == (2, 80, 80) + assert sample["instance_image"].shape == (1, 1, 160, 160) + assert sample["confidence_maps"].shape == (1, 2, 80, 80) base_topdown_data_config = OmegaConf.create( { - "general": {"keep_keys": None}, "preprocessing": { "anchor_ind": None, "crop_hw": (160, 160), @@ -143,9 +150,8 @@ def test_topdownconfmapspipeline(minimal_instance): ) pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) - datapipe = pipeline.make_training_pipeline( - data_provider=LabelsReader, filename=minimal_instance - ) + data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + datapipe = pipeline.make_training_pipeline(data_provider=data_provider) gt_sample_keys = [ "image", @@ -155,8 +161,6 @@ def test_topdownconfmapspipeline(minimal_instance): "instance_bbox", "instance_image", "confidence_maps", - "video_idx", - "frame_idx", ] sample = next(iter(datapipe)) @@ -164,5 +168,123 @@ def test_topdownconfmapspipeline(minimal_instance): for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): assert gt_key == key - assert sample["instance_image"].shape == (1, 160, 160) - assert sample["confidence_maps"].shape == (2, 80, 80) + assert sample["instance_image"].shape == (1, 1, 160, 160) + assert sample["confidence_maps"].shape == (1, 2, 80, 80) + + +def test_singleinstanceconfmapspipeline(minimal_instance): + labels = sio.load_slp(minimal_instance) + + # Making our minimal 2-instance example into a single instance example. + for lf in labels: + lf.instances = lf.instances[:1] + + base_singleinstance_data_config = OmegaConf.create( + { + "preprocessing": { + "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, + }, + "augmentation_config": { + "random_crop": {"random_crop_p": 0.0, "random_crop_hw": (160, 160)}, + "use_augmentations": False, + "augmentations": { + "intensity": { + "uniform_noise": (0.0, 0.04), + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast": (0.5, 2.0), + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + }, + "geometric": { + "rotation": 15.0, + "scale": 0.05, + "translate": (0.02, 0.02), + "affine_p": 0.5, + "erase_scale": (0.0001, 0.01), + "erase_ratio": (1, 1), + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, + }, + }, + }, + } + ) + + pipeline = SingleInstanceConfmapsPipeline( + data_config=base_singleinstance_data_config + ) + data_provider = LabelsReader(labels=labels) + datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + + sample = next(iter(datapipe)) + + gt_sample_keys = [ + "image", + "instances", + "confidence_maps", + ] + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["image"].shape == (1, 1, 384, 384) + assert sample["confidence_maps"].shape == (1, 2, 192, 192) + + base_singleinstance_data_config = OmegaConf.create( + { + "preprocessing": { + "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, + }, + "augmentation_config": { + "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)}, + "use_augmentations": True, + "augmentations": { + "intensity": { + "uniform_noise": (0.0, 0.04), + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast": (0.5, 2.0), + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + }, + "geometric": { + "rotation": 15.0, + "scale": 0.05, + "translate": (0.02, 0.02), + "affine_p": 0.5, + "erase_scale": (0.0001, 0.01), + "erase_ratio": (1, 1), + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, + }, + }, + }, + } + ) + + pipeline = SingleInstanceConfmapsPipeline( + data_config=base_singleinstance_data_config + ) + data_provider = LabelsReader(labels=labels) + datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + + sample = next(iter(datapipe)) + + gt_sample_keys = [ + "image", + "instances", + "confidence_maps", + ] + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["image"].shape == (1, 1, 160, 160) + assert sample["confidence_maps"].shape == (1, 2, 80, 80) diff --git a/tests/data/test_providers.py b/tests/data/test_providers.py index f49ee14a..badfc601 100644 --- a/tests/data/test_providers.py +++ b/tests/data/test_providers.py @@ -11,71 +11,3 @@ def test_providers(minimal_instance): instances, image = sample["instances"], sample["image"] assert image.shape == torch.Size([1, 1, 384, 384]) assert instances.shape == torch.Size([1, 2, 2, 2]) - - -def test_filter_user_instances(minimal_instance): - # Create sample Labels object. - - # Create skeleton. - skeleton = sio.Skeleton( - nodes=["head", "thorax", "abdomen"], - edges=[("head", "thorax"), ("thorax", "abdomen")], - ) - - # Get video. - min_labels = sio.load_slp(minimal_instance) - video = min_labels.videos[0] - - # Create user labelled instance. - user_inst = sio.Instance.from_numpy( - points=np.array( - [ - [11.4, 13.4], - [13.6, 15.1], - [0.3, 9.3], - ] - ), - skeleton=skeleton, - ) - - # Create Predicted Instance. - pred_inst = sio.PredictedInstance.from_numpy( - points=np.array( - [ - [10.2, 20.4], - [5.8, 15.1], - [0.3, 10.6], - ] - ), - skeleton=skeleton, - point_scores=np.array([0.5, 0.6, 0.8]), - instance_score=0.6, - ) - - # Create labeled frame. - user_lf = sio.LabeledFrame( - video=video, frame_idx=0, instances=[user_inst, pred_inst] - ) - pred_lf = sio.LabeledFrame(video=video, frame_idx=0, instances=[pred_inst]) - - # Create labels. - labels = sio.Labels( - videos=[video], skeletons=[skeleton], labeled_frames=[user_lf, pred_lf] - ) - - l = LabelsReader(labels, user_instances_only=True) - - # Check user instance filtering. - assert len(list(l)) == 1 - lf = next(iter(l)) - assert len(torch.squeeze(lf["instances"], dim=0)) == 1 - - l = LabelsReader(labels, user_instances_only=False) - assert len(list(l)) == 2 - lf = next(iter(l)) - assert len(torch.squeeze(lf["instances"], dim=0)) == 2 - - # Test with only Predicted instance. - labels = sio.Labels(videos=[video], skeletons=[skeleton], labeled_frames=[pred_lf]) - l = LabelsReader(labels, user_instances_only=True) - assert len(list(l)) == 0