From 1c51d26e85a22d4d32219e29a7cb0aad361754fc Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 25 Sep 2024 23:40:12 -0700 Subject: [PATCH 01/11] Add litdata to trainer --- tests/training/test_model_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index 898ba955..325e3a10 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -46,14 +46,14 @@ def test_create_data_loader(config, tmp_path: str): shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) - # test exception - config_copy = config.copy() - head_config = config_copy.model_config.head_configs.centered_instance - del config_copy.model_config.head_configs.centered_instance - OmegaConf.update(config_copy, "model_config.head_configs.topdown", head_config) - model_trainer = ModelTrainer(config_copy) - with pytest.raises(Exception): - model_trainer._create_data_loaders() +# # test exception +# config_copy = config.copy() +# head_config = config_copy.model_config.head_configs.centered_instance +# del config_copy.model_config.head_configs.centered_instance +# OmegaConf.update(config_copy, "model_config.head_configs.topdown", head_config) +# model_trainer = ModelTrainer(config_copy) +# with pytest.raises(Exception): +# model_trainer._create_data_loaders() def test_wandb(): From d7240412fc42acdc6e974433399e69a0ceb8b0d6 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 26 Sep 2024 09:16:48 -0700 Subject: [PATCH 02/11] Add tests for data loaderS --- tests/training/test_model_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index 325e3a10..491f3329 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -88,8 +88,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): assert not ( Path(config.trainer_config.save_ckpt_path).joinpath("best.ckpt").exists() ) - shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # update save_ckpt to True OmegaConf.update(config, "trainer_config.save_ckpt", True) @@ -162,8 +162,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): assert abs(df.loc[0, "learning_rate"] - config.trainer_config.optimizer.lr) <= 1e-4 assert not df.val_loss.isnull().all() assert not df.train_loss.isnull().all() - shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # check resume training config_copy = config.copy() @@ -185,8 +185,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): Path(config_copy.trainer_config.save_ckpt_path).joinpath("best.ckpt") ) assert checkpoint["epoch"] == 3 - shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) training_config = OmegaConf.load( f"{config_copy.trainer_config.save_ckpt_path}/training_config.yaml" @@ -214,8 +214,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): Path(config_early_stopping.trainer_config.save_ckpt_path).joinpath("best.ckpt") ) assert checkpoint["epoch"] == 1 - shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # For Single instance model single_instance_config = config.copy() From 65407acc8422e49d6ef31b8cc9a1e95b82c64f9a Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 26 Sep 2024 09:30:45 -0700 Subject: [PATCH 03/11] Fix tests --- tests/training/test_model_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index 491f3329..325e3a10 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -88,8 +88,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): assert not ( Path(config.trainer_config.save_ckpt_path).joinpath("best.ckpt").exists() ) - # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # update save_ckpt to True OmegaConf.update(config, "trainer_config.save_ckpt", True) @@ -162,8 +162,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): assert abs(df.loc[0, "learning_rate"] - config.trainer_config.optimizer.lr) <= 1e-4 assert not df.val_loss.isnull().all() assert not df.train_loss.isnull().all() - # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # check resume training config_copy = config.copy() @@ -185,8 +185,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): Path(config_copy.trainer_config.save_ckpt_path).joinpath("best.ckpt") ) assert checkpoint["epoch"] == 3 - # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) training_config = OmegaConf.load( f"{config_copy.trainer_config.save_ckpt_path}/training_config.yaml" @@ -214,8 +214,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): Path(config_early_stopping.trainer_config.save_ckpt_path).joinpath("best.ckpt") ) assert checkpoint["epoch"] == 1 - # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # For Single instance model single_instance_config = config.copy() From 38419557d2ed75caf9346efb870a49cfc60b8d8b Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 26 Sep 2024 09:43:23 -0700 Subject: [PATCH 04/11] Remove files in trainer --- tests/training/test_model_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index 325e3a10..491f3329 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -88,8 +88,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): assert not ( Path(config.trainer_config.save_ckpt_path).joinpath("best.ckpt").exists() ) - shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # update save_ckpt to True OmegaConf.update(config, "trainer_config.save_ckpt", True) @@ -162,8 +162,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): assert abs(df.loc[0, "learning_rate"] - config.trainer_config.optimizer.lr) <= 1e-4 assert not df.val_loss.isnull().all() assert not df.train_loss.isnull().all() - shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # check resume training config_copy = config.copy() @@ -185,8 +185,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): Path(config_copy.trainer_config.save_ckpt_path).joinpath("best.ckpt") ) assert checkpoint["epoch"] == 3 - shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) training_config = OmegaConf.load( f"{config_copy.trainer_config.save_ckpt_path}/training_config.yaml" @@ -214,8 +214,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): Path(config_early_stopping.trainer_config.save_ckpt_path).joinpath("best.ckpt") ) assert checkpoint["epoch"] == 1 - shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # For Single instance model single_instance_config = config.copy() From 493ebbc34132261a97b050da6fcda5c173f3d74b Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 26 Sep 2024 10:43:46 -0700 Subject: [PATCH 05/11] Remove shutil.rmtree --- tests/training/test_model_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index 491f3329..325e3a10 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -88,8 +88,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): assert not ( Path(config.trainer_config.save_ckpt_path).joinpath("best.ckpt").exists() ) - # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # update save_ckpt to True OmegaConf.update(config, "trainer_config.save_ckpt", True) @@ -162,8 +162,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): assert abs(df.loc[0, "learning_rate"] - config.trainer_config.optimizer.lr) <= 1e-4 assert not df.val_loss.isnull().all() assert not df.train_loss.isnull().all() - # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # check resume training config_copy = config.copy() @@ -185,8 +185,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): Path(config_copy.trainer_config.save_ckpt_path).joinpath("best.ckpt") ) assert checkpoint["epoch"] == 3 - # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) training_config = OmegaConf.load( f"{config_copy.trainer_config.save_ckpt_path}/training_config.yaml" @@ -214,8 +214,8 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): Path(config_early_stopping.trainer_config.save_ckpt_path).joinpath("best.ckpt") ) assert checkpoint["epoch"] == 1 - # shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) - # shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) + shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) # For Single instance model single_instance_config = config.copy() From 676f28ff3d3598f32c20c7e53de3b0bccf05953b Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 27 Sep 2024 09:47:14 -0700 Subject: [PATCH 06/11] Remove iterdatapipe in inference --- sleap_nn/data/providers.py | 302 +++++++----- sleap_nn/inference/predictors.py | 448 +++++++++--------- .../minimal_instance/training_config.yaml | 2 +- tests/inference/test_bottomup.py | 26 +- tests/inference/test_predictors.py | 3 +- tests/inference/test_single_instance.py | 39 +- 6 files changed, 443 insertions(+), 377 deletions(-) diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index b15eeb8a..70c27d86 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -93,125 +93,125 @@ def process_lf( return ex -class LabelsReader(IterDataPipe): - """IterDataPipe for reading frames from Labels object. - - This IterDataPipe will produce examples containing a frame and an sleap_io.Instance - from a sleap_io.Labels instance. - - Attributes: - labels: sleap_io.Labels object that contains LabeledFrames that will be - accessed through a torchdata DataPipe. - user_instances_only: True if filter labels only to user instances else False. - Default value True - instances_key: True if `instances` key needs to be present in the data pipeline. - When this is set to True, the instances are appended with NaNs to have same - number of instances to enable batching. Default: False. - """ - - def __init__( - self, - labels: sio.Labels, - user_instances_only: bool = True, - instances_key: bool = True, - ): - """Initialize labels attribute of the class.""" - self.labels = copy.deepcopy(labels) - self.max_instances = get_max_instances(labels) - self.instances_key = instances_key - - # Filter to user instances - if user_instances_only: - filtered_lfs = [] - for lf in self.labels: - if lf.user_instances is not None and len(lf.user_instances) > 0: - lf.instances = lf.user_instances - filtered_lfs.append(lf) - self.labels = sio.Labels( - videos=self.labels.videos, - skeletons=self.labels.skeletons, - labeled_frames=filtered_lfs, - ) - - @property - def edge_inds(self) -> list: - """Returns list of edge indices.""" - return self.labels.skeletons[0].edge_inds - - @property - def max_height_and_width(self) -> Tuple[int, int]: - """Return `(height, width)` that is the maximum of all videos.""" - return max(video.shape[1] for video in self.labels.videos), max( - video.shape[2] for video in self.labels.videos - ) - - @classmethod - def from_filename( - cls, - filename: str, - user_instances_only: bool = True, - instances_key: bool = True, - ): - """Create LabelsReader from a .slp filename.""" - labels = sio.load_slp(filename) - return cls(labels, user_instances_only, instances_key) - - def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: - """Return an example dictionary containing the following elements. - - "image": A torch.Tensor containing full raw frame image as a uint8 array - of shape (n_samples, channels, height, width). - "instances": Keypoint coordinates for all instances in the frame as a - float32 torch.Tensor of shape (n_samples, n_instances, n_nodes, 2). - """ - for lf in self.labels: - image = np.transpose(lf.image, (2, 0, 1)) # HWC -> CHW - - instances = [] - for inst in lf: - if not inst.is_empty: - instances.append(inst.numpy()) - instances = np.stack(instances, axis=0) - - # Add singleton time dimension for single frames. - image = np.expand_dims(image, axis=0) # (1, C, H, W) - img_height, img_width = image.shape[-2:] - instances = np.expand_dims( - instances, axis=0 - ) # (1, num_instances, num_nodes, 2) - - instances = torch.from_numpy(instances.astype("float32")) - num_instances, nodes = instances.shape[1:3] - ex = { - "image": torch.from_numpy(image), - "video_idx": torch.tensor( - self.labels.videos.index(lf.video), dtype=torch.int32 - ), - "frame_idx": torch.tensor(lf.frame_idx, dtype=torch.int32), - "num_instances": num_instances, - } - ex["orig_size"] = torch.Tensor([img_height, img_width]) - - if self.instances_key: - nans = torch.full( - (1, np.abs(self.max_instances - num_instances), nodes, 2), torch.nan - ) - ex["instances"] = torch.cat([instances, nans], dim=1) - - yield ex +# class LabelsReader(IterDataPipe): +# """IterDataPipe for reading frames from Labels object. + +# This IterDataPipe will produce examples containing a frame and an sleap_io.Instance +# from a sleap_io.Labels instance. + +# Attributes: +# labels: sleap_io.Labels object that contains LabeledFrames that will be +# accessed through a torchdata DataPipe. +# user_instances_only: True if filter labels only to user instances else False. +# Default value True +# instances_key: True if `instances` key needs to be present in the data pipeline. +# When this is set to True, the instances are appended with NaNs to have same +# number of instances to enable batching. Default: False. +# """ + +# def __init__( +# self, +# labels: sio.Labels, +# user_instances_only: bool = True, +# instances_key: bool = True, +# ): +# """Initialize labels attribute of the class.""" +# self.labels = copy.deepcopy(labels) +# self.max_instances = get_max_instances(labels) +# self.instances_key = instances_key + +# # Filter to user instances +# if user_instances_only: +# filtered_lfs = [] +# for lf in self.labels: +# if lf.user_instances is not None and len(lf.user_instances) > 0: +# lf.instances = lf.user_instances +# filtered_lfs.append(lf) +# self.labels = sio.Labels( +# videos=self.labels.videos, +# skeletons=self.labels.skeletons, +# labeled_frames=filtered_lfs, +# ) + +# @property +# def edge_inds(self) -> list: +# """Returns list of edge indices.""" +# return self.labels.skeletons[0].edge_inds + +# @property +# def max_height_and_width(self) -> Tuple[int, int]: +# """Return `(height, width)` that is the maximum of all videos.""" +# return max(video.shape[1] for video in self.labels.videos), max( +# video.shape[2] for video in self.labels.videos +# ) + +# @classmethod +# def from_filename( +# cls, +# filename: str, +# user_instances_only: bool = True, +# instances_key: bool = True, +# ): +# """Create LabelsReader from a .slp filename.""" +# labels = sio.load_slp(filename) +# return cls(labels, user_instances_only, instances_key) + +# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: +# """Return an example dictionary containing the following elements. + +# "image": A torch.Tensor containing full raw frame image as a uint8 array +# of shape (n_samples, channels, height, width). +# "instances": Keypoint coordinates for all instances in the frame as a +# float32 torch.Tensor of shape (n_samples, n_instances, n_nodes, 2). +# """ +# for lf in self.labels: +# image = np.transpose(lf.image, (2, 0, 1)) # HWC -> CHW + +# instances = [] +# for inst in lf: +# if not inst.is_empty: +# instances.append(inst.numpy()) +# instances = np.stack(instances, axis=0) + +# # Add singleton time dimension for single frames. +# image = np.expand_dims(image, axis=0) # (1, C, H, W) +# img_height, img_width = image.shape[-2:] +# instances = np.expand_dims( +# instances, axis=0 +# ) # (1, num_instances, num_nodes, 2) + +# instances = torch.from_numpy(instances.astype("float32")) +# num_instances, nodes = instances.shape[1:3] +# ex = { +# "image": torch.from_numpy(image), +# "video_idx": torch.tensor( +# self.labels.videos.index(lf.video), dtype=torch.int32 +# ), +# "frame_idx": torch.tensor(lf.frame_idx, dtype=torch.int32), +# "num_instances": num_instances, +# } +# ex["orig_size"] = torch.Tensor([img_height, img_width]) + +# if self.instances_key: +# nans = torch.full( +# (1, np.abs(self.max_instances - num_instances), nodes, 2), torch.nan +# ) +# ex["instances"] = torch.cat([instances, nans], dim=1) + +# yield ex class VideoReader(Thread): """Thread module for reading frames from sleap-io Video object. This module will load the frames from video and pushes them as Tensors into a buffer - queue as a tuple in the format (image, frame index, (height, width)) which are then - batched and consumed during the inference process. + queue as a tuple in the format (image, frame index, video index, (height, width)) + which are then batched and consumed during the inference process. Attributes: - video: sleap_io.Video object that contains LabeledFrames that will be + video: sleap_io.Video object that contains images that will be accessed through a torchdata DataPipe. - frame_buffer: Maximum size of the frame buffer queue. + frame_buffer: Frame buffer queue. start_idx: start index of the frames to read. If None, 0 is set as the default. end_idx: end index of the frames to read. If None, length of the video is set as the default. @@ -248,12 +248,13 @@ def max_height_and_width(self) -> Tuple[int, int]: def from_filename( cls, filename: str, - frame_buffer: Queue, + queue_maxsize: int, start_idx: Optional[int] = None, end_idx: Optional[int] = None, ): """Create LabelsReader from a .slp filename.""" video = sio.load_video(filename) + frame_buffer = Queue(maxsize=queue_maxsize) return cls(video, frame_buffer, start_idx, end_idx) def run(self): @@ -266,9 +267,10 @@ def run(self): self.frame_buffer.put( ( - torch.from_numpy(img), - torch.tensor(idx, dtype=torch.int32), - torch.Tensor(img.shape[-2:]), + torch.from_numpy(img), # img + torch.tensor(idx, dtype=torch.int32), # frame idx + torch.tensor(0, dtype=torch.int32), # video idx + torch.Tensor(img.shape[-2:]), # orig shape ) ) @@ -277,3 +279,75 @@ def run(self): finally: self.frame_buffer.put((None, None, None)) + + +class LabelsReader(Thread): + """Thread module for reading images from sleap-io Labels object. + + This module will load the images from `.slp` files and pushes them as Tensors into a + buffer queue as a tuple in the format (image, frame index, video index, (height, width)) + which are then batched and consumed during the inference process. + + Attributes: + labels: sleap_io.Labels object that contains LabeledFrames that will be + accessed through a torchdata DataPipe. + frame_buffer: Frame buffer queue. + """ + + def __init__( + self, + labels: sio.Labels, + frame_buffer: Queue, + ): + """Initialize attribute of the class.""" + super().__init__() + self.labels = labels + self.frame_buffer = frame_buffer + + def total_len(self): + """Returns the total number of frames in the video.""" + return len(self.labels) + + @property + def max_height_and_width(self) -> Tuple[int, int]: + """Return `(height, width)` of frames in the video.""" + return max(video.shape[1] for video in self.labels.videos), max( + video.shape[2] for video in self.labels.videos + ) + + @classmethod + def from_filename( + cls, + filename: str, + queue_maxsize: int, + ): + """Create LabelsReader from a .slp filename.""" + labels = sio.load_slp(filename) + frame_buffer = Queue(maxsize=queue_maxsize) + return cls(labels, frame_buffer) + + def run(self): + """Adds frames to the buffer queue.""" + try: + for idx in range(self.total_len()): + lf = self.labels[idx] + img = lf.image + img = np.transpose(img, (2, 0, 1)) # convert H,W,C to C,H,W + img = np.expand_dims(img, axis=0) # (1, C, H, W) + + self.frame_buffer.put( + ( + torch.from_numpy(img), # img + torch.tensor(idx, dtype=torch.int32), # frame idx + torch.tensor( + self.labels.videos.index(lf.video), dtype=torch.int32 + ), # video idx + torch.Tensor(img.shape[-2:]), # orig shape + ) + ) + + except Exception as e: + print(f"Error when reading labelled frame. Stopping labels reader.\n{e}") + + finally: + self.frame_buffer.put((None, None, None, None)) diff --git a/sleap_nn/inference/predictors.py b/sleap_nn/inference/predictors.py index 2f591e1e..deb7dfd2 100644 --- a/sleap_nn/inference/predictors.py +++ b/sleap_nn/inference/predictors.py @@ -10,20 +10,17 @@ import torch import attrs import lightning as L -from torch.utils.data.dataloader import DataLoader +import litdata as ld from omegaconf import OmegaConf from sleap_nn.data.providers import LabelsReader, VideoReader -from sleap_nn.data.resizing import ( - SizeMatcher, - Resizer, - PadToStride, - resize_image, - apply_pad_to_stride, +from sleap_nn.data.resizing import resize_image, apply_pad_to_stride, apply_sizematcher +from sleap_nn.data.normalization import ( + apply_normalization, + convert_to_grayscale, + convert_to_rgb, ) -from sleap_nn.data.normalization import Normalizer, convert_to_grayscale, convert_to_rgb -from sleap_nn.data.instance_centroids import InstanceCentroidFinder -from sleap_nn.data.instance_cropping import InstanceCropper -from sleap_nn.data.general import KeyFilter +from sleap_nn.data.instance_centroids import generate_centroids +from sleap_nn.data.instance_cropping import generate_crops from sleap_nn.inference.paf_grouping import PAFScorer from sleap_nn.training.model_trainer import ( TopDownCenteredInstanceModel, @@ -53,7 +50,7 @@ class Predictor(ABC): preprocess: Only for VideoReader provider. True if preprocessing (reszizing and apply_pad_to_stride) should be applied on the frames read in the video reader. Default: True. - video_preprocess_config: Preprocessing config for VideoReader with keys: [`batch_size`, + preprocess_config: Preprocessing config with keys: [`batch_size`, `scale`, `is_rgb`, `max_stride`]. Default: {"batch_size": 4, "scale": 1.0, "is_rgb": False, "max_stride": 1} provider: Provider for inference pipeline. One of ["LabelsReader", "VideoReader"]. @@ -66,14 +63,14 @@ class Predictor(ABC): """ preprocess: bool = True - video_preprocess_config: dict = { + preprocess_config: dict = { "batch_size": 4, "scale": 1.0, "is_rgb": False, "max_stride": 1, } provider: Union[LabelsReader, VideoReader] = LabelsReader - pipeline: Optional[Union[DataLoader, VideoReader]] = None + pipeline: Optional[Union[LabelsReader, VideoReader]] = None inference_model: Optional[ Union[ TopDownInferenceModel, SingleInstanceInferenceModel, BottomUpInferenceModel @@ -199,7 +196,14 @@ def data_config(self) -> OmegaConf: """Get the data parameters from the config.""" @abstractmethod - def make_pipeline(self, provider: str, data_path: str): + def make_pipeline( + self, + provider: str, + data_path: str, + queue_maxsize: int = 8, + video_start_idx=None, + video_end_idx=None, + ): """Create the data pipeline.""" @abstractmethod @@ -233,63 +237,52 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: self._initialize_inference_model() # Loop over data batches. - if self.provider == "LabelsReader": - for ex in self.pipeline: + self.pipeline.start() + batch_size = self.preprocess_config["batch_size"] + done = False + while not done: + imgs = [] + fidxs = [] + vidxs = [] + org_szs = [] + for _ in range(batch_size): + frame = self.pipeline.frame_buffer.get() + if frame[0] is None: + done = True + break + imgs.append(frame[0].unsqueeze(dim=0)) + fidxs.append(frame[1]) + vidxs.append(frame[2]) + org_szs.append(frame[3].unsqueeze(dim=0)) + if imgs: + imgs = torch.concatenate(imgs, dim=0) + fidxs = torch.tensor(fidxs, dtype=torch.int32) + vidxs = torch.tensor(vidxs, dtype=torch.int32) + org_szs = torch.concatenate(org_szs, dim=0) + ex = { + "image": imgs, + "frame_idx": fidxs, + "video_idx": vidxs, + "orig_size": org_szs, + } + ex["image"] = apply_normalization(ex["image"]) + if self.preprocess_config["is_rgb"]: + ex["image"] = convert_to_rgb(ex["image"]) + else: + ex["image"] = convert_to_grayscale(ex["image"]) + if self.preprocess: + scale = self.preprocess_config["scale"] + if scale != 1.0: + ex["image"] = resize_image(ex["image"], scale) + ex["image"] = apply_pad_to_stride( + ex["image"], self.preprocess_config["max_stride"] + ) outputs_list = self.inference_model(ex) for output in outputs_list: output = self._convert_tensors_to_numpy(output) yield output - elif self.provider == "VideoReader": - # try: - self.pipeline.start() - batch_size = self.video_preprocess_config["batch_size"] - done = False - while not done: - imgs = [] - fidxs = [] - org_szs = [] - for _ in range(batch_size): - frame = self.pipeline.frame_buffer.get() - if frame[0] is None: - done = True - break - imgs.append(frame[0].unsqueeze(dim=0)) - fidxs.append(frame[1]) - org_szs.append(frame[2].unsqueeze(dim=0)) - if imgs: - imgs = torch.concatenate(imgs, dim=0) - fidxs = torch.tensor(fidxs, dtype=torch.int32) - org_szs = torch.concatenate(org_szs, dim=0) - ex = { - "image": imgs, - "frame_idx": fidxs, - "video_idx": torch.tensor([0] * batch_size, dtype=torch.int32), - "orig_size": org_szs, - } - if not torch.is_floating_point(ex["image"]): # normalization - ex["image"] = ex["image"].to(torch.float32) / 255.0 - if self.video_preprocess_config["is_rgb"]: - ex["image"] = convert_to_rgb(ex["image"]) - else: - ex["image"] = convert_to_grayscale(ex["image"]) - if self.preprocess: - scale = self.video_preprocess_config["scale"] - if scale != 1.0: - ex["image"] = resize_image(ex["image"], scale) - ex["image"] = apply_pad_to_stride( - ex["image"], self.video_preprocess_config["max_stride"] - ) - outputs_list = self.inference_model(ex) - for output in outputs_list: - output = self._convert_tensors_to_numpy(output) - yield output - - # except Exception as e: - # raise Exception(f"Error in VideoReader during data processing: {e}") - - # finally: - self.pipeline.join() + self.pipeline.join() def predict( self, @@ -554,86 +547,53 @@ def from_trained_models( obj._initialize_inference_model() return obj - def make_pipeline(self, provider: str, data_path: str, num_workers: int = 0): + def make_pipeline( + self, + provider: str, + data_path: str, + queue_maxsize: int = 8, + video_start_idx=None, + video_end_idx=None, + ): """Make a data loading pipeline. Args: provider: (str) Provider class to read the input sleap files. Either "LabelsReader" or "VideoReader". data_path: (str) Path to `.slp` file or `.mp4` to run inference on. - num_workers: (int) Number of subprocesses to use for data loading. 0 means - that the data will be loaded in the main process. *Default*: 0. + #TODO: Returns: - Torch DataLoader where each item is a dictionary with key `image` if provider - is LabelsReader. If provider is VideoReader, this method initiates the reader - class (doesn't return a pipeline) and the Thread is started in - Predictor._predict_generator() method. - - Notes: - This method creates the class attribute `pipeline` and will be - called automatically when predicting on data from a new source only when the - provider is LabelsReader. + This method initiates the reader class (doesn't return a pipeline) and the + Thread is started in Predictor._predict_generator() method. """ self.provider = provider # LabelsReader provider if self.provider == "LabelsReader": provider = LabelsReader - instances_key = True - - # no need of `instances` key for Centered-instance model - if self.centroid_config and self.confmap_config: - instances_key = False - data_provider = provider.from_filename( - data_path, instances_key=instances_key - ) - - self.videos = data_provider.labels.videos - - pipeline = Normalizer(data_provider, is_rgb=self.data_config.is_rgb) - pipeline = SizeMatcher( - pipeline, - max_height=self.data_config.max_height, - max_width=self.data_config.max_width, - provider=data_provider, - ) - - if not self.centroid_model: - pipeline = InstanceCentroidFinder( - pipeline, - anchor_ind=self.confmap_config.model_config.head_configs.centered_instance.confmaps.anchor_part, - ) - pipeline = InstanceCropper( - pipeline, - crop_hw=self.data_config.crop_hw, - ) - - pipeline = KeyFilter( - pipeline, - keep_keys=[ - "image", - "video_idx", - "frame_idx", - "centroid", - "instance", - "instance_bbox", - "instance_image", - "confidence_maps", - "num_instances", - "orig_size", - ], + max_stride = self.confmap_config.model_config.backbone_config.max_stride + scale = self.confmap_config.data_config.preprocessing.scale + if self.centroid_config is not None: + max_stride = ( + self.centroid_config.model_config.backbone_config.max_stride ) + scale = self.centroid_config.data_config.preprocessing.scale - # Remove duplicates. - self.pipeline = pipeline.sharding_filter() + self.preprocess = False + self.preprocess_config = { + "batch_size": self.batch_size, + "scale": scale, + "is_rgb": self.data_config.is_rgb, + "max_stride": max_stride, + } - self.pipeline = DataLoader( - self.pipeline, batch_size=self.batch_size, num_workers=num_workers + self.pipeline = provider.from_filename( + filename=data_path, + queue_maxsize=queue_maxsize, ) - - return self.pipeline + self.videos = self.pipeline.labels.videos # VideoReader provider elif self.provider == "VideoReader": @@ -645,7 +605,7 @@ class (doesn't return a pipeline) and the Thread is started in provider = VideoReader self.preprocess = False - self.video_preprocess_config = { + self.preprocess_config = { "batch_size": self.batch_size, "scale": self.centroid_config.data_config.preprocessing.scale, "is_rgb": self.data_config.is_rgb, @@ -654,14 +614,11 @@ class (doesn't return a pipeline) and the Thread is started in ), } - frame_queue = Queue( - maxsize=self.data_config.video_queue_maxsize if not None else 16 - ) self.pipeline = provider.from_filename( filename=data_path, - frame_buffer=frame_queue, - start_idx=self.data_config.videoreader_start_idx, - end_idx=self.data_config.videoreader_end_idx, + queue_maxsize=queue_maxsize, + start_idx=video_start_idx, + end_idx=video_end_idx, ) self.videos = [self.pipeline.video] @@ -745,6 +702,70 @@ def _make_labeled_frames_from_generator( ) return pred_labels + def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: + """Create a generator that yields batches of inference results. + + This method handles creating a pipeline object depending on the model type and + provider for loading the data, as well as looping over the batches and + running inference. + + Returns: + A generator yielding batches predicted results as dictionaries of numpy + arrays. + """ + # Initialize inference model if needed. + + if self.inference_model is None: + self._initialize_inference_model() + + # Loop over data batches. + self.pipeline.start() + batch_size = self.preprocess_config["batch_size"] + done = False + while not done: + imgs = [] + fidxs = [] + vidxs = [] + org_szs = [] + for _ in range(batch_size): + frame = self.pipeline.frame_buffer.get() + if frame[0] is None: + done = True + break + imgs.append(frame[0].unsqueeze(dim=0)) + fidxs.append(frame[1]) + vidxs.append(frame[2]) + org_szs.append(frame[3].unsqueeze(dim=0)) + if imgs: + imgs = torch.concatenate(imgs, dim=0) + fidxs = torch.tensor(fidxs, dtype=torch.int32) + vidxs = torch.tensor(vidxs, dtype=torch.int32) + org_szs = torch.concatenate(org_szs, dim=0) + ex = { + "image": imgs, + "frame_idx": fidxs, + "video_idx": vidxs, + "orig_size": org_szs, + } + ex["image"] = apply_normalization(ex["image"]) + if self.preprocess_config["is_rgb"]: + ex["image"] = convert_to_rgb(ex["image"]) + else: + ex["image"] = convert_to_grayscale(ex["image"]) + if self.preprocess: + scale = self.preprocess_config["scale"] + if scale != 1.0: + ex["image"] = resize_image(ex["image"], scale) + ex["image"] = apply_pad_to_stride( + ex["image"], self.preprocess_config["max_stride"] + ) + outputs_list = self.inference_model(ex) + for output in outputs_list: + output = self._convert_tensors_to_numpy(output) + yield output + + self.pipeline.join() + @attrs.define class SingleInstancePredictor(Predictor): @@ -877,60 +898,53 @@ def from_trained_models( obj._initialize_inference_model() return obj - def make_pipeline(self, provider: str, data_path: str, num_workers: int = 0): + def make_pipeline( + self, + provider: str, + data_path: str, + queue_maxsize: int = 8, + video_start_idx=None, + video_end_idx=None, + ): """Make a data loading pipeline. Args: provider: (str) Provider class to read the input sleap files. Either "LabelsReader" or "VideoReader". data_path: (str) Path to `.slp` file or `.mp4` to run inference on. - num_workers: (int) Number of subprocesses to use for data loading. 0 means - that the data will be loaded in the main process. *Default*: 0. + #TODO Returns: - Torch DataLoader where each item is a dictionary with key `image` if provider - is LabelsReader. If provider is VideoReader, this method initiates the reader - class (doesn't return a pipeline) and the Thread is started in - Predictor._predict_generator() method. - - Notes: - This method creates the class attribute `pipeline` and will be - called automatically when predicting on data from a new source only when the - provider is LabelsReader. + This method initiates the reader class (doesn't return a pipeline) and the + Thread is started in Predictor._predict_generator() method. + """ self.provider = provider + + # LabelsReader provider if self.provider == "LabelsReader": provider = LabelsReader - data_provider = provider.from_filename(data_path) - self.videos = data_provider.labels.videos - pipeline = Normalizer(data_provider, is_rgb=self.data_config.is_rgb) - pipeline = SizeMatcher( - pipeline, - max_height=self.data_config.max_height, - max_width=self.data_config.max_width, - provider=data_provider, - ) - pipeline = Resizer( - pipeline, scale=self.confmap_config.data_config.preprocessing.scale - ) - pipeline = PadToStride( - pipeline, - max_stride=self.confmap_config.model_config.backbone_config.max_stride, - ) - # Remove duplicates. - self.pipeline = pipeline.sharding_filter() + max_stride = self.confmap_config.model_config.backbone_config.max_stride - self.pipeline = DataLoader( - self.pipeline, batch_size=self.batch_size, num_workers=num_workers - ) + self.preprocess = False + self.preprocess_config = { + "batch_size": self.batch_size, + "scale": self.confmap_config.data_config.preprocessing.scale, + "is_rgb": self.data_config.is_rgb, + "max_stride": max_stride, + } - return self.pipeline + self.pipeline = provider.from_filename( + filename=data_path, + queue_maxsize=queue_maxsize, + ) + self.videos = self.pipeline.labels.videos elif self.provider == "VideoReader": provider = VideoReader self.preprocess = True - self.video_preprocess_config = { + self.preprocess_config = { "batch_size": self.batch_size, "scale": self.confmap_config.data_config.preprocessing.scale, "is_rgb": self.data_config.is_rgb, @@ -938,14 +952,12 @@ class (doesn't return a pipeline) and the Thread is started in self.confmap_config.model_config.backbone_config.max_stride ), } - frame_queue = Queue( - maxsize=self.data_config.video_queue_maxsize if not None else 16 - ) + self.pipeline = provider.from_filename( filename=data_path, - frame_buffer=frame_queue, - start_idx=self.data_config.videoreader_start_idx, - end_idx=self.data_config.videoreader_end_idx, + queue_maxsize=queue_maxsize, + start_idx=video_start_idx, + end_idx=video_end_idx, ) self.videos = [self.pipeline.video] @@ -1193,58 +1205,51 @@ def from_trained_models( obj._initialize_inference_model() return obj - def make_pipeline(self, provider: str, data_path: str, num_workers: int = 0): + def make_pipeline( + self, + provider: str, + data_path: str, + queue_maxsize: int = 8, + video_start_idx=None, + video_end_idx=None, + ): """Make a data loading pipeline. Args: provider: (str) Provider class to read the input sleap files. Either "LabelsReader" or "VideoReader". data_path: (str) Path to `.slp` file or `.mp4` to run inference on. - num_workers: (int) Number of subprocesses to use for data loading. 0 means - that the data will be loaded in the main process. *Default*: 0. + #TODO Returns: - Torch DataLoader where each item is a dictionary with key `image` if provider - is LabelsReader. If provider is VideoReader, this method initiates the reader - class (doesn't return a pipeline) and the Thread is started in - Predictor._predict_generator() method. - - Notes: - This method creates the class attribute `pipeline` and will be - called automatically when predicting on data from a new source only when the - provider is LabelsReader. + This method initiates the reader class (doesn't return a pipeline) and the + Thread is started in Predictor._predict_generator() method. """ self.provider = provider + # LabelsReader provider if self.provider == "LabelsReader": provider = LabelsReader - data_provider = provider.from_filename(data_path) - self.videos = data_provider.labels.videos - pipeline = Normalizer(data_provider, is_rgb=self.data_config.is_rgb) - pipeline = SizeMatcher( - pipeline, - max_height=self.data_config.max_height, - max_width=self.data_config.max_width, - provider=data_provider, - ) - pipeline = Resizer( - pipeline, scale=self.bottomup_config.data_config.preprocessing.scale - ) + max_stride = self.bottomup_config.model_config.backbone_config.max_stride - pipeline = PadToStride(pipeline, max_stride=max_stride) - # Remove duplicates. - self.pipeline = pipeline.sharding_filter() + self.preprocess = False + self.preprocess_config = { + "batch_size": self.batch_size, + "scale": self.bottomup_config.data_config.preprocessing.scale, + "is_rgb": self.data_config.is_rgb, + "max_stride": max_stride, + } - self.pipeline = DataLoader( - self.pipeline, batch_size=self.batch_size, num_workers=num_workers + self.pipeline = provider.from_filename( + filename=data_path, + queue_maxsize=queue_maxsize, ) - - return self.pipeline + self.videos = self.pipeline.labels.videos elif self.provider == "VideoReader": provider = VideoReader self.preprocess = True - self.video_preprocess_config = { + self.preprocess_config = { "batch_size": self.batch_size, "scale": self.bottomup_config.data_config.preprocessing.scale, "is_rgb": self.data_config.is_rgb, @@ -1252,14 +1257,12 @@ class (doesn't return a pipeline) and the Thread is started in self.bottomup_config.model_config.backbone_config.max_stride ), } - frame_queue = Queue( - maxsize=self.data_config.video_queue_maxsize if not None else 16 - ) + self.pipeline = provider.from_filename( filename=data_path, - frame_buffer=frame_queue, - start_idx=self.data_config.videoreader_start_idx, - end_idx=self.data_config.videoreader_end_idx, + queue_maxsize=queue_maxsize, + start_idx=video_start_idx, + end_idx=video_end_idx, ) self.videos = [self.pipeline.video] @@ -1369,10 +1372,9 @@ def main( is_rgb: bool = False, provider: str = "LabelsReader", batch_size: int = 4, - num_workers: int = 0, - video_queue_maxsize: int = 8, - videoreader_start_idx: int = 0, - videoreader_end_idx: int = 100, + queue_maxsize: int = 8, + videoreader_start_idx: Optional[int] = None, + videoreader_end_idx: Optional[int] = None, crop_hw: List[int] = (160, 160), peak_threshold: Union[float, List[float]] = 0.2, integral_refinement: str = None, @@ -1421,9 +1423,7 @@ def main( provider: (str) Provider class to read the input sleap files. Either "LabelsReader" or "VideoReader". Default: LabelsReader. batch_size: (int) Number of samples per batch. Default: 4. - num_workers: (int) Number of subprocesses to use for data loading. 0 means - that the data will be loaded in the main process. *Default*: 0. - video_queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8. + queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8. videoreader_start_idx: (int) Start index of the frames to read. Default: 0. videoreader_end_idx: (int) End index of the frames to read. Default: 100. crop_hw: List[int] Minimum height and width of the crop in pixels. Default: (160, 160). @@ -1516,7 +1516,7 @@ def main( } if provider == "VideoReader": - preprocess_config["video_queue_maxsize"] = video_queue_maxsize + preprocess_config["video_queue_maxsize"] = queue_maxsize preprocess_config["videoreader_start_idx"] = videoreader_start_idx preprocess_config["videoreader_end_idx"] = videoreader_end_idx @@ -1565,7 +1565,9 @@ def main( # initialize make_pipeline function - predictor.make_pipeline(provider, data_path, num_workers) + predictor.make_pipeline( + provider, data_path, queue_maxsize, videoreader_start_idx, videoreader_end_idx + ) # run predict output = predictor.predict( diff --git a/tests/assets/minimal_instance/training_config.yaml b/tests/assets/minimal_instance/training_config.yaml index cc3da057..e50c82ac 100755 --- a/tests/assets/minimal_instance/training_config.yaml +++ b/tests/assets/minimal_instance/training_config.yaml @@ -83,7 +83,7 @@ trainer_config: name: fly_unet_centered wandb_mode: '' api_key: '' - prv_runid: + prv_runid: null log_params: - trainer_config.optimizer_name - trainer_config.optimizer.amsgrad diff --git a/tests/inference/test_bottomup.py b/tests/inference/test_bottomup.py index 3659e5d3..523eb264 100644 --- a/tests/inference/test_bottomup.py +++ b/tests/inference/test_bottomup.py @@ -1,17 +1,27 @@ from omegaconf import OmegaConf import numpy as np -from sleap_nn.training.model_trainer import BottomUpModel, ModelTrainer +from pathlib import Path +import shutil +import sleap_io as sio +from sleap_nn.data.providers import process_lf +from sleap_nn.data.normalization import apply_normalization +from sleap_nn.training.model_trainer import BottomUpModel from sleap_nn.inference.paf_grouping import PAFScorer from sleap_nn.inference.bottomup import ( BottomUpInferenceModel, ) -def test_bottomup_inference_model(minimal_instance_bottomup_ckpt): +def test_bottomup_inference_model( + minimal_instance, minimal_instance_bottomup_ckpt, tmp_path: str +): """Test BottomUpInferenceModel.""" train_config = OmegaConf.load( f"{minimal_instance_bottomup_ckpt}/training_config.yaml" ) + OmegaConf.update( + train_config, "trainer_config.save_ckpt_path", f"{tmp_path}/test_model_trainer/" + ) OmegaConf.update( train_config, "data_config.train_labels_path", @@ -22,10 +32,10 @@ def test_bottomup_inference_model(minimal_instance_bottomup_ckpt): "data_config.val_labels_path", "./tests/assets/minimal_instance.pkg.slp", ) - # get dataloader - trainer = ModelTrainer(train_config) - trainer._create_data_loaders() - loader = trainer.val_data_loader + + labels = sio.load_slp(minimal_instance) + ex = process_lf(labels[0], 0, 2) + ex["image"] = apply_normalization(ex["image"]).unsqueeze(dim=0) torch_model = BottomUpModel.load_from_checkpoint( f"{minimal_instance_bottomup_ckpt}/best.ckpt", @@ -54,7 +64,7 @@ def test_bottomup_inference_model(minimal_instance_bottomup_ckpt): return_confmaps=False, ) - output = inference_layer(next(iter(loader)))[0] + output = inference_layer(ex)[0] assert "confmaps" not in output.keys() assert output["pred_instance_peaks"].is_nested assert tuple(output["pred_instance_peaks"][0].shape)[1:] == (2, 2) @@ -83,7 +93,7 @@ def test_bottomup_inference_model(minimal_instance_bottomup_ckpt): return_paf_graph=True, ) - output = inference_layer(next(iter(loader)))[0] + output = inference_layer(ex)[0] assert tuple(output["confmaps"].shape) == (1, 2, 192, 192) assert tuple(output["part_affinity_fields"].shape) == (1, 96, 96, 2) assert output["pred_instance_peaks"].is_nested diff --git a/tests/inference/test_predictors.py b/tests/inference/test_predictors.py index 46af0af9..87442784 100644 --- a/tests/inference/test_predictors.py +++ b/tests/inference/test_predictors.py @@ -262,7 +262,7 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): peak_threshold=0.3, ) assert isinstance(pred_labels, sio.Labels) - assert len(pred_labels) == 100 + assert len(pred_labels) == 1100 assert len(pred_labels[0].instances) == 1 lf = pred_labels[0] @@ -278,6 +278,7 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): provider="VideoReader", make_labels=False, peak_threshold=0.3, + videoreader_end_idx=100, ) assert isinstance(preds, list) assert len(preds) == 25 diff --git a/tests/inference/test_single_instance.py b/tests/inference/test_single_instance.py index c60f00f4..5aec0a6c 100644 --- a/tests/inference/test_single_instance.py +++ b/tests/inference/test_single_instance.py @@ -1,17 +1,15 @@ import sleap_io as sio from omegaconf import OmegaConf import numpy as np -from torch.utils.data.dataloader import DataLoader from sleap_nn.data.resizing import resize_image -from sleap_nn.data.providers import LabelsReader -from sleap_nn.data.normalization import Normalizer -from sleap_nn.data.resizing import SizeMatcher, Resizer, PadToStride from sleap_nn.training.model_trainer import ( SingleInstanceModel, ) from sleap_nn.inference.single_instance import ( SingleInstanceInferenceModel, ) +from sleap_nn.data.providers import process_lf +from sleap_nn.data.normalization import apply_normalization def test_single_instance_inference_model(minimal_instance, minimal_instance_ckpt): @@ -42,25 +40,9 @@ def test_single_instance_inference_model(minimal_instance, minimal_instance_ckpt for lf in labels: lf.instances = lf.instances[:1] - provider_pipeline = LabelsReader(labels) - pipeline = Normalizer(provider_pipeline, is_rgb=False) - pipeline = SizeMatcher( - pipeline, - max_height=None, - max_width=None, - provider=provider_pipeline, - ) - - pipeline = Resizer(pipeline, scale=config.data_config.preprocessing.scale) - pipeline = PadToStride( - pipeline, max_stride=config.model_config.backbone_config.max_stride - ) + ex = process_lf(labels[0], 0, 2) + ex["image"] = apply_normalization(ex["image"]).unsqueeze(dim=0) - pipeline = pipeline.sharding_filter() - data_pipeline = DataLoader( - pipeline, - batch_size=4, - ) find_peaks_layer = SingleInstanceInferenceModel( torch_model=torch_model, output_stride=2, @@ -69,8 +51,7 @@ def test_single_instance_inference_model(minimal_instance, minimal_instance_ckpt ) outputs = [] - for x in data_pipeline: - outputs.append(find_peaks_layer(x)) + outputs.append(find_peaks_layer(ex)) keys = outputs[0][0].keys() assert "pred_instance_peaks" in keys and "pred_peak_values" in keys assert "pred_confmaps" not in keys @@ -87,9 +68,8 @@ def test_single_instance_inference_model(minimal_instance, minimal_instance_ckpt input_scale=0.5, ) outputs = [] - for x in data_pipeline: - x["image"] = resize_image(x["image"], 0.5) - outputs.append(find_peaks_layer(x)) + ex["image"] = resize_image(ex["image"], 0.5) + outputs.append(find_peaks_layer(ex)) for i in outputs: instance = i[0]["pred_instance_peaks"].numpy() @@ -100,7 +80,6 @@ def test_single_instance_inference_model(minimal_instance, minimal_instance_ckpt torch_model=torch_model, output_stride=2, peak_threshold=0, return_confmaps=True ) outputs = [] - for x in data_pipeline: - outputs.append(find_peaks_layer(x)) + outputs.append(find_peaks_layer(ex)) assert "pred_confmaps" in outputs[0][0].keys() - assert outputs[0][0]["pred_confmaps"].shape[-2:] == (192, 192) + assert outputs[0][0]["pred_confmaps"].shape[-2:] == (96, 96) From a138f01443db615d639f1a42d96ec92e7c82788b Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 27 Sep 2024 17:18:08 -0700 Subject: [PATCH 07/11] Generate gt crops in CentroidCrop --- sleap_nn/data/providers.py | 309 ++++++++++++++++------------- sleap_nn/inference/predictors.py | 190 ++++++++---------- sleap_nn/inference/topdown.py | 92 ++++++--- tests/data/test_providers.py | 65 +++++- tests/inference/test_predictors.py | 49 +++-- tests/inference/test_topdown.py | 191 ++++++++---------- 6 files changed, 486 insertions(+), 410 deletions(-) diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index 70c27d86..7838ac01 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -93,119 +93,119 @@ def process_lf( return ex -# class LabelsReader(IterDataPipe): -# """IterDataPipe for reading frames from Labels object. - -# This IterDataPipe will produce examples containing a frame and an sleap_io.Instance -# from a sleap_io.Labels instance. - -# Attributes: -# labels: sleap_io.Labels object that contains LabeledFrames that will be -# accessed through a torchdata DataPipe. -# user_instances_only: True if filter labels only to user instances else False. -# Default value True -# instances_key: True if `instances` key needs to be present in the data pipeline. -# When this is set to True, the instances are appended with NaNs to have same -# number of instances to enable batching. Default: False. -# """ - -# def __init__( -# self, -# labels: sio.Labels, -# user_instances_only: bool = True, -# instances_key: bool = True, -# ): -# """Initialize labels attribute of the class.""" -# self.labels = copy.deepcopy(labels) -# self.max_instances = get_max_instances(labels) -# self.instances_key = instances_key - -# # Filter to user instances -# if user_instances_only: -# filtered_lfs = [] -# for lf in self.labels: -# if lf.user_instances is not None and len(lf.user_instances) > 0: -# lf.instances = lf.user_instances -# filtered_lfs.append(lf) -# self.labels = sio.Labels( -# videos=self.labels.videos, -# skeletons=self.labels.skeletons, -# labeled_frames=filtered_lfs, -# ) - -# @property -# def edge_inds(self) -> list: -# """Returns list of edge indices.""" -# return self.labels.skeletons[0].edge_inds - -# @property -# def max_height_and_width(self) -> Tuple[int, int]: -# """Return `(height, width)` that is the maximum of all videos.""" -# return max(video.shape[1] for video in self.labels.videos), max( -# video.shape[2] for video in self.labels.videos -# ) - -# @classmethod -# def from_filename( -# cls, -# filename: str, -# user_instances_only: bool = True, -# instances_key: bool = True, -# ): -# """Create LabelsReader from a .slp filename.""" -# labels = sio.load_slp(filename) -# return cls(labels, user_instances_only, instances_key) - -# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: -# """Return an example dictionary containing the following elements. - -# "image": A torch.Tensor containing full raw frame image as a uint8 array -# of shape (n_samples, channels, height, width). -# "instances": Keypoint coordinates for all instances in the frame as a -# float32 torch.Tensor of shape (n_samples, n_instances, n_nodes, 2). -# """ -# for lf in self.labels: -# image = np.transpose(lf.image, (2, 0, 1)) # HWC -> CHW - -# instances = [] -# for inst in lf: -# if not inst.is_empty: -# instances.append(inst.numpy()) -# instances = np.stack(instances, axis=0) - -# # Add singleton time dimension for single frames. -# image = np.expand_dims(image, axis=0) # (1, C, H, W) -# img_height, img_width = image.shape[-2:] -# instances = np.expand_dims( -# instances, axis=0 -# ) # (1, num_instances, num_nodes, 2) - -# instances = torch.from_numpy(instances.astype("float32")) -# num_instances, nodes = instances.shape[1:3] -# ex = { -# "image": torch.from_numpy(image), -# "video_idx": torch.tensor( -# self.labels.videos.index(lf.video), dtype=torch.int32 -# ), -# "frame_idx": torch.tensor(lf.frame_idx, dtype=torch.int32), -# "num_instances": num_instances, -# } -# ex["orig_size"] = torch.Tensor([img_height, img_width]) - -# if self.instances_key: -# nans = torch.full( -# (1, np.abs(self.max_instances - num_instances), nodes, 2), torch.nan -# ) -# ex["instances"] = torch.cat([instances, nans], dim=1) - -# yield ex +class LabelsReader(IterDataPipe): + """IterDataPipe for reading frames from Labels object. + + This IterDataPipe will produce examples containing a frame and an sleap_io.Instance + from a sleap_io.Labels instance. + + Attributes: + labels: sleap_io.Labels object that contains LabeledFrames that will be + accessed through a torchdata DataPipe. + user_instances_only: True if filter labels only to user instances else False. + Default value True + instances_key: True if `instances` key needs to be present in the data pipeline. + When this is set to True, the instances are appended with NaNs to have same + number of instances to enable batching. Default: False. + """ + + def __init__( + self, + labels: sio.Labels, + user_instances_only: bool = True, + instances_key: bool = True, + ): + """Initialize labels attribute of the class.""" + self.labels = copy.deepcopy(labels) + self.max_instances = get_max_instances(labels) + self.instances_key = instances_key + + # Filter to user instances + if user_instances_only: + filtered_lfs = [] + for lf in self.labels: + if lf.user_instances is not None and len(lf.user_instances) > 0: + lf.instances = lf.user_instances + filtered_lfs.append(lf) + self.labels = sio.Labels( + videos=self.labels.videos, + skeletons=self.labels.skeletons, + labeled_frames=filtered_lfs, + ) + + @property + def edge_inds(self) -> list: + """Returns list of edge indices.""" + return self.labels.skeletons[0].edge_inds + + @property + def max_height_and_width(self) -> Tuple[int, int]: + """Return `(height, width)` that is the maximum of all videos.""" + return max(video.shape[1] for video in self.labels.videos), max( + video.shape[2] for video in self.labels.videos + ) + + @classmethod + def from_filename( + cls, + filename: str, + user_instances_only: bool = True, + instances_key: bool = True, + ): + """Create LabelsReader from a .slp filename.""" + labels = sio.load_slp(filename) + return cls(labels, user_instances_only, instances_key) + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + """Return an example dictionary containing the following elements. + + "image": A torch.Tensor containing full raw frame image as a uint8 array + of shape (n_samples, channels, height, width). + "instances": Keypoint coordinates for all instances in the frame as a + float32 torch.Tensor of shape (n_samples, n_instances, n_nodes, 2). + """ + for lf in self.labels: + image = np.transpose(lf.image, (2, 0, 1)) # HWC -> CHW + + instances = [] + for inst in lf: + if not inst.is_empty: + instances.append(inst.numpy()) + instances = np.stack(instances, axis=0) + + # Add singleton time dimension for single frames. + image = np.expand_dims(image, axis=0) # (1, C, H, W) + img_height, img_width = image.shape[-2:] + instances = np.expand_dims( + instances, axis=0 + ) # (1, num_instances, num_nodes, 2) + + instances = torch.from_numpy(instances.astype("float32")) + num_instances, nodes = instances.shape[1:3] + ex = { + "image": torch.from_numpy(image), + "video_idx": torch.tensor( + self.labels.videos.index(lf.video), dtype=torch.int32 + ), + "frame_idx": torch.tensor(lf.frame_idx, dtype=torch.int32), + "num_instances": num_instances, + } + ex["orig_size"] = torch.Tensor([img_height, img_width]) + + if self.instances_key: + nans = torch.full( + (1, np.abs(self.max_instances - num_instances), nodes, 2), torch.nan + ) + ex["instances"] = torch.cat([instances, nans], dim=1) + + yield ex class VideoReader(Thread): """Thread module for reading frames from sleap-io Video object. This module will load the frames from video and pushes them as Tensors into a buffer - queue as a tuple in the format (image, frame index, video index, (height, width)) + queue as a dictionary with (image, frame index, video index, (height, width)) which are then batched and consumed during the inference process. Attributes: @@ -266,43 +266,51 @@ def run(self): img = np.expand_dims(img, axis=0) # (1, C, H, W) self.frame_buffer.put( - ( - torch.from_numpy(img), # img - torch.tensor(idx, dtype=torch.int32), # frame idx - torch.tensor(0, dtype=torch.int32), # video idx - torch.Tensor(img.shape[-2:]), # orig shape - ) + { + "image": torch.from_numpy(img), + "frame_idx": torch.tensor(idx, dtype=torch.int32), + "video_idx": torch.tensor(0, dtype=torch.int32), + "orig_size": torch.Tensor(img.shape[-2:]), + } ) except Exception as e: print(f"Error when reading video frame. Stopping video reader.\n{e}") finally: - self.frame_buffer.put((None, None, None)) + self.frame_buffer.put( + { + "image": None, + "frame_idx": None, + "video_idx": None, + "orig_size": None, + } + ) -class LabelsReader(Thread): +class LabelReader(Thread): """Thread module for reading images from sleap-io Labels object. This module will load the images from `.slp` files and pushes them as Tensors into a - buffer queue as a tuple in the format (image, frame index, video index, (height, width)) + buffer queue as a dictionary with (image, frame index, video index, (height, width)) which are then batched and consumed during the inference process. Attributes: labels: sleap_io.Labels object that contains LabeledFrames that will be accessed through a torchdata DataPipe. frame_buffer: Frame buffer queue. + instances_key: If `True`, then instances are appended to the output dictionary. """ def __init__( - self, - labels: sio.Labels, - frame_buffer: Queue, + self, labels: sio.Labels, frame_buffer: Queue, instances_key: bool = False ): """Initialize attribute of the class.""" super().__init__() self.labels = labels self.frame_buffer = frame_buffer + self.instances_key = instances_key + self.max_instances = get_max_instances(self.labels) def total_len(self): """Returns the total number of frames in the video.""" @@ -317,14 +325,12 @@ def max_height_and_width(self) -> Tuple[int, int]: @classmethod def from_filename( - cls, - filename: str, - queue_maxsize: int, + cls, filename: str, queue_maxsize: int, instances_key: bool = False ): """Create LabelsReader from a .slp filename.""" labels = sio.load_slp(filename) frame_buffer = Queue(maxsize=queue_maxsize) - return cls(labels, frame_buffer) + return cls(labels, frame_buffer, instances_key) def run(self): """Adds frames to the buffer queue.""" @@ -335,19 +341,54 @@ def run(self): img = np.transpose(img, (2, 0, 1)) # convert H,W,C to C,H,W img = np.expand_dims(img, axis=0) # (1, C, H, W) - self.frame_buffer.put( - ( - torch.from_numpy(img), # img - torch.tensor(idx, dtype=torch.int32), # frame idx - torch.tensor( - self.labels.videos.index(lf.video), dtype=torch.int32 - ), # video idx - torch.Tensor(img.shape[-2:]), # orig shape - ) - ) + sample = { + "image": torch.from_numpy(img), + "frame_idx": torch.tensor(idx, dtype=torch.int32), + "video_idx": torch.tensor( + self.labels.videos.index(lf.video), dtype=torch.int32 + ), + "orig_size": torch.Tensor(img.shape[-2:]), + } + + if self.instances_key: + instances = [] + for inst in lf: + if not inst.is_empty: + instances.append(inst.numpy()) + instances = np.stack(instances, axis=0) + + # Add singleton time dimension for single frames. + instances = np.expand_dims( + instances, axis=0 + ) # (n_samples=1, num_instances, num_nodes, 2) + + instances = torch.from_numpy(instances.astype("float32")) + + num_instances, nodes = instances.shape[1:3] + + # append with nans for broadcasting + if self.max_instances != 1: + nans = torch.full( + (1, np.abs(self.max_instances - num_instances), nodes, 2), + torch.nan, + ) + instances = torch.cat( + [instances, nans], dim=1 + ) # (n_samples, max_instances, num_nodes, 2) + + sample["instances"] = instances + + self.frame_buffer.put(sample) except Exception as e: print(f"Error when reading labelled frame. Stopping labels reader.\n{e}") finally: - self.frame_buffer.put((None, None, None, None)) + self.frame_buffer.put( + { + "image": None, + "frame_idx": None, + "video_idx": None, + "orig_size": None, + } + ) diff --git a/sleap_nn/inference/predictors.py b/sleap_nn/inference/predictors.py index deb7dfd2..1aac99d1 100644 --- a/sleap_nn/inference/predictors.py +++ b/sleap_nn/inference/predictors.py @@ -12,8 +12,13 @@ import lightning as L import litdata as ld from omegaconf import OmegaConf -from sleap_nn.data.providers import LabelsReader, VideoReader -from sleap_nn.data.resizing import resize_image, apply_pad_to_stride, apply_sizematcher +from sleap_nn.data.providers import LabelReader, VideoReader +from sleap_nn.data.resizing import ( + resize_image, + apply_pad_to_stride, + apply_sizematcher, + apply_resizer, +) from sleap_nn.data.normalization import ( apply_normalization, convert_to_grayscale, @@ -53,13 +58,14 @@ class Predictor(ABC): preprocess_config: Preprocessing config with keys: [`batch_size`, `scale`, `is_rgb`, `max_stride`]. Default: {"batch_size": 4, "scale": 1.0, "is_rgb": False, "max_stride": 1} - provider: Provider for inference pipeline. One of ["LabelsReader", "VideoReader"]. - Default: LabelsReader. - pipeline: If provider is LabelsReader, pipeline is a `DataLoader` object. If provider + provider: Provider for inference pipeline. One of ["LabelReader", "VideoReader"]. + Default: LabelReader. + pipeline: If provider is LabelReader, pipeline is a `DataLoader` object. If provider is VideoReader, pipeline is an instance of `sleap_nn.data.providers.VideoReader` class. Default: None. inference_model: Instance of one of the inference models ["TopDownInferenceModel", "SingleInstanceInferenceModel", "BottomUpInferenceModel"]. Default: None. + instances_key: If `True`, then instances are appended to the data samples. """ preprocess: bool = True @@ -69,13 +75,14 @@ class Predictor(ABC): "is_rgb": False, "max_stride": 1, } - provider: Union[LabelsReader, VideoReader] = LabelsReader - pipeline: Optional[Union[LabelsReader, VideoReader]] = None + provider: Union[LabelReader, VideoReader] = LabelReader + pipeline: Optional[Union[LabelReader, VideoReader]] = None inference_model: Optional[ Union[ TopDownInferenceModel, SingleInstanceInferenceModel, BottomUpInferenceModel ] ] = None + instances_key: bool = False @classmethod def from_model_paths( @@ -245,26 +252,33 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: fidxs = [] vidxs = [] org_szs = [] + instances = [] for _ in range(batch_size): frame = self.pipeline.frame_buffer.get() - if frame[0] is None: + if frame["image"] is None: done = True break - imgs.append(frame[0].unsqueeze(dim=0)) - fidxs.append(frame[1]) - vidxs.append(frame[2]) - org_szs.append(frame[3].unsqueeze(dim=0)) + imgs.append(frame["image"].unsqueeze(dim=0)) + fidxs.append(frame["frame_idx"]) + vidxs.append(frame["video_idx"]) + org_szs.append(frame["orig_size"].unsqueeze(dim=0)) + if self.instances_key: + instances.append(frame["instances"].unsqueeze(dim=0)) if imgs: imgs = torch.concatenate(imgs, dim=0) fidxs = torch.tensor(fidxs, dtype=torch.int32) vidxs = torch.tensor(vidxs, dtype=torch.int32) org_szs = torch.concatenate(org_szs, dim=0) + if self.instances_key: + instances = torch.concatenate(instances, dim=0) ex = { "image": imgs, "frame_idx": fidxs, "video_idx": vidxs, "orig_size": org_szs, } + if self.instances_key: + ex["instances"] = instances ex["image"] = apply_normalization(ex["image"]) if self.preprocess_config["is_rgb"]: ex["image"] = convert_to_rgb(ex["image"]) @@ -273,7 +287,12 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: if self.preprocess: scale = self.preprocess_config["scale"] if scale != 1.0: - ex["image"] = resize_image(ex["image"], scale) + if self.instances_key: + ex["image"], ex["instances"] = apply_resizer( + ex["image"], ex["instances"] + ) + else: + ex["image"] = resize_image(ex["image"], scale) ex["image"] = apply_pad_to_stride( ex["image"], self.preprocess_config["max_stride"] ) @@ -386,6 +405,9 @@ def _initialize_inference_model(self): """Initialize the inference model from the trained models and configuration.""" # Create an instance of CentroidLayer if centroid_config is not None return_crops = False + # if both centroid and centered-instance model are provided, set return crops to True + if self.confmap_model: + return_crops = True if isinstance(self.peak_threshold, list): centroid_peak_threshold = self.peak_threshold[0] centered_instance_peak_threshold = self.peak_threshold[1] @@ -394,14 +416,17 @@ def _initialize_inference_model(self): centered_instance_peak_threshold = self.peak_threshold if self.centroid_config is None: - centroid_crop_layer = None + + centroid_crop_layer = CentroidCrop( + use_gt_centroids=True, + crop_hw=self.data_config.crop_hw, + anchor_ind=self.confmap_config.model_config.head_configs.centered_instance.confmaps.anchor_part, + return_crops=return_crops, + ) + else: max_stride = self.centroid_config.model_config.backbone_config.max_stride - # if both centroid and centered-instance model are provided, set return crops to True - if self.confmap_model: - return_crops = True - # initialize centroid crop layer centroid_crop_layer = CentroidCrop( torch_model=self.centroid_model, @@ -415,11 +440,13 @@ def _initialize_inference_model(self): max_stride=max_stride, input_scale=self.centroid_config.data_config.preprocessing.scale, crop_hw=self.data_config.crop_hw, + use_gt_centroids=False, ) # Create an instance of FindInstancePeaks layer if confmap_config is not None if self.confmap_config is None: instance_peaks_layer = FindInstancePeaksGroundTruth() + self.instances_key = True else: max_stride = self.confmap_config.model_config.backbone_config.max_stride @@ -434,6 +461,11 @@ def _initialize_inference_model(self): input_scale=self.confmap_config.data_config.preprocessing.scale, ) + if self.centroid_config is None and self.confmap_config is not None: + self.instances_key = ( + True # we need `instances` to get ground-truth centroids + ) + # Initialize the inference model with centroid and instance peak layers self.inference_model = TopDownInferenceModel( centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer @@ -490,7 +522,7 @@ def from_trained_models( An instance of `TopDownPredictor` with the loaded models. One of the two models can be left as `None` to perform inference with ground - truth data. This will only work with `LabelsReader` as the provider. + truth data. This will only work with `LabelReader` as the provider. """ if centroid_ckpt_path is not None: @@ -559,9 +591,11 @@ def make_pipeline( Args: provider: (str) Provider class to read the input sleap files. - Either "LabelsReader" or "VideoReader". + Either "LabelReader" or "VideoReader". data_path: (str) Path to `.slp` file or `.mp4` to run inference on. - #TODO: + queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8. + video_start_idx: (int) Start index of the frames to read. Default: None. + video_end_idx: (int) End index of the frames to read. Default: None. Returns: This method initiates the reader class (doesn't return a pipeline) and the @@ -569,17 +603,18 @@ def make_pipeline( """ self.provider = provider - # LabelsReader provider - if self.provider == "LabelsReader": - provider = LabelsReader + # LabelReader provider + if self.provider == "LabelReader": + provider = LabelReader - max_stride = self.confmap_config.model_config.backbone_config.max_stride - scale = self.confmap_config.data_config.preprocessing.scale if self.centroid_config is not None: max_stride = ( self.centroid_config.model_config.backbone_config.max_stride ) scale = self.centroid_config.data_config.preprocessing.scale + else: + max_stride = self.confmap_config.model_config.backbone_config.max_stride + scale = self.confmap_config.data_config.preprocessing.scale self.preprocess = False self.preprocess_config = { @@ -592,6 +627,7 @@ def make_pipeline( self.pipeline = provider.from_filename( filename=data_path, queue_maxsize=queue_maxsize, + instances_key=self.instances_key, ) self.videos = self.pipeline.labels.videos @@ -624,7 +660,7 @@ def make_pipeline( else: raise Exception( - "Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider" + "Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider" ) def _make_labeled_frames_from_generator( @@ -702,70 +738,6 @@ def _make_labeled_frames_from_generator( ) return pred_labels - def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: - """Create a generator that yields batches of inference results. - - This method handles creating a pipeline object depending on the model type and - provider for loading the data, as well as looping over the batches and - running inference. - - Returns: - A generator yielding batches predicted results as dictionaries of numpy - arrays. - """ - # Initialize inference model if needed. - - if self.inference_model is None: - self._initialize_inference_model() - - # Loop over data batches. - self.pipeline.start() - batch_size = self.preprocess_config["batch_size"] - done = False - while not done: - imgs = [] - fidxs = [] - vidxs = [] - org_szs = [] - for _ in range(batch_size): - frame = self.pipeline.frame_buffer.get() - if frame[0] is None: - done = True - break - imgs.append(frame[0].unsqueeze(dim=0)) - fidxs.append(frame[1]) - vidxs.append(frame[2]) - org_szs.append(frame[3].unsqueeze(dim=0)) - if imgs: - imgs = torch.concatenate(imgs, dim=0) - fidxs = torch.tensor(fidxs, dtype=torch.int32) - vidxs = torch.tensor(vidxs, dtype=torch.int32) - org_szs = torch.concatenate(org_szs, dim=0) - ex = { - "image": imgs, - "frame_idx": fidxs, - "video_idx": vidxs, - "orig_size": org_szs, - } - ex["image"] = apply_normalization(ex["image"]) - if self.preprocess_config["is_rgb"]: - ex["image"] = convert_to_rgb(ex["image"]) - else: - ex["image"] = convert_to_grayscale(ex["image"]) - if self.preprocess: - scale = self.preprocess_config["scale"] - if scale != 1.0: - ex["image"] = resize_image(ex["image"], scale) - ex["image"] = apply_pad_to_stride( - ex["image"], self.preprocess_config["max_stride"] - ) - outputs_list = self.inference_model(ex) - for output in outputs_list: - output = self._convert_tensors_to_numpy(output) - yield output - - self.pipeline.join() - @attrs.define class SingleInstancePredictor(Predictor): @@ -910,9 +882,11 @@ def make_pipeline( Args: provider: (str) Provider class to read the input sleap files. - Either "LabelsReader" or "VideoReader". + Either "LabelReader" or "VideoReader". data_path: (str) Path to `.slp` file or `.mp4` to run inference on. - #TODO + queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8. + video_start_idx: (int) Start index of the frames to read. Default: None. + video_end_idx: (int) End index of the frames to read. Default: None. Returns: This method initiates the reader class (doesn't return a pipeline) and the @@ -921,9 +895,9 @@ def make_pipeline( """ self.provider = provider - # LabelsReader provider - if self.provider == "LabelsReader": - provider = LabelsReader + # LabelReader provider + if self.provider == "LabelReader": + provider = LabelReader max_stride = self.confmap_config.model_config.backbone_config.max_stride @@ -964,7 +938,7 @@ def make_pipeline( else: raise Exception( - "Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider" + "Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider" ) def _make_labeled_frames_from_generator( @@ -1217,18 +1191,20 @@ def make_pipeline( Args: provider: (str) Provider class to read the input sleap files. - Either "LabelsReader" or "VideoReader". + Either "LabelReader" or "VideoReader". data_path: (str) Path to `.slp` file or `.mp4` to run inference on. - #TODO + queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8. + video_start_idx: (int) Start index of the frames to read. Default: None. + video_end_idx: (int) End index of the frames to read. Default: None. Returns: This method initiates the reader class (doesn't return a pipeline) and the Thread is started in Predictor._predict_generator() method. """ self.provider = provider - # LabelsReader provider - if self.provider == "LabelsReader": - provider = LabelsReader + # LabelReader provider + if self.provider == "LabelReader": + provider = LabelReader max_stride = self.bottomup_config.model_config.backbone_config.max_stride @@ -1269,7 +1245,7 @@ def make_pipeline( else: raise Exception( - "Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider" + "Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider" ) def _make_labeled_frames_from_generator( @@ -1370,7 +1346,7 @@ def main( max_width: int = None, max_height: int = None, is_rgb: bool = False, - provider: str = "LabelsReader", + provider: str = "LabelReader", batch_size: int = 4, queue_maxsize: int = 8, videoreader_start_idx: Optional[int] = None, @@ -1421,11 +1397,11 @@ def main( is set to False, then we convert the image to grayscale (single-channel) image. Default: False. provider: (str) Provider class to read the input sleap files. - Either "LabelsReader" or "VideoReader". Default: LabelsReader. + Either "LabelReader" or "VideoReader". Default: LabelReader. batch_size: (int) Number of samples per batch. Default: 4. queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8. - videoreader_start_idx: (int) Start index of the frames to read. Default: 0. - videoreader_end_idx: (int) End index of the frames to read. Default: 100. + videoreader_start_idx: (int) Start index of the frames to read. Default: None. + videoreader_end_idx: (int) End index of the frames to read. Default: None. crop_hw: List[int] Minimum height and width of the crop in pixels. Default: (160, 160). peak_threshold: (float) Minimum confidence threshold. Peaks with values below this will be ignored. Default: 0.2. This can also be `List[float]` for topdown diff --git a/sleap_nn/inference/topdown.py b/sleap_nn/inference/topdown.py index 04a79b23..6c1c1ad3 100644 --- a/sleap_nn/inference/topdown.py +++ b/sleap_nn/inference/topdown.py @@ -9,6 +9,7 @@ apply_pad_to_stride, ) from sleap_nn.inference.peak_finding import crop_bboxes +from sleap_nn.data.instance_centroids import generate_centroids from sleap_nn.data.instance_cropping import make_centered_bboxes from sleap_nn.inference.peak_finding import find_global_peaks, find_local_peaks @@ -47,12 +48,17 @@ class CentroidCrop(L.LightningModule): If > 1, this will pad the bottom and right of the images to ensure they meet this divisibility criteria. Padding is applied after the scaling specified in the `scale` attribute. + use_gt_centroids: If `True`, then the crops are generated using ground-truth centroids. + If `False`, then centroids are predicted using a trained centroid model. + anchor_ind: The index of the node to use as the anchor for the centroid. If not + provided or if not present in the instance, the midpoint of the bounding box + is used instead. """ def __init__( self, - torch_model: L.LightningModule, + torch_model: L.LightningModule = None, output_stride: int = 1, peak_threshold: float = 0.0, max_instances: Optional[int] = None, @@ -63,6 +69,8 @@ def __init__( crop_hw: tuple = (160, 160), input_scale: float = 1.0, max_stride: int = 1, + use_gt_centroids: bool = False, + anchor_ind: int = None, **kwargs, ): """Initialise the model attributes.""" @@ -78,6 +86,8 @@ def __init__( self.crop_hw = crop_hw self.input_scale = input_scale self.max_stride = max_stride + self.use_gt_centroids = use_gt_centroids + self.anchor_ind = anchor_ind def _generate_crops(self, inputs): """Generate Crops from the predicted centroids.""" @@ -152,6 +162,42 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: and (batch, max_instances) repsectively which could then to passed to FindInstancePeaksGroundTruth class. """ + if self.use_gt_centroids: + batch = inputs["video_idx"].shape[0] + centroids = generate_centroids( + inputs["instances"], anchor_ind=self.anchor_ind + ) + centroid_vals = torch.ones(centroids.shape)[..., 0] + self.refined_peaks_batched = [x[0] for x in centroids] + self.peak_vals_batched = [x[0] for x in centroid_vals] + + max_instances = ( + self.max_instances + if self.max_instances is not None + else inputs["instances"].shape[-3] + ) + + refined_peaks_with_nans = torch.zeros((batch, max_instances, 2)) + peak_vals_with_nans = torch.zeros((batch, max_instances)) + for ind, (r, p) in enumerate( + zip(self.refined_peaks_batched, self.peak_vals_batched) + ): + refined_peaks_with_nans[ind] = r + peak_vals_with_nans[ind] = p + + inputs.update( + { + "centroids": refined_peaks_with_nans.unsqueeze(dim=1), + "centroid_vals": peak_vals_with_nans, + } + ) + + if self.return_crops: + crops_dict = self._generate_crops(inputs) + return crops_dict + else: + return inputs + # Network forward pass. orig_image = inputs["image"] scaled_image = resize_image(orig_image, self.input_scale) @@ -469,38 +515,20 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: `"pred_peak_vals": (batch_size, n_nodes)`: Confidence values for the instance skeleton points. """ - batch_size = batch["video_idx"].shape[0] + if isinstance(self.instance_peaks, FindInstancePeaksGroundTruth): + if "instances" not in batch: + raise ValueError( + "Ground truth data was not detected... " + "Please load both models when predicting on non-ground-truth data." + ) + self.centroid_crop.eval() peaks_output = [] - if self.centroid_crop is None: - batch["centroid_val"] = torch.ones(batch_size) - if isinstance(self.instance_peaks, FindInstancePeaksGroundTruth): - if "instances" in batch: - peaks_output.append(self.instance_peaks(batch)) - else: - raise ValueError( - "Ground truth data was not detected... " - "Please load both models when predicting on non-ground-truth data." - ) - else: - self.instance_peaks.eval() - peaks_output.append(self.instance_peaks(batch)) + batch = self.centroid_crop(batch) + if isinstance(self.instance_peaks, FindInstancePeaksGroundTruth): + peaks_output.append(self.instance_peaks(batch)) else: - self.centroid_crop.eval() - if isinstance(self.instance_peaks, FindInstancePeaksGroundTruth): - if "instances" in batch: - max_inst = batch["instances"].shape[-3] - self.centroid_crop.max_instances = max_inst - else: - raise ValueError( - "Ground truth data was not detected... " - "Please load both models when predicting on non-ground-truth data." - ) - batch = self.centroid_crop(batch) - if isinstance(self.instance_peaks, FindInstancePeaksGroundTruth): - peaks_output.append(self.instance_peaks(batch)) - else: - for i in batch: - self.instance_peaks.eval() - peaks_output.append(self.instance_peaks(i)) + for i in batch: + self.instance_peaks.eval() + peaks_output.append(self.instance_peaks(i)) return peaks_output diff --git a/tests/data/test_providers.py b/tests/data/test_providers.py index 10b0c063..06e740e4 100644 --- a/tests/data/test_providers.py +++ b/tests/data/test_providers.py @@ -1,6 +1,6 @@ import torch -from sleap_nn.data.providers import LabelsReader, VideoReader, process_lf +from sleap_nn.data.providers import LabelsReader, LabelReader, VideoReader, process_lf from queue import Queue import sleap_io as sio import numpy as np @@ -36,11 +36,11 @@ def test_videoreader_provider(centered_instance_video): data = [] for i in range(batch_size): frame = reader.frame_buffer.get() - if frame[0] is None: + if frame["image"] is None: break data.append(frame) assert len(data) == batch_size - assert data[0][0].shape == (1, 1, 384, 384) + assert data[0]["image"].shape == (1, 1, 384, 384) except: raise finally: @@ -48,10 +48,9 @@ def test_videoreader_provider(centered_instance_video): assert reader.total_len() == 4 # check graceful stop (video has 1100 frames) - queue = Queue(maxsize=4) reader = VideoReader.from_filename( filename=centered_instance_video, - frame_buffer=queue, + queue_maxsize=4, start_idx=1099, end_idx=1104, ) @@ -61,11 +60,11 @@ def test_videoreader_provider(centered_instance_video): data = [] for i in range(batch_size): frame = reader.frame_buffer.get() - if frame[0] is None: + if frame["image"] is None: break data.append(frame) assert len(data) == 1 - assert data[0][0].shape == (1, 1, 384, 384) + assert data[0]["image"].shape == (1, 1, 384, 384) except: raise finally: @@ -81,11 +80,11 @@ def test_videoreader_provider(centered_instance_video): data = [] for i in range(batch_size): frame = reader.frame_buffer.get() - if frame[0] is None: + if frame["image"] is None: break data.append(frame) assert len(data) == batch_size - assert data[0][0].shape == (1, 1, 384, 384) + assert data[0]["image"].shape == (1, 1, 384, 384) except: raise finally: @@ -93,6 +92,54 @@ def test_videoreader_provider(centered_instance_video): assert reader.total_len() == 6 +def test_labelreader_provider(minimal_instance): + """Test LabelReader class.""" + labels = sio.load_slp(minimal_instance) + queue = Queue(maxsize=4) + reader = LabelReader(labels=labels, frame_buffer=queue, instances_key=False) + assert reader.max_height_and_width == (384, 384) + reader.start() + batch_size = 1 + try: + data = [] + for i in range(batch_size): + frame = reader.frame_buffer.get() + if frame["image"] is None: + break + data.append(frame) + assert len(data) == batch_size + assert data[0]["image"].shape == (1, 1, 384, 384) + assert "instances" not in data[0] + except: + raise + finally: + reader.join() + assert reader.total_len() == 1 + + # with instances key + reader = LabelReader.from_filename( + minimal_instance, queue_maxsize=4, instances_key=True + ) + assert reader.max_height_and_width == (384, 384) + reader.start() + batch_size = 1 + try: + data = [] + for i in range(batch_size): + frame = reader.frame_buffer.get() + if frame["image"] is None: + break + data.append(frame) + assert len(data) == batch_size + assert data[0]["image"].shape == (1, 1, 384, 384) + assert "instances" in data[0] + except: + raise + finally: + reader.join() + assert reader.total_len() == 1 + + def test_process_lf(minimal_instance): labels = sio.load_slp(minimal_instance) lf = labels[0] diff --git a/tests/inference/test_predictors.py b/tests/inference/test_predictors.py index 87442784..8519ba68 100644 --- a/tests/inference/test_predictors.py +++ b/tests/inference/test_predictors.py @@ -17,10 +17,10 @@ def test_topdown_predictor( pred_labels = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", return_confmaps=False, make_labels=True, - peak_threshold=0.1, + peak_threshold=0.0, ) assert isinstance(pred_labels, sio.Labels) assert len(pred_labels) == 1 @@ -40,14 +40,15 @@ def test_topdown_predictor( preds = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=False, peak_threshold=0.0, integral_refinement="integral", - batch_size=1, ) assert isinstance(preds, list) - assert len(preds) == 2 + assert len(preds) == 1 + assert len(preds[0]["instance_image"]) == 2 + assert len(preds[0]["centroid"]) == 2 assert isinstance(preds[0], dict) assert "pred_confmaps" not in preds[0].keys() @@ -62,7 +63,7 @@ def test_topdown_predictor( preds = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=False, ) @@ -72,7 +73,7 @@ def test_topdown_predictor( pred_labels = main( model_paths=[minimal_instance_centroid_ckpt, minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=True, max_instances=6, peak_threshold=[0.0, 0.0], @@ -83,16 +84,19 @@ def test_topdown_predictor( assert len(pred_labels[0].instances) <= 6 # centroid model + max_instances = 6 pred_labels = main( model_paths=[minimal_instance_centroid_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=False, - max_instances=6, + max_instances=max_instances, peak_threshold=0.1, ) assert len(pred_labels) == 1 - assert pred_labels[0]["centroids"].shape == (1, 1, 2, 2) + assert ( + pred_labels[0]["centroids"].shape[-2] <= max_instances + ) # centroids (1,1,max_instances,2) # Provider = VideoReader # centroid + centered-instance model inference @@ -108,13 +112,14 @@ def test_topdown_predictor( videoreader_start_idx=0, videoreader_end_idx=100, ) + assert isinstance(pred_labels, sio.Labels) assert len(pred_labels) == 100 # Unrecognized provider with pytest.raises( Exception, - match="Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider", + match="Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider", ): pred_labels = main( model_paths=[minimal_instance_centroid_ckpt, minimal_instance_ckpt], @@ -179,7 +184,7 @@ def test_topdown_predictor( def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): """Test SingleInstancePredictor module.""" - # provider as LabelsReader + # provider as LabelReader _config = OmegaConf.load(f"{minimal_instance_ckpt}/training_config.yaml") config = _config.copy() @@ -199,10 +204,10 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): pred_labels = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=True, max_instances=6, - peak_threshold=0.3, + peak_threshold=0.1, max_height=500, max_width=500, ) @@ -222,7 +227,7 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): preds = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=False, peak_threshold=0.3, max_height=500, @@ -307,7 +312,7 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): # check if labels are created from ckpt with pytest.raises( Exception, - match="Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider", + match="Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider", ): preds = main( model_paths=[minimal_instance_ckpt], @@ -324,13 +329,13 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): def test_bottomup_predictor(minimal_instance, minimal_instance_bottomup_ckpt): """Test BottomUpPredictor module.""" - # provider as LabelsReader + # provider as LabelReader # check if labels are created from ckpt pred_labels = main( model_paths=[minimal_instance_bottomup_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=True, max_instances=6, peak_threshold=0.03, @@ -351,7 +356,7 @@ def test_bottomup_predictor(minimal_instance, minimal_instance_bottomup_ckpt): preds = main( model_paths=[minimal_instance_bottomup_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=False, max_instances=6, peak_threshold=0.03, @@ -368,7 +373,7 @@ def test_bottomup_predictor(minimal_instance, minimal_instance_bottomup_ckpt): pred_labels = main( model_paths=[minimal_instance_bottomup_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=True, max_instances=6, peak_threshold=1.0, @@ -415,7 +420,7 @@ def test_bottomup_predictor(minimal_instance, minimal_instance_bottomup_ckpt): # unrecognized provider with pytest.raises( Exception, - match="Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider", + match="Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider", ): preds = main( model_paths=[minimal_instance_bottomup_ckpt], @@ -430,7 +435,7 @@ def test_bottomup_predictor(minimal_instance, minimal_instance_bottomup_ckpt): pred_labels = main( model_paths=[minimal_instance_bottomup_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=True, max_instances=6, peak_threshold=0.03, diff --git a/tests/inference/test_topdown.py b/tests/inference/test_topdown.py index 8a59808d..3512a114 100644 --- a/tests/inference/test_topdown.py +++ b/tests/inference/test_topdown.py @@ -3,13 +3,13 @@ import numpy as np import torch from torch.utils.data.dataloader import DataLoader -from sleap_nn.data.providers import LabelsReader +import sleap_io as sio +from sleap_nn.data.providers import process_lf, LabelsReader from sleap_nn.data.resizing import resize_image -from sleap_nn.data.instance_centroids import InstanceCentroidFinder -from sleap_nn.data.normalization import Normalizer +from sleap_nn.data.instance_centroids import InstanceCentroidFinder, generate_centroids +from sleap_nn.data.normalization import apply_normalization, Normalizer from sleap_nn.data.resizing import SizeMatcher, Resizer, PadToStride -from sleap_nn.data.instance_centroids import InstanceCentroidFinder -from sleap_nn.data.instance_cropping import InstanceCropper +from sleap_nn.data.instance_cropping import InstanceCropper, generate_crops from sleap_nn.training.model_trainer import ( CentroidModel, ModelTrainer, @@ -24,7 +24,7 @@ def initialize_model(config, minimal_instance, minimal_instance_ckpt): - """Returns data loader, trained torch model and FindInstancePeaks layer to test InferenceModels.""" + """Returns trained torch model and FindInstancePeaks layer to test InferenceModels.""" # for centered instance model config = OmegaConf.load(f"{minimal_instance_ckpt}/training_config.yaml") torch_model = TopDownCenteredInstanceModel.load_from_checkpoint( @@ -34,42 +34,16 @@ def initialize_model(config, minimal_instance, minimal_instance_ckpt): model_type="centered_instance", ) - data_provider = LabelsReader.from_filename(minimal_instance) - pipeline = Normalizer(data_provider, is_rgb=False) - pipeline = SizeMatcher( - pipeline, - provider=data_provider, - max_height=None, - max_width=None, - ) - pipeline = InstanceCentroidFinder( - pipeline, - anchor_ind=0, - ) - pipeline = InstanceCropper( - pipeline, - crop_hw=(160, 160), - ) - pipeline = Resizer( - pipeline, scale=1.0, image_key="instance_image", instances_key="instance" - ) - pipeline = PadToStride(pipeline, max_stride=16, image_key="instance_image") - - pipeline = pipeline.sharding_filter() - data_pipeline = DataLoader( - pipeline, - batch_size=4, - ) find_peaks_layer = FindInstancePeaks( torch_model=torch_model, output_stride=2, peak_threshold=0.0, return_confmaps=False, ) - return data_pipeline, torch_model, find_peaks_layer + return torch_model, find_peaks_layer -def test_centroid_inference_model(config): +def test_centroid_inference_model(config, minimal_instance): """Test CentroidCrop class to run inference on centroid models.""" OmegaConf.update( @@ -81,8 +55,14 @@ def test_centroid_inference_model(config): del config.model_config.head_configs.centroid["confmaps"].part_names trainer = ModelTrainer(config) - trainer._create_data_loaders() - loader = next(iter(trainer.val_data_loader)) + labels = sio.load_slp(minimal_instance) + ex = process_lf(labels[0], 0, 2) + ex["image"] = apply_normalization(ex["image"]).unsqueeze(dim=0) + ex["instances"] = ex["instances"].unsqueeze(dim=0) + ex["frame_idx"] = ex["frame_idx"].unsqueeze(dim=0) + ex["video_idx"] = ex["video_idx"].unsqueeze(dim=0) + ex["orig_size"] = ex["orig_size"].unsqueeze(dim=0) + trainer._initialize_model() model = trainer.model @@ -99,12 +79,12 @@ def test_centroid_inference_model(config): crop_hw=(160, 160), ) - out = layer(loader) + out = layer(ex) assert tuple(out["centroids"].shape) == (1, 1, 6, 2) assert tuple(out["centroid_vals"].shape) == (1, 6) assert "instance_image" not in out.keys() - # return crops = False + # return crops = True layer = CentroidCrop( torch_model=model, peak_threshold=0.0, @@ -116,7 +96,7 @@ def test_centroid_inference_model(config): return_crops=True, crop_hw=(160, 160), ) - out = layer(loader) + out = layer(ex) assert len(out) == 1 out = out[0] assert tuple(out["centroid"].shape) == (2, 2) @@ -129,28 +109,21 @@ def test_find_instance_peaks_groundtruth( config, minimal_instance, minimal_instance_ckpt, minimal_instance_centroid_ckpt ): """Test FindInstancePeaksGroundTruth class for running inference on centroid model without centered instance model.""" - data_provider = LabelsReader.from_filename(minimal_instance, instances_key=True) - pipeline = SizeMatcher( - data_provider, - max_height=None, - max_width=None, - ) - pipeline = Normalizer(pipeline, is_rgb=False) - pipeline = InstanceCentroidFinder( - pipeline, - anchor_ind=0, - ) - - pipeline = pipeline.sharding_filter() - data_pipeline = DataLoader( - pipeline, - batch_size=4, - ) - - p = iter(data_pipeline) - example = next(p) + labels = sio.load_slp(minimal_instance) + ex = process_lf(labels[0], 0, 2) + ex["image"] = apply_normalization(ex["image"]).unsqueeze(dim=0) + ex["instances"] = ex["instances"].unsqueeze(dim=0) + ex["frame_idx"] = ex["frame_idx"].unsqueeze(dim=0) + ex["video_idx"] = ex["video_idx"].unsqueeze(dim=0) + ex["orig_size"] = ex["orig_size"].unsqueeze(dim=0) + # ex["centroids"] = generate_centroids(ex["instances"], 0) + + example = ex topdown_inf_layer = TopDownInferenceModel( - centroid_crop=None, instance_peaks=FindInstancePeaksGroundTruth() + centroid_crop=CentroidCrop( + use_gt_centroids=True, anchor_ind=0, crop_hw=(160, 160) + ), + instance_peaks=FindInstancePeaksGroundTruth(), ) output = topdown_inf_layer(example)[0] assert torch.isclose( @@ -215,12 +188,38 @@ def test_find_instance_peaks_groundtruth( def test_find_instance_peaks(config, minimal_instance, minimal_instance_ckpt): """Test FindInstancePeaks class to run inference on the Centered instance model.""" - data_pipeline, torch_model, find_peaks_layer = initialize_model( + torch_model, find_peaks_layer = initialize_model( config, minimal_instance, minimal_instance_ckpt ) + labels = sio.load_slp(minimal_instance) + ex = process_lf(labels[0], 0, 2) + ex["image"] = apply_normalization(ex["image"]).unsqueeze(dim=0) + ex["instances"] = ex["instances"].unsqueeze(dim=0) + ex["frame_idx"] = ex["frame_idx"].unsqueeze(dim=0) + ex["video_idx"] = ex["video_idx"].unsqueeze(dim=0) + ex["orig_size"] = ex["orig_size"].unsqueeze(dim=0) + ex["centroids"] = generate_centroids(ex["instances"], 0) + ex["instances"], centroids = ( + ex["instances"][0, 0], + ex["centroids"][0, 0], + ) # n_samples=1 + + for cnt, (instance, centroid) in enumerate(zip(ex["instances"], centroids)): + if cnt == ex["num_instances"]: + break + + res = generate_crops(ex["image"][0], instance, centroid, (160, 160)) + + res["frame_idx"] = ex["frame_idx"] + res["video_idx"] = ex["video_idx"] + res["num_instances"] = ex["num_instances"] + res["orig_size"] = ex["orig_size"] + res["instance_image"] = res["instance_image"].unsqueeze(dim=0) + + break + outputs = [] - for x in data_pipeline: - outputs.append(find_peaks_layer(x)) + outputs.append(find_peaks_layer(res)) keys = outputs[0].keys() assert "pred_instance_peaks" in keys and "pred_peak_values" in keys assert "pred_confmaps" not in keys @@ -236,8 +235,7 @@ def test_find_instance_peaks(config, minimal_instance, minimal_instance_ckpt): return_confmaps=False, ) outputs = [] - for x in data_pipeline: - outputs.append(find_peaks_layer(x)) + outputs.append(find_peaks_layer(res)) for i in outputs: instance = i["pred_instance_peaks"].numpy() assert np.all(np.isnan(instance)) @@ -251,8 +249,7 @@ def test_find_instance_peaks(config, minimal_instance, minimal_instance_ckpt): input_scale=0.5, ) outputs = [] - for x in data_pipeline: - outputs.append(find_peaks_layer(x)) + outputs.append(find_peaks_layer(res)) assert "pred_confmaps" in outputs[0].keys() assert outputs[0]["pred_confmaps"].shape[-2:] == (40, 40) @@ -262,44 +259,39 @@ def test_topdown_inference_model( ): """Test TopDownInferenceModel class for centroid and cenetered model inferences.""" # for centered instance model - loader, _, find_peaks_layer = initialize_model( + _, find_peaks_layer = initialize_model( config, minimal_instance, minimal_instance_ckpt ) - data_provider = LabelsReader.from_filename(minimal_instance, instances_key=True) - pipeline = SizeMatcher( - data_provider, - max_height=None, - max_width=None, - ) - pipeline = Normalizer(pipeline, is_rgb=False) - pipeline = InstanceCentroidFinder( - pipeline, - anchor_ind=0, - ) - - pipeline = pipeline.sharding_filter() - data_pipeline = DataLoader( - pipeline, - batch_size=4, - ) + labels = sio.load_slp(minimal_instance) + ex = process_lf(labels[0], 0, 2) + ex["image"] = apply_normalization(ex["image"]).unsqueeze(dim=0) + ex["instances"] = ex["instances"].unsqueeze(dim=0) + ex["frame_idx"] = ex["frame_idx"].unsqueeze(dim=0) + ex["video_idx"] = ex["video_idx"].unsqueeze(dim=0) + ex["orig_size"] = ex["orig_size"].unsqueeze(dim=0) - # if centroid layer is none and centered-instance model. + # if gt centroids and centered-instance model. topdown_inf_layer = TopDownInferenceModel( - centroid_crop=None, instance_peaks=find_peaks_layer + centroid_crop=CentroidCrop( + use_gt_centroids=True, anchor_ind=0, crop_hw=(160, 160), return_crops=True + ), + instance_peaks=find_peaks_layer, ) outputs = [] - for x in loader: - outputs.append(topdown_inf_layer(x)) + outputs.append(topdown_inf_layer(ex)) for i in outputs[0]: assert i["centroid_val"][0] == 1 assert "pred_instance_peaks" in i and "pred_peak_values" in i - # if centroid layer is none and "instances" not in data + # if gt centroids and "instances" not in data topdown_inf_layer = TopDownInferenceModel( - centroid_crop=None, instance_peaks=FindInstancePeaksGroundTruth() + centroid_crop=CentroidCrop( + use_gt_centroids=True, anchor_ind=0, crop_hw=(160, 160), return_crops=True + ), + instance_peaks=FindInstancePeaksGroundTruth(), ) - example = next(iter(data_pipeline)) + example = ex del example["instances"] with pytest.raises( @@ -317,19 +309,6 @@ def test_topdown_inference_model( model_type="centroid", ) - data_provider = LabelsReader.from_filename(minimal_instance, instances_key=True) - pipeline = SizeMatcher( - data_provider, - max_height=None, - max_width=None, - ) - pipeline = Normalizer(pipeline, is_rgb=False) - pipeline = pipeline.sharding_filter() - data_pipeline = DataLoader( - pipeline, - batch_size=4, - ) - centroid_layer = CentroidCrop( torch_model=torch_model, peak_threshold=0.0, @@ -345,7 +324,7 @@ def test_topdown_inference_model( topdown_inf_layer = TopDownInferenceModel( centroid_crop=centroid_layer, instance_peaks=find_peaks_layer ) - outputs = topdown_inf_layer(next(iter(data_pipeline))) + outputs = topdown_inf_layer(ex) for i in outputs: assert i["instance_image"].shape[1:] == (1, 1, 160, 160) assert i["pred_instance_peaks"].shape[1:] == (2, 2) From a7d8e22e60c6befc8aa5b60f5aee34bdcead6041 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 27 Sep 2024 18:40:00 -0700 Subject: [PATCH 08/11] Fix tracker tests --- tests/tracking/candidates/test_fixed_window.py | 2 +- tests/tracking/candidates/test_local_queues.py | 2 +- tests/tracking/test_tracker.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tracking/candidates/test_fixed_window.py b/tests/tracking/candidates/test_fixed_window.py index 8c131714..980ac3b0 100644 --- a/tests/tracking/candidates/test_fixed_window.py +++ b/tests/tracking/candidates/test_fixed_window.py @@ -10,7 +10,7 @@ def get_pred_instances(minimal_instance_ckpt, n=10): result_labels = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=True, max_instances=6, peak_threshold=0.0, diff --git a/tests/tracking/candidates/test_local_queues.py b/tests/tracking/candidates/test_local_queues.py index 76efcdef..e6b5993c 100644 --- a/tests/tracking/candidates/test_local_queues.py +++ b/tests/tracking/candidates/test_local_queues.py @@ -10,7 +10,7 @@ def get_pred_instances(minimal_instance_ckpt, n=10): result_labels = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=True, max_instances=6, peak_threshold=0.0, diff --git a/tests/tracking/test_tracker.py b/tests/tracking/test_tracker.py index 99d95ef4..adc43a36 100644 --- a/tests/tracking/test_tracker.py +++ b/tests/tracking/test_tracker.py @@ -15,7 +15,7 @@ def get_pred_instances(minimal_instance_ckpt): result_labels = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelsReader", + provider="LabelReader", make_labels=True, max_instances=6, peak_threshold=0.0, From 3ab80078daea5b333e811de281c689d7bbe8248a Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Mon, 30 Sep 2024 11:06:51 -0700 Subject: [PATCH 09/11] Fix frame index --- sleap_nn/data/providers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index 7838ac01..6292fcf8 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -343,7 +343,7 @@ def run(self): sample = { "image": torch.from_numpy(img), - "frame_idx": torch.tensor(idx, dtype=torch.int32), + "frame_idx": torch.tensor(lf.frame_idx, dtype=torch.int32), "video_idx": torch.tensor( self.labels.videos.index(lf.video), dtype=torch.int32 ), From 2d6bc83156fb0bcc167dda072a5f12f250867536 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Tue, 1 Oct 2024 16:43:33 -0700 Subject: [PATCH 10/11] Fix changes --- sleap_nn/data/pipelines.py | 4 +- sleap_nn/data/providers.py | 8 +-- sleap_nn/data/resizing.py | 4 +- sleap_nn/inference/predictors.py | 49 ++++++++++--------- sleap_nn/inference/topdown.py | 4 +- sleap_nn/training/model_trainer.py | 4 -- tests/data/test_augmentation.py | 8 +-- tests/data/test_confmaps.py | 8 +-- tests/data/test_edge_maps.py | 4 +- tests/data/test_instance_centroids.py | 4 +- tests/data/test_instance_cropping.py | 4 +- tests/data/test_normalization.py | 6 +-- tests/data/test_pipelines.py | 28 +++++------ tests/data/test_providers.py | 19 ++++--- tests/data/test_resizing.py | 8 +-- tests/fixtures/datasets.py | 2 +- tests/inference/test_predictors.py | 32 ++++++------ tests/inference/test_topdown.py | 4 +- .../tracking/candidates/test_fixed_window.py | 2 +- .../tracking/candidates/test_local_queues.py | 2 +- tests/tracking/test_tracker.py | 2 +- 21 files changed, 104 insertions(+), 102 deletions(-) diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index 064991a2..b92f16b3 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -252,7 +252,7 @@ def make_training_pipeline( Args: data_provider: A `Provider` that generates data examples, typically a - `LabelsReader` instance. + `LabelsReaderDP` instance. use_augmentations: `True` if augmentations should be applied to the training pipeline, else `False`. Default: `False`. @@ -353,7 +353,7 @@ def make_training_pipeline( Args: data_provider: A `Provider` that generates data examples, typically a - `LabelsReader` instance. + `LabelsReaderDP` instance. use_augmentations: `True` if augmentations should be applied to the training pipeline, else `False`. Default: `False`. diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index 6292fcf8..7a4a98d1 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -93,7 +93,7 @@ def process_lf( return ex -class LabelsReader(IterDataPipe): +class LabelsReaderDP(IterDataPipe): """IterDataPipe for reading frames from Labels object. This IterDataPipe will produce examples containing a frame and an sleap_io.Instance @@ -152,7 +152,7 @@ def from_filename( user_instances_only: bool = True, instances_key: bool = True, ): - """Create LabelsReader from a .slp filename.""" + """Create LabelsReaderDP from a .slp filename.""" labels = sio.load_slp(filename) return cls(labels, user_instances_only, instances_key) @@ -252,7 +252,7 @@ def from_filename( start_idx: Optional[int] = None, end_idx: Optional[int] = None, ): - """Create LabelsReader from a .slp filename.""" + """Create VideoReader from a .slp filename.""" video = sio.load_video(filename) frame_buffer = Queue(maxsize=queue_maxsize) return cls(video, frame_buffer, start_idx, end_idx) @@ -288,7 +288,7 @@ def run(self): ) -class LabelReader(Thread): +class LabelsReader(Thread): """Thread module for reading images from sleap-io Labels object. This module will load the images from `.slp` files and pushes them as Tensors into a diff --git a/sleap_nn/data/resizing.py b/sleap_nn/data/resizing.py index 4193b3ca..fc442fa9 100644 --- a/sleap_nn/data/resizing.py +++ b/sleap_nn/data/resizing.py @@ -4,7 +4,7 @@ import torch import torch.nn.functional as F -from sleap_nn.data.providers import LabelsReader, VideoReader +from sleap_nn.data.providers import LabelsReaderDP, VideoReader import torchvision.transforms.v2.functional as tvf from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -230,7 +230,7 @@ class SizeMatcher(IterDataPipe): def __init__( self, source_datapipe: IterDataPipe, - provider: Optional[Union[LabelsReader, VideoReader]] = None, + provider: Optional[Union[LabelsReaderDP, VideoReader]] = None, max_height: Optional[int] = None, max_width: Optional[int] = None, ): diff --git a/sleap_nn/inference/predictors.py b/sleap_nn/inference/predictors.py index 1aac99d1..0df59c07 100644 --- a/sleap_nn/inference/predictors.py +++ b/sleap_nn/inference/predictors.py @@ -12,7 +12,7 @@ import lightning as L import litdata as ld from omegaconf import OmegaConf -from sleap_nn.data.providers import LabelReader, VideoReader +from sleap_nn.data.providers import LabelsReader, VideoReader from sleap_nn.data.resizing import ( resize_image, apply_pad_to_stride, @@ -58,9 +58,9 @@ class Predictor(ABC): preprocess_config: Preprocessing config with keys: [`batch_size`, `scale`, `is_rgb`, `max_stride`]. Default: {"batch_size": 4, "scale": 1.0, "is_rgb": False, "max_stride": 1} - provider: Provider for inference pipeline. One of ["LabelReader", "VideoReader"]. - Default: LabelReader. - pipeline: If provider is LabelReader, pipeline is a `DataLoader` object. If provider + provider: Provider for inference pipeline. One of ["LabelsReader", "VideoReader"]. + Default: LabelsReader. + pipeline: If provider is LabelsReader, pipeline is a `DataLoader` object. If provider is VideoReader, pipeline is an instance of `sleap_nn.data.providers.VideoReader` class. Default: None. inference_model: Instance of one of the inference models ["TopDownInferenceModel", @@ -75,8 +75,8 @@ class Predictor(ABC): "is_rgb": False, "max_stride": 1, } - provider: Union[LabelReader, VideoReader] = LabelReader - pipeline: Optional[Union[LabelReader, VideoReader]] = None + provider: Union[LabelsReader, VideoReader] = LabelsReader + pipeline: Optional[Union[LabelsReader, VideoReader]] = None inference_model: Optional[ Union[ TopDownInferenceModel, SingleInstanceInferenceModel, BottomUpInferenceModel @@ -265,6 +265,7 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: if self.instances_key: instances.append(frame["instances"].unsqueeze(dim=0)) if imgs: + # TODO: all preprocessing should be moved into InferenceModels to be exportable. imgs = torch.concatenate(imgs, dim=0) fidxs = torch.tensor(fidxs, dtype=torch.int32) vidxs = torch.tensor(vidxs, dtype=torch.int32) @@ -522,7 +523,7 @@ def from_trained_models( An instance of `TopDownPredictor` with the loaded models. One of the two models can be left as `None` to perform inference with ground - truth data. This will only work with `LabelReader` as the provider. + truth data. This will only work with `LabelsReader` as the provider. """ if centroid_ckpt_path is not None: @@ -591,7 +592,7 @@ def make_pipeline( Args: provider: (str) Provider class to read the input sleap files. - Either "LabelReader" or "VideoReader". + Either "LabelsReader" or "VideoReader". data_path: (str) Path to `.slp` file or `.mp4` to run inference on. queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8. video_start_idx: (int) Start index of the frames to read. Default: None. @@ -603,9 +604,9 @@ def make_pipeline( """ self.provider = provider - # LabelReader provider - if self.provider == "LabelReader": - provider = LabelReader + # LabelsReader provider + if self.provider == "LabelsReader": + provider = LabelsReader if self.centroid_config is not None: max_stride = ( @@ -660,7 +661,7 @@ def make_pipeline( else: raise Exception( - "Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider" + "Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider" ) def _make_labeled_frames_from_generator( @@ -882,7 +883,7 @@ def make_pipeline( Args: provider: (str) Provider class to read the input sleap files. - Either "LabelReader" or "VideoReader". + Either "LabelsReader" or "VideoReader". data_path: (str) Path to `.slp` file or `.mp4` to run inference on. queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8. video_start_idx: (int) Start index of the frames to read. Default: None. @@ -895,9 +896,9 @@ def make_pipeline( """ self.provider = provider - # LabelReader provider - if self.provider == "LabelReader": - provider = LabelReader + # LabelsReader provider + if self.provider == "LabelsReader": + provider = LabelsReader max_stride = self.confmap_config.model_config.backbone_config.max_stride @@ -938,7 +939,7 @@ def make_pipeline( else: raise Exception( - "Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider" + "Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider" ) def _make_labeled_frames_from_generator( @@ -1191,7 +1192,7 @@ def make_pipeline( Args: provider: (str) Provider class to read the input sleap files. - Either "LabelReader" or "VideoReader". + Either "LabelsReader" or "VideoReader". data_path: (str) Path to `.slp` file or `.mp4` to run inference on. queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8. video_start_idx: (int) Start index of the frames to read. Default: None. @@ -1202,9 +1203,9 @@ def make_pipeline( Thread is started in Predictor._predict_generator() method. """ self.provider = provider - # LabelReader provider - if self.provider == "LabelReader": - provider = LabelReader + # LabelsReader provider + if self.provider == "LabelsReader": + provider = LabelsReader max_stride = self.bottomup_config.model_config.backbone_config.max_stride @@ -1245,7 +1246,7 @@ def make_pipeline( else: raise Exception( - "Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider" + "Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider" ) def _make_labeled_frames_from_generator( @@ -1346,7 +1347,7 @@ def main( max_width: int = None, max_height: int = None, is_rgb: bool = False, - provider: str = "LabelReader", + provider: str = "LabelsReader", batch_size: int = 4, queue_maxsize: int = 8, videoreader_start_idx: Optional[int] = None, @@ -1397,7 +1398,7 @@ def main( is set to False, then we convert the image to grayscale (single-channel) image. Default: False. provider: (str) Provider class to read the input sleap files. - Either "LabelReader" or "VideoReader". Default: LabelReader. + Either "LabelsReader" or "VideoReader". Default: LabelsReader. batch_size: (int) Number of samples per batch. Default: 4. queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8. videoreader_start_idx: (int) Start index of the frames to read. Default: None. diff --git a/sleap_nn/inference/topdown.py b/sleap_nn/inference/topdown.py index 6c1c1ad3..48a0bcfc 100644 --- a/sleap_nn/inference/topdown.py +++ b/sleap_nn/inference/topdown.py @@ -58,7 +58,7 @@ class CentroidCrop(L.LightningModule): def __init__( self, - torch_model: L.LightningModule = None, + torch_model: Optional[L.LightningModule] = None, output_stride: int = 1, peak_threshold: float = 0.0, max_instances: Optional[int] = None, @@ -70,7 +70,7 @@ def __init__( input_scale: float = 1.0, max_stride: int = 1, use_gt_centroids: bool = False, - anchor_ind: int = None, + anchor_ind: Optional[int] = None, **kwargs, ): """Initialise the model attributes.""" diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index a7f5978a..1038904b 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -9,11 +9,9 @@ import shutil import torch import sleap_io as sio -from torch.utils.data import DataLoader from omegaconf import OmegaConf import lightning as L import litdata as ld -from sleap_nn.data.providers import LabelsReader from sleap_nn.data.pipelines import ( TopdownConfmapsPipeline, SingleInstanceConfmapsPipeline, @@ -145,8 +143,6 @@ def _get_data_chunks(self, func, train_labels, val_labels): def _create_data_loaders(self): """Create a DataLoader for train, validation and test sets using the data_config.""" self.provider = self.config.data_config.provider - if self.provider == "LabelsReader": - self.provider = LabelsReader train_labels = sio.load_slp(self.config.data_config.train_labels_path) val_labels = sio.load_slp(self.config.data_config.val_labels_path) diff --git a/tests/data/test_augmentation.py b/tests/data/test_augmentation.py index e0a711bc..254c336e 100644 --- a/tests/data/test_augmentation.py +++ b/tests/data/test_augmentation.py @@ -10,12 +10,12 @@ from sleap_nn.data.normalization import apply_normalization from sleap_nn.data.providers import process_lf from sleap_nn.data.normalization import Normalizer -from sleap_nn.data.providers import LabelsReader +from sleap_nn.data.providers import LabelsReaderDP def test_uniform_noise(minimal_instance): """Test RandomUniformNoise module.""" - p = LabelsReader.from_filename(minimal_instance) + p = LabelsReaderDP.from_filename(minimal_instance) p = Normalizer(p) sample = next(iter(p)) @@ -99,7 +99,7 @@ def test_apply_geometric_augmentation(minimal_instance): def test_kornia_augmentation(minimal_instance): """Test KorniaAugmenter module.""" - p = LabelsReader.from_filename(minimal_instance) + p = LabelsReaderDP.from_filename(minimal_instance) p = Normalizer(p) p = KorniaAugmenter( @@ -127,7 +127,7 @@ def test_kornia_augmentation(minimal_instance): assert pts.shape == (1, 2, 2, 2) # Test RandomCrop value error. - p = LabelsReader.from_filename(minimal_instance) + p = LabelsReaderDP.from_filename(minimal_instance) p = Normalizer(p) with pytest.raises( ValueError, match="crop_hw height and width must be greater than 0." diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index 691f08d1..f09df6f4 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -12,14 +12,14 @@ from sleap_nn.data.instance_cropping import InstanceCropper from sleap_nn.data.normalization import Normalizer from sleap_nn.data.resizing import Resizer -from sleap_nn.data.providers import LabelsReader, process_lf +from sleap_nn.data.providers import LabelsReaderDP, process_lf from sleap_nn.data.utils import make_grid_vectors import numpy as np def test_confmaps(minimal_instance): """Test ConfidenceMapGenerator module.""" - datapipe = LabelsReader.from_filename(minimal_instance) + datapipe = LabelsReaderDP.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) datapipe = Normalizer(datapipe) datapipe = InstanceCropper(datapipe, (100, 100)) @@ -72,7 +72,7 @@ def test_confmaps(minimal_instance): def test_multi_confmaps(minimal_instance): """Test MultiConfidenceMapGenerator module.""" # centroids = True - datapipe = LabelsReader.from_filename(minimal_instance) + datapipe = LabelsReaderDP.from_filename(minimal_instance) datapipe = Normalizer(datapipe) datapipe = InstanceCentroidFinder(datapipe) datapipe1 = MultiConfidenceMapGenerator( @@ -107,7 +107,7 @@ def test_multi_confmaps(minimal_instance): torch.testing.assert_close(gt, cms[0][0], atol=0.001, rtol=0.0) # centroids = False (for instances) - datapipe = LabelsReader.from_filename(minimal_instance) + datapipe = LabelsReaderDP.from_filename(minimal_instance) datapipe = Normalizer(datapipe) datapipe = Resizer(datapipe, scale=2) datapipe = InstanceCentroidFinder(datapipe) diff --git a/tests/data/test_edge_maps.py b/tests/data/test_edge_maps.py index d994932e..ecf32c7d 100644 --- a/tests/data/test_edge_maps.py +++ b/tests/data/test_edge_maps.py @@ -2,7 +2,7 @@ import torch import sleap_io as sio from sleap_nn.data.utils import make_grid_vectors -from sleap_nn.data.providers import LabelsReader, process_lf +from sleap_nn.data.providers import LabelsReaderDP, process_lf from sleap_nn.data.edge_maps import ( distance_to_edge, make_edge_maps, @@ -190,7 +190,7 @@ def test_generate_pafs(minimal_instance): def test_part_affinity_fields_generator(minimal_instance): - provider = LabelsReader.from_filename(minimal_instance) + provider = LabelsReaderDP.from_filename(minimal_instance) paf_generator = PartAffinityFieldsGenerator( provider, sigma=8, diff --git a/tests/data/test_instance_centroids.py b/tests/data/test_instance_centroids.py index 245c0c45..3b4db2a7 100644 --- a/tests/data/test_instance_centroids.py +++ b/tests/data/test_instance_centroids.py @@ -4,7 +4,7 @@ InstanceCentroidFinder, generate_centroids, ) -from sleap_nn.data.providers import LabelsReader, process_lf +from sleap_nn.data.providers import LabelsReaderDP, process_lf def test_generate_centroids(minimal_instance): @@ -38,7 +38,7 @@ def test_generate_centroids(minimal_instance): def test_instance_centroids(minimal_instance): """Test InstanceCentroidFinder and generate_centroids functions.""" # Undefined anchor_ind. - datapipe = LabelsReader.from_filename(minimal_instance) + datapipe = LabelsReaderDP.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) sample = next(iter(datapipe)) instances = sample["instances"] diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index c93df4ef..4d60aa3e 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -10,7 +10,7 @@ ) from sleap_nn.data.normalization import Normalizer, apply_normalization from sleap_nn.data.resizing import SizeMatcher, Resizer, PadToStride -from sleap_nn.data.providers import LabelsReader, process_lf +from sleap_nn.data.providers import LabelsReaderDP, process_lf def test_find_instance_crop_size(minimal_instance): @@ -44,7 +44,7 @@ def test_make_centered_bboxes(): def test_instance_cropper(minimal_instance): """Test InstanceCropper module.""" - provider = LabelsReader.from_filename(minimal_instance) + provider = LabelsReaderDP.from_filename(minimal_instance) provider.max_instances = 3 datapipe = Normalizer(provider) datapipe = SizeMatcher(datapipe, provider) diff --git a/tests/data/test_normalization.py b/tests/data/test_normalization.py index bcf71837..a94fec83 100644 --- a/tests/data/test_normalization.py +++ b/tests/data/test_normalization.py @@ -6,12 +6,12 @@ convert_to_grayscale, apply_normalization, ) -from sleap_nn.data.providers import LabelsReader +from sleap_nn.data.providers import LabelsReaderDP def test_normalizer(minimal_instance): """Test Normalizer module.""" - p = LabelsReader.from_filename(minimal_instance) + p = LabelsReaderDP.from_filename(minimal_instance) p = Normalizer(p) ex = next(iter(p)) @@ -19,7 +19,7 @@ def test_normalizer(minimal_instance): assert ex["image"].shape[-3] == 1 # test is_rgb - p = LabelsReader.from_filename(minimal_instance) + p = LabelsReaderDP.from_filename(minimal_instance) p = Normalizer(p, is_rgb=True) ex = next(iter(p)) diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index e313be0b..78e92135 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -15,12 +15,12 @@ CentroidConfmapsPipeline, BottomUpPipeline, ) -from sleap_nn.data.providers import LabelsReader +from sleap_nn.data.providers import LabelsReaderDP def test_key_filter(minimal_instance): """Test KeyFilter module.""" - datapipe = LabelsReader.from_filename(filename=minimal_instance) + datapipe = LabelsReaderDP.from_filename(filename=minimal_instance) datapipe = Normalizer(datapipe) datapipe = SizeMatcher(datapipe) datapipe = Resizer(datapipe) @@ -58,7 +58,7 @@ def test_key_filter(minimal_instance): assert sample["video_idx"] == 0 assert sample["num_instances"] == 2 - datapipe = LabelsReader.from_filename(filename=minimal_instance) + datapipe = LabelsReaderDP.from_filename(filename=minimal_instance) datapipe = Normalizer(datapipe) datapipe = SizeMatcher(datapipe) datapipe = Resizer(datapipe, keep_original=True) @@ -113,7 +113,7 @@ def test_topdownconfmapspipeline(minimal_instance): confmap_head=confmap_head, crop_hw=crop_hw, ) - data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + data_provider = LabelsReaderDP(labels=sio.load_slp(minimal_instance)) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, @@ -192,7 +192,7 @@ def test_topdownconfmapspipeline(minimal_instance): crop_hw=(100, 100), ) - data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + data_provider = LabelsReaderDP(labels=sio.load_slp(minimal_instance)) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, use_augmentations=base_topdown_data_config.use_augmentations_train, @@ -272,7 +272,7 @@ def test_topdownconfmapspipeline(minimal_instance): crop_hw=(100, 100), ) - data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + data_provider = LabelsReaderDP(labels=sio.load_slp(minimal_instance)) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, use_augmentations=base_topdown_data_config.use_augmentations_train, @@ -326,7 +326,7 @@ def test_singleinstanceconfmapspipeline(minimal_instance): max_stride=8, confmap_head=confmap_head, ) - data_provider = LabelsReader(labels=labels) + data_provider = LabelsReaderDP(labels=labels) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, @@ -401,7 +401,7 @@ def test_singleinstanceconfmapspipeline(minimal_instance): confmap_head=confmap_head, ) - data_provider = LabelsReader(labels=labels) + data_provider = LabelsReaderDP(labels=labels) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, use_augmentations=base_singleinstance_data_config.use_augmentations_train, @@ -443,7 +443,7 @@ def test_centroidconfmapspipeline(minimal_instance): pipeline = CentroidConfmapsPipeline( data_config=base_centroid_data_config, max_stride=32, confmap_head=confmap_head ) - data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + data_provider = LabelsReaderDP(labels=sio.load_slp(minimal_instance)) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, @@ -516,7 +516,7 @@ def test_centroidconfmapspipeline(minimal_instance): data_config=base_centroid_data_config, max_stride=32, confmap_head=confmap_head ) - data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + data_provider = LabelsReaderDP(labels=sio.load_slp(minimal_instance)) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, use_augmentations=base_centroid_data_config.use_augmentations_train, @@ -563,7 +563,7 @@ def test_bottomuppipeline(minimal_instance): confmap_head=confmap_head, pafs_head=pafs_head, ) - data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + data_provider = LabelsReaderDP(labels=sio.load_slp(minimal_instance)) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, @@ -607,7 +607,7 @@ def test_bottomuppipeline(minimal_instance): confmap_head=confmap_head, pafs_head=pafs_head, ) - data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + data_provider = LabelsReaderDP(labels=sio.load_slp(minimal_instance)) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, @@ -685,7 +685,7 @@ def test_bottomuppipeline(minimal_instance): confmap_head=confmap_head, pafs_head=pafs_head, ) - data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + data_provider = LabelsReaderDP(labels=sio.load_slp(minimal_instance)) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, @@ -765,7 +765,7 @@ def test_bottomuppipeline(minimal_instance): pafs_head=pafs_head, ) - data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) + data_provider = LabelsReaderDP(labels=sio.load_slp(minimal_instance)) datapipe = pipeline.make_training_pipeline( data_provider=data_provider, use_augmentations=base_bottom_config.use_augmentations_train, diff --git a/tests/data/test_providers.py b/tests/data/test_providers.py index 06e740e4..bef1e94e 100644 --- a/tests/data/test_providers.py +++ b/tests/data/test_providers.py @@ -1,6 +1,11 @@ import torch -from sleap_nn.data.providers import LabelsReader, LabelReader, VideoReader, process_lf +from sleap_nn.data.providers import ( + LabelsReaderDP, + LabelsReader, + VideoReader, + process_lf, +) from queue import Queue import sleap_io as sio import numpy as np @@ -9,8 +14,8 @@ def test_providers(minimal_instance): - """Test LabelsReader module.""" - l = LabelsReader.from_filename(minimal_instance) + """Test LabelsReaderDP module.""" + l = LabelsReaderDP.from_filename(minimal_instance) sample = next(iter(l)) instances, image = sample["instances"], sample["image"] assert image.shape == torch.Size([1, 1, 384, 384]) @@ -92,11 +97,11 @@ def test_videoreader_provider(centered_instance_video): assert reader.total_len() == 6 -def test_labelreader_provider(minimal_instance): - """Test LabelReader class.""" +def test_labelsreader_provider(minimal_instance): + """Test LabelsReader class.""" labels = sio.load_slp(minimal_instance) queue = Queue(maxsize=4) - reader = LabelReader(labels=labels, frame_buffer=queue, instances_key=False) + reader = LabelsReader(labels=labels, frame_buffer=queue, instances_key=False) assert reader.max_height_and_width == (384, 384) reader.start() batch_size = 1 @@ -117,7 +122,7 @@ def test_labelreader_provider(minimal_instance): assert reader.total_len() == 1 # with instances key - reader = LabelReader.from_filename( + reader = LabelsReader.from_filename( minimal_instance, queue_maxsize=4, instances_key=True ) assert reader.max_height_and_width == (384, 384) diff --git a/tests/data/test_resizing.py b/tests/data/test_resizing.py index 0d926bea..f70bee99 100644 --- a/tests/data/test_resizing.py +++ b/tests/data/test_resizing.py @@ -1,6 +1,6 @@ import torch -from sleap_nn.data.providers import LabelsReader, process_lf +from sleap_nn.data.providers import LabelsReaderDP, process_lf from sleap_nn.data.resizing import ( SizeMatcher, Resizer, @@ -16,7 +16,7 @@ def test_sizematcher(minimal_instance): """Test SizeMatcher module for pad images to specified dimensions.""" - l = LabelsReader.from_filename(minimal_instance) + l = LabelsReaderDP.from_filename(minimal_instance) pipe = SizeMatcher(l, provider=l) sample = next(iter(pipe)) instances, image = sample["instances"], sample["image"] @@ -49,7 +49,7 @@ def test_sizematcher(minimal_instance): def test_resizer(minimal_instance): """Test Resizer module for resizing images based on given scale.""" - l = LabelsReader.from_filename(minimal_instance) + l = LabelsReaderDP.from_filename(minimal_instance) pipe = Resizer(l, scale=2, keep_original=False) sample = next(iter(pipe)) image = sample["image"] @@ -65,7 +65,7 @@ def test_resizer(minimal_instance): def test_padtostride(minimal_instance): """Test PadToStride module to pad images based on max stride.""" - l = LabelsReader.from_filename(minimal_instance) + l = LabelsReaderDP.from_filename(minimal_instance) pipe = PadToStride(l, max_stride=200) sample = next(iter(pipe)) image = sample["image"] diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 45d9ba20..d8424f69 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -48,7 +48,7 @@ def config(sleap_data_dir): init_config = OmegaConf.create( { "data_config": { - "provider": "LabelsReader", + "provider": "LabelsReaderDP", "train_labels_path": f"{sleap_data_dir}/minimal_instance.pkg.slp", "val_labels_path": f"{sleap_data_dir}/minimal_instance.pkg.slp", "preprocessing": { diff --git a/tests/inference/test_predictors.py b/tests/inference/test_predictors.py index 8519ba68..68d7cf0e 100644 --- a/tests/inference/test_predictors.py +++ b/tests/inference/test_predictors.py @@ -17,7 +17,7 @@ def test_topdown_predictor( pred_labels = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", return_confmaps=False, make_labels=True, peak_threshold=0.0, @@ -40,7 +40,7 @@ def test_topdown_predictor( preds = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=False, peak_threshold=0.0, integral_refinement="integral", @@ -63,7 +63,7 @@ def test_topdown_predictor( preds = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=False, ) @@ -73,7 +73,7 @@ def test_topdown_predictor( pred_labels = main( model_paths=[minimal_instance_centroid_ckpt, minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=True, max_instances=6, peak_threshold=[0.0, 0.0], @@ -88,7 +88,7 @@ def test_topdown_predictor( pred_labels = main( model_paths=[minimal_instance_centroid_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=False, max_instances=max_instances, peak_threshold=0.1, @@ -119,7 +119,7 @@ def test_topdown_predictor( # Unrecognized provider with pytest.raises( Exception, - match="Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider", + match="Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider", ): pred_labels = main( model_paths=[minimal_instance_centroid_ckpt, minimal_instance_ckpt], @@ -184,7 +184,7 @@ def test_topdown_predictor( def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): """Test SingleInstancePredictor module.""" - # provider as LabelReader + # provider as LabelsReader _config = OmegaConf.load(f"{minimal_instance_ckpt}/training_config.yaml") config = _config.copy() @@ -204,7 +204,7 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): pred_labels = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=True, max_instances=6, peak_threshold=0.1, @@ -227,7 +227,7 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): preds = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=False, peak_threshold=0.3, max_height=500, @@ -312,7 +312,7 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): # check if labels are created from ckpt with pytest.raises( Exception, - match="Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider", + match="Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider", ): preds = main( model_paths=[minimal_instance_ckpt], @@ -329,13 +329,13 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): def test_bottomup_predictor(minimal_instance, minimal_instance_bottomup_ckpt): """Test BottomUpPredictor module.""" - # provider as LabelReader + # provider as LabelsReader # check if labels are created from ckpt pred_labels = main( model_paths=[minimal_instance_bottomup_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=True, max_instances=6, peak_threshold=0.03, @@ -356,7 +356,7 @@ def test_bottomup_predictor(minimal_instance, minimal_instance_bottomup_ckpt): preds = main( model_paths=[minimal_instance_bottomup_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=False, max_instances=6, peak_threshold=0.03, @@ -373,7 +373,7 @@ def test_bottomup_predictor(minimal_instance, minimal_instance_bottomup_ckpt): pred_labels = main( model_paths=[minimal_instance_bottomup_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=True, max_instances=6, peak_threshold=1.0, @@ -420,7 +420,7 @@ def test_bottomup_predictor(minimal_instance, minimal_instance_bottomup_ckpt): # unrecognized provider with pytest.raises( Exception, - match="Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider", + match="Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider", ): preds = main( model_paths=[minimal_instance_bottomup_ckpt], @@ -435,7 +435,7 @@ def test_bottomup_predictor(minimal_instance, minimal_instance_bottomup_ckpt): pred_labels = main( model_paths=[minimal_instance_bottomup_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=True, max_instances=6, peak_threshold=0.03, diff --git a/tests/inference/test_topdown.py b/tests/inference/test_topdown.py index 3512a114..fd459c1c 100644 --- a/tests/inference/test_topdown.py +++ b/tests/inference/test_topdown.py @@ -4,7 +4,7 @@ import torch from torch.utils.data.dataloader import DataLoader import sleap_io as sio -from sleap_nn.data.providers import process_lf, LabelsReader +from sleap_nn.data.providers import process_lf, LabelsReaderDP from sleap_nn.data.resizing import resize_image from sleap_nn.data.instance_centroids import InstanceCentroidFinder, generate_centroids from sleap_nn.data.normalization import apply_normalization, Normalizer @@ -136,7 +136,7 @@ def test_find_instance_peaks_groundtruth( # with centroid crop class config = OmegaConf.load(f"{minimal_instance_ckpt}/training_config.yaml") - data_provider = LabelsReader.from_filename(minimal_instance, instances_key=True) + data_provider = LabelsReaderDP.from_filename(minimal_instance, instances_key=True) pipeline = SizeMatcher( data_provider, max_height=None, diff --git a/tests/tracking/candidates/test_fixed_window.py b/tests/tracking/candidates/test_fixed_window.py index 980ac3b0..8c131714 100644 --- a/tests/tracking/candidates/test_fixed_window.py +++ b/tests/tracking/candidates/test_fixed_window.py @@ -10,7 +10,7 @@ def get_pred_instances(minimal_instance_ckpt, n=10): result_labels = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=True, max_instances=6, peak_threshold=0.0, diff --git a/tests/tracking/candidates/test_local_queues.py b/tests/tracking/candidates/test_local_queues.py index e6b5993c..76efcdef 100644 --- a/tests/tracking/candidates/test_local_queues.py +++ b/tests/tracking/candidates/test_local_queues.py @@ -10,7 +10,7 @@ def get_pred_instances(minimal_instance_ckpt, n=10): result_labels = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=True, max_instances=6, peak_threshold=0.0, diff --git a/tests/tracking/test_tracker.py b/tests/tracking/test_tracker.py index adc43a36..99d95ef4 100644 --- a/tests/tracking/test_tracker.py +++ b/tests/tracking/test_tracker.py @@ -15,7 +15,7 @@ def get_pred_instances(minimal_instance_ckpt): result_labels = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", - provider="LabelReader", + provider="LabelsReader", make_labels=True, max_instances=6, peak_threshold=0.0, From b89af6338a7da708ec37a714f0a9d3b46269777b Mon Sep 17 00:00:00 2001 From: DivyaSesh <64513125+gitttt-1234@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:07:18 -0700 Subject: [PATCH 11/11] Update test_model_trainer.py --- tests/training/test_model_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index 325e3a10..898ba955 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -46,14 +46,14 @@ def test_create_data_loader(config, tmp_path: str): shutil.rmtree((Path(model_trainer.dir_path) / "train_chunks").as_posix()) shutil.rmtree((Path(model_trainer.dir_path) / "val_chunks").as_posix()) -# # test exception -# config_copy = config.copy() -# head_config = config_copy.model_config.head_configs.centered_instance -# del config_copy.model_config.head_configs.centered_instance -# OmegaConf.update(config_copy, "model_config.head_configs.topdown", head_config) -# model_trainer = ModelTrainer(config_copy) -# with pytest.raises(Exception): -# model_trainer._create_data_loaders() + # test exception + config_copy = config.copy() + head_config = config_copy.model_config.head_configs.centered_instance + del config_copy.model_config.head_configs.centered_instance + OmegaConf.update(config_copy, "model_config.head_configs.topdown", head_config) + model_trainer = ModelTrainer(config_copy) + with pytest.raises(Exception): + model_trainer._create_data_loaders() def test_wandb():