Skip to content

Commit

Permalink
Video preview.
Browse files Browse the repository at this point in the history
  • Loading branch information
iwatkot committed Feb 7, 2024
1 parent ca269d4 commit f639ca8
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
2 changes: 2 additions & 0 deletions project_dataset/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

ABSOLUTE_PATH = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.dirname(ABSOLUTE_PATH)
TEMP_DIR = os.path.join(PARENT_DIR, "temp")
sly.fs.mkdir(TEMP_DIR)

if sly.is_development():
load_dotenv(os.path.join(PARENT_DIR, "local.env"))
Expand Down
101 changes: 101 additions & 0 deletions project_dataset/src/ui/inference_preview.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
from random import choice
from time import sleep
from typing import Any, Dict, List, Literal, Tuple, Union

import cv2
import numpy as np
import supervisely as sly
import yaml
from supervisely.app.widgets import (
Expand Down Expand Up @@ -90,6 +93,13 @@

@preview_button.click
def create_preview() -> None:
if settings.inference_mode.get_value() == "sliding_window":
window_preview()
else:
full_preview()


def full_preview() -> None:
"""Create a preview of the model inference and show it in the gallery."""
try:
inference_settings = yaml.safe_load(settings.additional_settings.get_value())
Expand Down Expand Up @@ -127,11 +137,102 @@ def create_preview() -> None:
)


def window_preview():
if random_image_checkbox.is_checked():
image_info: sly.ImageInfo = choice(g.input_images)
else:
# TODO: implement image selection
pass

check_sliding_sizes_by_image(image_info)
inference_setting = get_sliding_window_params()

ann_pred_res = g.api.task.send_request(
g.model_session_id,
"inference_image_id",
data={"image_id": image_info.id, "settings": inference_setting},
timeout=200,
)

try:
predictions = ann_pred_res["data"]["slides"]
except Exception as ex:
raise ValueError("Cannot parse slides predictions, reason: {}".format(repr(ex)))

image_np = g.api.image.download_np(image_info.id)
file_info = write_video(image_np, predictions)

# TODO: Show video in the gallery.
# ? Change UI widget?


# region legacy
# This code comes from the legacy version of the app mostly as is.
# Functions were modified to use the global variables with less arguments
# and UI state was removed (as it is not used in the new version of the app).
# Consider refactoring logic of functions later.
def check_sliding_sizes_by_image(image_info: sly.ImageInfo) -> None:
"""Checks sliding window sizes by the image and updates them if necessary.
:param image_info: Image to check sliding window sizes by.
:type image_info: sly.ImageInfo
"""
if window_height.get_value() > image_info.height:
window_height.value = image_info.height

if window_width.get_value() > image_info.width:
window_width.value = image_info.width


def write_video(image_np: np.ndarray, predictions, last_two_frames_copies=8, max_video_size=1080):
scale_ratio = None
if image_np.shape[1] > max_video_size:
scale_ratio = max_video_size / image_np.shape[1]
image_np = cv2.resize(
image_np, (int(image_np.shape[1] * scale_ratio), int(image_np.shape[0] * scale_ratio))
)

video_path = os.path.join(g.my_app.data_dir, "preview.mp4")
video = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*"VP90"),
vizulazation_fps.get_value(),
(image_np.shape[1], image_np.shape[0]),
)

for i, pred in enumerate(predictions):
rect = pred["rectangle"]
rect = sly.Rectangle.from_json(rect)
if scale_ratio is not None:
rect = rect.scale(scale_ratio)
labels = pred["labels"]
for label_ind, label in enumerate(labels):
labels[label_ind] = sly.Label.from_json(label, g.model_meta)

frame = image_np.copy()
rect.draw_contour(frame, [255, 0, 0], thickness=5)
for label in labels:
if scale_ratio is not None:
label = sly.Label(label.geometry.scale(scale_ratio), label.obj_class)
label.draw_contour(frame, thickness=3)
sly.image.write(os.path.join(g.TEMP_DIR, f"{i:05d}.jpg"), frame)
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

if i >= len(predictions) - 2:
for n in range(last_two_frames_copies):
video.write(frame_bgr)
else:
video.write(frame_bgr)

video.release()

remote_video_path = os.path.join(g.TEMP_DIR, "preview.mp4")
if g.api.file.exists(g.team_id, remote_video_path):
g.api.file.remove(g.team_id, remote_video_path)
file_info = g.api.file.upload(g.team_id, video_path, remote_video_path)
return file_info


def apply_model_to_image(
image_info: sly.ImageInfo, inference_setting: Dict[str, Union[float, bool]]
) -> Tuple[sly.Annotation, sly.Annotation, sly.ProjectMeta]:
Expand Down

0 comments on commit f639ca8

Please sign in to comment.