From df26cecb683d2d9d8cb2cf3299ff77746f8d6da2 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 10 Jun 2024 14:51:11 -0700 Subject: [PATCH] fix small bug --- dreem/io/frame.py | 4 ++-- dreem/io/instance.py | 2 +- dreem/io/visualize.py | 7 +++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/dreem/io/frame.py b/dreem/io/frame.py index d6a4617..f55775a 100644 --- a/dreem/io/frame.py +++ b/dreem/io/frame.py @@ -91,7 +91,7 @@ def __repr__(self) -> str: ")" ) - def to(self, map_location: str) -> "Frame": + def to(self, map_location: Union[str, torch.device]) -> "Frame": """Move frame to different device or dtype (See `torch.to` for more info). Args: @@ -116,7 +116,7 @@ def to(self, map_location: str) -> "Frame": for instance in self.instances: instance = instance.to(map_location) - if isinstance(map_location, str): + if isinstance(map_location, (str, torch.device)): self._device = map_location return self diff --git a/dreem/io/instance.py b/dreem/io/instance.py index 99d9d72..a114136 100644 --- a/dreem/io/instance.py +++ b/dreem/io/instance.py @@ -146,7 +146,7 @@ def to(self, map_location: Union[str, torch.device]) -> "Instance": self._bbox = self._bbox.to(map_location) self._crop = self._crop.to(map_location) self._features = self._features.to(map_location) - if isinstance(map_location, str, torch.device): + if isinstance(map_location, (str, torch.device)): self.device = map_location return self diff --git a/dreem/io/visualize.py b/dreem/io/visualize.py index 753959d..64e7131 100644 --- a/dreem/io/visualize.py +++ b/dreem/io/visualize.py @@ -316,12 +316,15 @@ def main(cfg: DictConfig): """Take in a path to a video + labels file, annotates a video and saves it to the specified path.""" labels = pd.read_csv(cfg.labels_path) video = imageio.get_reader(cfg.vid_path, "ffmpeg") - frames_annotated = annotate_video(video, labels, save_path=cfg.save_path, **cfg.annotate) - + frames_annotated = annotate_video( + video, labels, save_path=cfg.save_path, **cfg.annotate + ) + if frames_annotated: print("Video saved to {cfg.save_path}!") else: print("Failed to annotate video!") + if __name__ == "__main__": main()