From 02a76f957b34e27377dea9ad05734af152243d7e Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 11 Jun 2024 11:39:39 -0700 Subject: [PATCH 1/3] update docs --- dreem/io/visualize.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dreem/io/visualize.py b/dreem/io/visualize.py index 1c0095e3..ac20176c 100644 --- a/dreem/io/visualize.py +++ b/dreem/io/visualize.py @@ -4,6 +4,7 @@ from copy import deepcopy from tqdm import tqdm from omegaconf import DictConfig +from typing import Union import seaborn as sns import imageio From 0dfab7856135d9ca1f361f02dcbc0bf1ed44cc63 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 11 Jun 2024 18:54:36 -0700 Subject: [PATCH 2/3] update type hinting to follow python 3.11 syntax by using `|` instead of `typing.Union` and referring to standard collections directly rather than `typing.Collection` --- dreem/io/visualize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dreem/io/visualize.py b/dreem/io/visualize.py index ac20176c..1c0095e3 100644 --- a/dreem/io/visualize.py +++ b/dreem/io/visualize.py @@ -4,7 +4,6 @@ from copy import deepcopy from tqdm import tqdm from omegaconf import DictConfig -from typing import Union import seaborn as sns import imageio From 978942f41df9439413ef9c1d0ab0d6650a23e9bd Mon Sep 17 00:00:00 2001 From: aaprasad Date: Wed, 12 Jun 2024 11:57:45 -0700 Subject: [PATCH 3/3] * set up `logging` * replace print statements and warnings with logger emissions * get rid of unnecessary/commented out exceptions/print statements --- dreem/__init__.py | 16 ++++ dreem/datasets/data_utils.py | 2 - dreem/datasets/microscopy_dataset.py | 1 - dreem/datasets/sleap_dataset.py | 8 +- dreem/inference/metrics.py | 16 ++-- dreem/inference/track.py | 23 +++--- dreem/inference/track_queue.py | 30 ++++---- dreem/inference/tracker.py | 107 ++++++++++----------------- dreem/io/association_matrix.py | 13 +++- dreem/io/config.py | 20 ++--- dreem/io/frame.py | 16 ++-- dreem/io/instance.py | 7 +- dreem/io/visualize.py | 14 ++-- dreem/models/embedding.py | 7 +- dreem/models/gtr_runner.py | 11 ++- dreem/models/model_utils.py | 7 +- dreem/training/__init__.py | 2 - dreem/training/losses.py | 10 +-- dreem/training/train.py | 15 ++-- logging.yaml | 36 +++++++++ tests/test_datasets.py | 4 - 21 files changed, 197 insertions(+), 168 deletions(-) create mode 100644 logging.yaml diff --git a/dreem/__init__.py b/dreem/__init__.py index 5299cbe3..f5e35ccb 100644 --- a/dreem/__init__.py +++ b/dreem/__init__.py @@ -1,5 +1,6 @@ """Top-level package for dreem.""" +import logging.config from dreem.version import __version__ from dreem.models.global_tracking_transformer import GlobalTrackingTransformer @@ -16,3 +17,18 @@ # from .training import run from dreem.inference.tracker import Tracker + + +def setup_logging(): + """Setup logging based on `logging.yaml`.""" + import logging + import yaml + import os + + package_directory = os.path.dirname(os.path.abspath(__file__)) + + with open(os.path.join(package_directory, "..", "logging.yaml"), "r") as stream: + logging_cfg = yaml.load(stream, Loader=yaml.FullLoader) + + logging.config.dictConfig(logging_cfg) + logger = logging.getLogger("dreem") diff --git a/dreem/datasets/data_utils.py b/dreem/datasets/data_utils.py index 2eafb421..2557115d 100644 --- a/dreem/datasets/data_utils.py +++ b/dreem/datasets/data_utils.py @@ -90,7 +90,6 @@ def centroid_bbox(points: ArrayLike, anchors: list, crop_size: int) -> torch.Ten Returns: Bounding box in [y1, x1, y2, x2] format. """ - print(anchors) for anchor in anchors: cx, cy = points[anchor][0], points[anchor][1] if not np.isnan(cx): @@ -120,7 +119,6 @@ def pose_bbox(points: np.ndarray, bbox_size: tuple[int] | int) -> torch.Tensor: """ if isinstance(bbox_size, int): bbox_size = (bbox_size, bbox_size) - # print(points) c = np.nanmean(points, axis=0) bbox = torch.Tensor( diff --git a/dreem/datasets/microscopy_dataset.py b/dreem/datasets/microscopy_dataset.py index cda2e035..bc28f2c5 100644 --- a/dreem/datasets/microscopy_dataset.py +++ b/dreem/datasets/microscopy_dataset.py @@ -133,7 +133,6 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram frames = [] for frame_id in frame_idx: - # print(i) instances, gt_track_ids, centroids = [], [], [] img = ( diff --git a/dreem/datasets/sleap_dataset.py b/dreem/datasets/sleap_dataset.py index f3ffb6d8..f7297a54 100644 --- a/dreem/datasets/sleap_dataset.py +++ b/dreem/datasets/sleap_dataset.py @@ -6,11 +6,13 @@ import numpy as np import sleap_io as sio import random -import warnings +import logging from dreem.io import Instance, Frame from dreem.datasets import data_utils, BaseDataset from torchvision.transforms import functional as tvf +logger = logging.getLogger("dreem.datasets") + class SleapDataset(BaseDataset): """Dataset for loading animal behavior data from sleap.""" @@ -165,7 +167,9 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram img = np.expand_dims(img, 0) h, w, c = img.shape except IndexError as e: - print(f"Could not read frame {frame_ind} from {video_name} due to {e}") + logger.warning( + f"Could not read frame {frame_ind} from {video_name} due to {e}" + ) continue if len(img.shape) == 2: diff --git a/dreem/inference/metrics.py b/dreem/inference/metrics.py index d5be0d1e..5d96b12c 100644 --- a/dreem/inference/metrics.py +++ b/dreem/inference/metrics.py @@ -3,8 +3,11 @@ import numpy as np import motmetrics as mm import torch -from typing import Iterable import pandas as pd +import logging +from typing import Iterable + +logger = logging.getLogger("dreem.inference") # from dreem.inference.post_processing import _pairwise_iou # from dreem.inference.boxes import Boxes @@ -39,8 +42,8 @@ def get_matches(frames: list["dreem.io.Frame"]) -> tuple[dict, list, int]: matches[match] = np.full(len(frames), 0) matches[match][idx] = 1 - # else: - # warnings.warn("No instances detected!") + else: + logger.debug("No instances detected!") return matches, indices, video_id @@ -191,12 +194,7 @@ def to_track_eval(frames: list["dreem.io.Frame"]) -> dict: data["num_gt_ids"] = len(unique_gt_ids) data["num_tracker_dets"] = num_tracker_dets data["num_gt_dets"] = num_gt_dets - try: - data["gt_ids"] = gt_ids - # print(data['gt_ids']) - except Exception as e: - print(gt_ids) - raise (e) + data["gt_ids"] = gt_ids data["tracker_ids"] = track_ids data["similarity_scores"] = similarity_scores data["num_timesteps"] = len(frames) diff --git a/dreem/inference/track.py b/dreem/inference/track.py index ef4e387f..cc48a59d 100644 --- a/dreem/inference/track.py +++ b/dreem/inference/track.py @@ -4,7 +4,6 @@ from dreem.models import GTRRunner from omegaconf import DictConfig from pathlib import Path -from pprint import pprint import hydra import os @@ -12,6 +11,9 @@ import pytorch_lightning as pl import torch import sleap_io as sio +import logging + +logger = logging.getLogger("dreem.inference") def export_trajectories( @@ -76,16 +78,13 @@ def track( for frame in batch: lf, tracks = frame.to_slp(tracks) if frame.frame_id.item() == 0: - print(f"Video: {lf.video}") + logger.info(f"Video: {lf.video}") vid_trajectories[frame.video_id.item()].append(lf) for vid_id, video in vid_trajectories.items(): if len(video) > 0: - try: - vid_trajectories[vid_id] = sio.Labels(video) - except AttributeError as e: - print(video[0].video) - raise (e) + + vid_trajectories[vid_id] = sio.Labels(video) return vid_trajectories @@ -106,7 +105,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: except KeyError: index = input("Pod Index Not found! Please choose a pod index: ") - print(f"Pod Index: {index}") + logger.info(f"Pod Index: {index}") checkpoints = pd.read_csv(cfg.checkpoints) checkpoint = checkpoints.iloc[index] @@ -115,10 +114,10 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: model = GTRRunner.load_from_checkpoint(checkpoint) tracker_cfg = pred_cfg.get_tracker_cfg() - print("Updating tracker hparams") + logger.info("Updating tracker hparams") model.tracker_cfg = tracker_cfg - print(f"Using the following params for tracker:") - pprint(model.tracker_cfg) + logger.info(f"Using the following params for tracker:") + logger.info(model.tracker_cfg) dataset = pred_cfg.get_dataset(mode="test") dataloader = pred_cfg.get_dataloader(dataset, mode="test") @@ -139,7 +138,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: if os.path.exists(outpath): run_num += 1 outpath = outpath.replace(f".v{run_num-1}", f".v{run_num}") - print(f"Saving {preds} to {outpath}") + logger.info(f"Saving {preds} to {outpath}") pred.save(outpath) return preds diff --git a/dreem/inference/track_queue.py b/dreem/inference/track_queue.py index a3b29225..b215b7e0 100644 --- a/dreem/inference/track_queue.py +++ b/dreem/inference/track_queue.py @@ -1,11 +1,14 @@ """Module handling sliding window tracking.""" -import warnings from dreem.io import Frame from collections import deque -import numpy as np from torch import device +import logging +import numpy as np + +logger = logging.getLogger("dreem.inference") + class TrackQueue: """Class handling track local queue system for sliding window. @@ -175,7 +178,7 @@ def end_tracks(self, track_id: int | None = None) -> bool: self._queues.pop(track_id) self._curr_gap.pop(track_id) except KeyError: - print(f"Track ID {track_id} not found in queue!") + logger.exception(f"Track ID {track_id} not found in queue!") return False return True @@ -211,10 +214,9 @@ def add_frame(self, frame: Frame) -> None: ) # dumb work around to retain `img_shape` self.curr_track = pred_track_id - if self.verbose: - warnings.warn( - f"New track = {pred_track_id} on frame {frame_id}! Current number of tracks = {self.n_tracks}" - ) + logger.debug( + f"New track = {pred_track_id} on frame {frame_id}! Current number of tracks = {self.n_tracks}" + ) else: self._queues[pred_track_id].append((*frame_meta, instance)) @@ -288,10 +290,9 @@ def increment_gaps(self, pred_track_ids: list[int]) -> dict[int, bool]: for track in self._curr_gap: if track not in pred_track_ids: self._curr_gap[track] += 1 - if self.verbose: - warnings.warn( - f"Track {track} has not been seen for {self._curr_gap[track]} frames." - ) + logger.debug( + f"Track {track} has not been seen for {self._curr_gap[track]} frames." + ) else: self._curr_gap[track] = 0 if self._curr_gap[track] >= self.max_gap: @@ -301,10 +302,9 @@ def increment_gaps(self, pred_track_ids: list[int]) -> dict[int, bool]: for track, gap_exceeded in exceeded_gap.items(): if gap_exceeded: - if self.verbose: - warnings.warn( - f"Track {track} has not been seen for {self._curr_gap[track]} frames! Terminating Track...Current number of tracks = {self.n_tracks}." - ) + logger.debug( + f"Track {track} has not been seen for {self._curr_gap[track]} frames! Terminating Track...Current number of tracks = {self.n_tracks}." + ) self._queues.pop(track) self._curr_gap.pop(track) diff --git a/dreem/inference/tracker.py b/dreem/inference/tracker.py index 304713c3..fa1042c3 100644 --- a/dreem/inference/tracker.py +++ b/dreem/inference/tracker.py @@ -2,7 +2,8 @@ import torch import pandas as pd -import warnings +import logging + from dreem.io import Frame from dreem.models import model_utils, GlobalTrackingTransformer from dreem.inference.track_queue import TrackQueue @@ -11,6 +12,8 @@ from scipy.optimize import linear_sum_assignment from math import inf +logger = logging.getLogger("dreem.inference") + class Tracker: """Tracker class used for assignment based on sliding inference from GTR.""" @@ -120,8 +123,7 @@ def track( instances_pred = self.sliding_inference(model, frames) if not self.persistent_tracking: - if self.verbose: - warnings.warn(f"Clearing Queue after tracking") + logger.debug(f"Clearing Queue after tracking") self.track_queue.end_tracks() return instances_pred @@ -148,16 +150,13 @@ def sliding_inference( tracked_frames = self.track_queue.collate_tracks( device=frame_to_track.frame_id.device ) - if self.verbose: - warnings.warn( - f"Current number of tracks is {self.track_queue.n_tracks}" - ) + logger.debug(f"Current number of tracks is {self.track_queue.n_tracks}") if ( self.persistent_tracking and frame_to_track.frame_id == 0 ): # check for new video and clear queue - if self.verbose: - warnings.warn("New Video! Resetting Track Queue.") + + logger.debug("New Video! Resetting Track Queue.") self.track_queue.end_tracks() """ @@ -165,10 +164,10 @@ def sliding_inference( """ if len(self.track_queue) == 0: if frame_to_track.has_instances(): - if self.verbose: - warnings.warn( - f"Initializing track on clip ind {batch_idx} frame {frame_to_track.frame_id.item()}" - ) + + logger.debug( + f"Initializing track on clip ind {batch_idx} frame {frame_to_track.frame_id.item()}" + ) curr_track_id = 0 for i, instance in enumerate(frames[batch_idx].instances): @@ -241,8 +240,7 @@ def _run_global_tracker( query_instances = query_frame.instances all_instances = [instance for frame in frames for instance in frame.instances] - if self.verbose: - print(f"Frame {query_frame.frame_id.item()}") + logger.debug(f"Frame {query_frame.frame_id.item()}") instances_per_frame = [frame.num_detected for frame in frames] @@ -250,8 +248,7 @@ def _run_global_tracker( instances_per_frame ) # Number of instances in window; length of window. - if self.verbose: - print(f"total_instances: {total_instances}") + logger.debug(f"total_instances: {total_instances}") overlap_thresh = self.overlap_thresh mult_thresh = self.mult_thresh @@ -265,10 +262,7 @@ def _run_global_tracker( # (L=1, n_query, total_instances) with torch.no_grad(): asso_matrix = model(all_instances, query_instances) - # if model.transformer.return_embedding: - # query_frame.embeddings = embed TODO add embedding to Instance Object - # if query_frame == 1: - # print(asso_output) + asso_output = asso_matrix[-1].matrix.split( instances_per_frame, dim=1 ) # (window_size, n_query, N_i) @@ -288,40 +282,27 @@ def _run_global_tracker( query_frame.add_traj_score("asso_output", asso_output_df) query_frame.asso_output = asso_matrix - try: - n_query = ( - query_frame.num_detected - ) # Number of instances in the current/query frame. - except Exception as e: - print(len(frames), query_frame, frames[-1]) - raise (e) + n_query = ( + query_frame.num_detected + ) # Number of instances in the current/query frame. n_nonquery = ( total_instances - n_query ) # Number of instances in the window not including the current/query frame. - if self.verbose: - print(f"n_nonquery: {n_nonquery}") - print(f"n_query: {n_query}") - try: - instance_ids = torch.cat( - [ - x.get_pred_track_ids() - for batch_idx, x in enumerate(frames) - if batch_idx != query_ind - ], - dim=0, - ).view( - n_nonquery - ) # (n_nonquery,) - except Exception as e: - print( - [ - [instance.pred_track_id.device for instance in frame.instances] - for frame in frames - ] - ) - raise (e) + logger.debug(f"n_nonquery: {n_nonquery}") + logger.debug(f"n_query: {n_query}") + + instance_ids = torch.cat( + [ + x.get_pred_track_ids() + for batch_idx, x in enumerate(frames) + if batch_idx != query_ind + ], + dim=0, + ).view( + n_nonquery + ) # (n_nonquery,) query_inds = [ x @@ -350,9 +331,8 @@ def _run_global_tracker( unique_ids = torch.unique(instance_ids) # (n_nonquery,) - if self.verbose: - print(f"Instance IDs: {instance_ids}") - print(f"unique ids: {unique_ids}") + logger.debug(f"Instance IDs: {instance_ids}") + logger.debug(f"unique ids: {unique_ids}") id_inds = ( unique_ids[None, :] == instance_ids[:, None] @@ -451,13 +431,7 @@ def _run_global_tracker( query_frame.add_traj_score("scaled", scaled_traj_score_df) ################################################################################ - try: - match_i, match_j = linear_sum_assignment((-traj_score)) - except ValueError as e: - print(reid_features.isnan().any()) - print(asso_output) - print(traj_score) - raise (e) + match_i, match_j = linear_sum_assignment((-traj_score)) track_ids = instance_ids.new_full((n_query,), -1) for i, j in zip(match_i, match_j): @@ -471,18 +445,15 @@ def _run_global_tracker( overlap_thresh * id_inds[:, j].sum() if mult_thresh else overlap_thresh ) if n_traj >= self.max_tracks or traj_score[i, j] > thresh: - if self.verbose: - print( - f"Assigning instance {i} to track {j} with id {unique_ids[j]}" - ) + logger.debug( + f"Assigning instance {i} to track {j} with id {unique_ids[j]}" + ) track_ids[i] = unique_ids[j] query_frame.instances[i].track_score = scaled_traj_score[i, j].item() - if self.verbose: - print(f"track_ids: {track_ids}") + logger.debug(f"track_ids: {track_ids}") for i in range(n_query): if track_ids[i] < 0: - if self.verbose: - print(f"Creating new track {n_traj}") + logger.debug(f"Creating new track {curr_track}") curr_track += 1 track_ids[i] = curr_track diff --git a/dreem/io/association_matrix.py b/dreem/io/association_matrix.py index be34291e..9c613c5f 100644 --- a/dreem/io/association_matrix.py +++ b/dreem/io/association_matrix.py @@ -4,8 +4,12 @@ import numpy as np import pandas as pd import attrs +import logging + from dreem.io import Instance +logger = logging.getLogger("dreem.io") + @attrs.define class AssociationMatrix: @@ -240,10 +244,11 @@ def __getitem__( try: return self.numpy()[query_ind[:, None], ref_ind].squeeze() except IndexError as e: - print(f"Query_insts: {type(query_inst)}") - print(f"Query_inds: {query_ind}") - print(f"Ref_insts: {type(ref_inst)}") - print(f"Ref_ind: {ref_ind}") + logger.exception(f"Query_insts: {type(query_inst)}") + logger.exception(f"Query_inds: {query_ind}") + logger.exception(f"Ref_insts: {type(ref_inst)}") + logger.exception(f"Ref_ind: {ref_ind}") + logger.exception(e) raise (e) def __getindices__( diff --git a/dreem/io/config.py b/dreem/io/config.py index a8693d4b..9ba3711f 100644 --- a/dreem/io/config.py +++ b/dreem/io/config.py @@ -3,13 +3,15 @@ from __future__ import annotations from omegaconf import DictConfig, OmegaConf, open_dict -from pprint import pprint from typing import Iterable from pathlib import Path +import logging import glob import pytorch_lightning as pl import torch +logger = logging.getLogger("dreem.io") + class Config: """Class handling loading components based on config params.""" @@ -26,13 +28,13 @@ def __init__(self, cfg: DictConfig, params_cfg: DictConfig | None = None): training/evaluation """ base_cfg = cfg - print(f"Base Config: {cfg}") + logger.info(f"Base Config: {cfg}") if "params_config" in cfg: params_cfg = OmegaConf.load(cfg.params_config) if params_cfg: - pprint(f"Overwriting base config with {params_cfg}") + logger.info(f"Overwriting base config with {params_cfg}") with open_dict(base_cfg): self.cfg = OmegaConf.merge(base_cfg, params_cfg) # merge configs else: @@ -71,13 +73,13 @@ def set_hparams(self, hparams: dict) -> bool: `True` if config is successfully updated, `False` otherwise """ if hparams == {} or hparams is None: - print("Nothing to update!") + logger.warning("Nothing to update!") return False for hparam, val in hparams.items(): try: OmegaConf.update(self.cfg, hparam, val) except Exception as e: - print(f"Failed to update {hparam} to {val} due to {e}") + logger.exception(f"Failed to update {hparam} to {val} due to {e}") return False return True @@ -159,11 +161,11 @@ def get_data_paths(self, data_cfg: dict) -> tuple[list[str], list[str]]: vid_suff = dir_cfg.vid_suffix labels_path = f"{dir_cfg.path}/*{labels_suff}" vid_path = f"{dir_cfg.path}/*{vid_suff}" - print(f"Searching for labels matching {labels_path}") + logger.debug(f"Searching for labels matching {labels_path}") label_files = glob.glob(labels_path) - print(f"Searching for videos matching {vid_path}") + logger.debug(f"Searching for videos matching {vid_path}") vid_files = glob.glob(vid_path) - print(f"Found {len(label_files)} labels and {len(vid_files)} videos") + logger.debug(f"Found {len(label_files)} labels and {len(vid_files)} videos") return label_files, vid_files return None, None @@ -361,7 +363,7 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint: try: Path(dirpath).mkdir(parents=True, exist_ok=True) except OSError as e: - print( + logger.exception( f"Cannot create a new folder. Check the permissions to the given Checkpoint directory. \n {e}" ) diff --git a/dreem/io/frame.py b/dreem/io/frame.py index 761515ca..610512d3 100644 --- a/dreem/io/frame.py +++ b/dreem/io/frame.py @@ -1,12 +1,15 @@ """Module containing data classes such as Instances and Frames.""" from __future__ import annotations +from numpy.typing import ArrayLike +from typing import Self import torch import sleap_io as sio import numpy as np import attrs -from numpy.typing import ArrayLike -from typing import Self +import logging + +logger = logging.getLogger("dreem.io") def _to_tensor(data: float | ArrayLike) -> torch.Tensor: @@ -411,7 +414,7 @@ def get_traj_score(self, key: str | None = None) -> dict | ArrayLike | None: try: return self._traj_score[key] except KeyError as e: - print(f"Could not access {key} traj_score due to {e}") + logger.exception(f"Could not access {key} traj_score due to {e}") return None def add_traj_score(self, key: str, traj_score: ArrayLike) -> None: @@ -511,11 +514,8 @@ def get_crops(self) -> torch.Tensor: """ if not self.has_instances(): return torch.tensor([]) - try: - return torch.cat([instance.crop for instance in self.instances], dim=0) - except Exception as e: - print(self) - raise (e) + + return torch.cat([instance.crop for instance in self.instances], dim=0) def has_features(self) -> bool: """Check if any of frames instances has reid features already computed. diff --git a/dreem/io/instance.py b/dreem/io/instance.py index fca3c0c1..81f9b825 100644 --- a/dreem/io/instance.py +++ b/dreem/io/instance.py @@ -4,9 +4,12 @@ import sleap_io as sio import numpy as np import attrs +import logging from numpy.typing import ArrayLike from typing import Self +logger = logging.getLogger("dreem.io") + def _to_tensor(data: float | ArrayLike) -> torch.Tensor: """Convert data to a torch.Tensor object. @@ -237,7 +240,7 @@ def to_slp( track_lookup, ) except Exception as e: - print( + logger.exception( f"Pose: {np.array(list(self.pose.values())).shape}, Pose score shape {self.point_scores.shape}" ) raise RuntimeError(f"Failed to convert to sio.PredictedInstance: {e}") @@ -501,7 +504,7 @@ def get_embedding( try: return self._embeddings[emb_type] except KeyError: - print( + logger.exception( f"{emb_type} not saved! Only {list(self._embeddings.keys())} are available" ) return None diff --git a/dreem/io/visualize.py b/dreem/io/visualize.py index 1c0095e3..88714e84 100644 --- a/dreem/io/visualize.py +++ b/dreem/io/visualize.py @@ -11,6 +11,9 @@ import pandas as pd import numpy as np import cv2 +import logging + +logger = logging.getLogger("dreem.io") palette = sns.color_palette("tab20") @@ -172,9 +175,6 @@ def annotate_video( ) midpt = (int(x), int(y)) - # print(midpt, type(midpt)) - - # assert idx < len(instance[key]) pred_track_id = instance[key] if "Track_score" in instance.index: @@ -195,8 +195,6 @@ def annotate_video( .tolist()[::-1] ) - # print(instance[key]) - # Bbox. if boxes is not None: frame = cv2.rectangle( @@ -266,7 +264,7 @@ def annotate_video( except Exception as e: writer.close() - print(e) + logger.exception(e) return False writer.close() @@ -327,9 +325,9 @@ def main(cfg: DictConfig): ) if frames_annotated: - print("Video saved to {cfg.save_path}!") + logger.info("Video saved to {cfg.save_path}!") else: - print("Failed to annotate video!") + logger.error("Failed to annotate video!") if __name__ == "__main__": diff --git a/dreem/models/embedding.py b/dreem/models/embedding.py index 6aea8f24..8a959c90 100644 --- a/dreem/models/embedding.py +++ b/dreem/models/embedding.py @@ -2,8 +2,10 @@ import math import torch +import logging from dreem.models.mlp import MLP +logger = logging.getLogger("dreem.models") # todo: add named tensors, clean variable names @@ -282,7 +284,10 @@ def _learned_pos_embedding(self, boxes: torch.Tensor) -> torch.Tensor: self.emb_num, n_anchors, 4, f ) # T x 4 x (D * 4) except RuntimeError as e: - print(f"Hint: `n_points` ({self.n_points}) may be set incorrectly!") + logger.exception( + f"Hint: `n_points` ({self.n_points}) may be set incorrectly!" + ) + logger.exception(e) raise (e) left_emb = pos_emb_table.gather( diff --git a/dreem/models/gtr_runner.py b/dreem/models/gtr_runner.py index c6119d22..3a42d116 100644 --- a/dreem/models/gtr_runner.py +++ b/dreem/models/gtr_runner.py @@ -2,14 +2,16 @@ import torch import gc +import logging from dreem.inference import Tracker from dreem.inference import metrics from dreem.models import GlobalTrackingTransformer from dreem.training.losses import AssoLoss from dreem.models.model_utils import init_optimizer, init_scheduler from pytorch_lightning import LightningModule -from dreem.io.frame import Frame -from dreem.io.instance import Instance + + +logger = logging.getLogger("dreem.models") class GTRRunner(LightningModule): @@ -199,7 +201,10 @@ def _shared_eval_step( return_metrics.update(clearmot.to_dict()) return_metrics["batch_size"] = len(frames) except Exception as e: - print(f"Failed on frame {frames[0].frame_id} of video {frames[0].video_id}") + logger.exception( + f"Failed on frame {frames[0].frame_id} of video {frames[0].video_id}" + ) + logger.exception(e) raise (e) return return_metrics diff --git a/dreem/models/model_utils.py b/dreem/models/model_utils.py index 4a37d22b..75638e7e 100644 --- a/dreem/models/model_utils.py +++ b/dreem/models/model_utils.py @@ -41,11 +41,8 @@ def get_times( Returns: Tuple of Corresponding frame indices eg [0, 0, 1, 1, ..., T, T] for ref and query instances. """ - try: - ref_inds = torch.concat([instance.frame.frame_id for instance in ref_instances]) - except RuntimeError as e: - print([instance.frame.frame_id.device for instance in ref_instances]) - raise (e) + ref_inds = torch.concat([instance.frame.frame_id for instance in ref_instances]) + if query_instances is not None: query_inds = torch.concat( [instance.frame.frame_id for instance in query_instances] diff --git a/dreem/training/__init__.py b/dreem/training/__init__.py index c4e96012..932d36ac 100644 --- a/dreem/training/__init__.py +++ b/dreem/training/__init__.py @@ -1,3 +1 @@ """Initialize training module.""" - -# from .train import train diff --git a/dreem/training/losses.py b/dreem/training/losses.py index bfda4c3e..185199e5 100644 --- a/dreem/training/losses.py +++ b/dreem/training/losses.py @@ -46,13 +46,9 @@ def forward( """ # get number of detected objects and ground truth ids n_t = [frame.num_detected for frame in frames] - try: - target_inst_id = torch.cat( - [frame.get_gt_track_ids().to(asso_preds[-1].device) for frame in frames] - ) - except RuntimeError as e: - print([frame.get_gt_track_ids().device for frame in frames]) - raise (e) + target_inst_id = torch.cat( + [frame.get_gt_track_ids().to(asso_preds[-1].device) for frame in frames] + ) instances = [instance for frame in frames for instance in frame.instances] # for now set equal since detections are fixed diff --git a/dreem/training/train.py b/dreem/training/train.py index 811cac06..372bfa67 100644 --- a/dreem/training/train.py +++ b/dreem/training/train.py @@ -8,13 +8,16 @@ from dreem.datasets.data_utils import view_training_batch from multiprocessing import cpu_count from omegaconf import DictConfig -from pprint import pprint + import os import hydra import pandas as pd import pytorch_lightning as pl import torch import torch.multiprocessing +import logging + +logger = logging.getLogger("training") @hydra.main(config_path=None, config_name=None, version_base=None) @@ -43,11 +46,11 @@ def run(cfg: DictConfig): hparams = hparams_df.iloc[index].to_dict() if train_cfg.set_hparams(hparams): - print("Updated the following hparams to the following values") - pprint(hparams) + logger.debug("Updated the following hparams to the following values") + logger.debug(hparams) else: hparams = {} - pprint(f"Final train config: {train_cfg}") + logger.info(f"Final train config: {train_cfg}") model = train_cfg.get_model() train_dataset = train_cfg.get_dataset(mode="train") @@ -72,7 +75,7 @@ def run(cfg: DictConfig): model = train_cfg.get_gtr_runner() # TODO see if we can use torch.compile() - logger = train_cfg.get_logger() + run_logger = train_cfg.get_logger() callbacks = [] _ = callbacks.extend(train_cfg.get_checkpointing()) @@ -84,7 +87,7 @@ def run(cfg: DictConfig): trainer = train_cfg.get_trainer( callbacks, - logger, + run_logger, accelerator=accelerator, devices=devices, ) diff --git a/logging.yaml b/logging.yaml new file mode 100644 index 00000000..f71a92d6 --- /dev/null +++ b/logging.yaml @@ -0,0 +1,36 @@ +version: 1 +formatters: + simple: + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +handlers: + empty: + class: logging.NullHandler +loggers: + dreem: + level: INFO + propagate: no + dreem.datasets: + level: INFO + propagate: yes + parent: [loggers.dreem] + dreem.inference: + level: INFO + propagate: yes + parent: [loggers.dreem] + dreem.io: + level: INFO + propagate: yes + parent: [loggers.dreem] + dreem.models: + level: INFO + propagate: yes + parent: [loggers.dreem] + dreem.training: + level: INFO + propagate: yes + parent: [loggers.dreem] +root: + level: [INFO] + handlers: [empty] + + diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 2287c4f9..d5a38f7f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -250,10 +250,6 @@ def test_cell_tracking_dataset(cell_tracking): clip_length = 8 - # print(cell_tracking[0]) - # print(cell_tracking[1]) - # print(cell_tracking[2]) - train_ds = CellTrackingDataset( raw_images=[cell_tracking[0]], gt_images=[cell_tracking[1]],