Skip to content

Commit

Permalink
fix small bug
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad committed Jun 10, 2024
1 parent 5ce73d8 commit df26cec
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions dreem/io/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __repr__(self) -> str:
")"
)

def to(self, map_location: str) -> "Frame":
def to(self, map_location: Union[str, torch.device]) -> "Frame":
"""Move frame to different device or dtype (See `torch.to` for more info).
Args:
Expand All @@ -116,7 +116,7 @@ def to(self, map_location: str) -> "Frame":
for instance in self.instances:
instance = instance.to(map_location)

if isinstance(map_location, str):
if isinstance(map_location, (str, torch.device)):
self._device = map_location

return self
Expand Down
2 changes: 1 addition & 1 deletion dreem/io/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def to(self, map_location: Union[str, torch.device]) -> "Instance":
self._bbox = self._bbox.to(map_location)
self._crop = self._crop.to(map_location)
self._features = self._features.to(map_location)
if isinstance(map_location, str, torch.device):
if isinstance(map_location, (str, torch.device)):
self.device = map_location

return self
Expand Down
7 changes: 5 additions & 2 deletions dreem/io/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,15 @@ def main(cfg: DictConfig):
"""Take in a path to a video + labels file, annotates a video and saves it to the specified path."""
labels = pd.read_csv(cfg.labels_path)
video = imageio.get_reader(cfg.vid_path, "ffmpeg")
frames_annotated = annotate_video(video, labels, save_path=cfg.save_path, **cfg.annotate)

frames_annotated = annotate_video(
video, labels, save_path=cfg.save_path, **cfg.annotate
)

if frames_annotated:
print("Video saved to {cfg.save_path}!")
else:
print("Failed to annotate video!")


if __name__ == "__main__":
main()

0 comments on commit df26cec

Please sign in to comment.