Skip to content
This repository has been archived by the owner on Jul 3, 2024. It is now read-only.

Commit

Permalink
refactor: apply pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
Mouwrice committed Apr 13, 2024
1 parent bd9ee9a commit 06b35aa
Show file tree
Hide file tree
Showing 14 changed files with 282 additions and 244 deletions.
54 changes: 29 additions & 25 deletions drumpy/app/camera_display.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Self

import numpy as np
import numpy.typing as npt
import pygame.time
from pygame import Surface

Expand All @@ -14,7 +14,7 @@ class VideoDisplay:
"""

def __init__(
self: Self,
self,
video_source: VideoSource,
media_pipe_pose: MediaPipePose,
window: Surface,
Expand All @@ -28,40 +28,44 @@ def __init__(
self.video_source = video_source
self.prev_surface = None

def update(self: Self) -> None:
frame = self.video_source.get_frame()
assert frame is None or frame.shape[0] == frame.shape[1], "Frame is not square"
def update(self) -> None:
result: npt.NDArray[np.float32] | None = self.video_source.get_frame()
assert (
result is None or result.shape[0] == result.shape[1]
), "Frame is not square"

# There is no new frame to display
if frame is None and self.prev_surface is not None:
if result is None and self.prev_surface is not None:
self.window.blit(self.prev_surface, self.rect.topleft)
return

# There is no new frame to display and no previous frame to display
if frame is None and self.prev_surface is None:
if result is None and self.prev_surface is None:
return

timestamp_ms = self.video_source.get_timestamp_ms()
if result is not None:
frame: npt.NDArray[np.float32] = result
timestamp_ms = self.video_source.get_timestamp_ms()

self.media_pipe_pose.process_image(frame, timestamp_ms)
self.media_pipe_pose.process_image(frame, timestamp_ms)

# Draw the landmarks on the image
if self.media_pipe_pose.visualisation is not None:
frame = self.media_pipe_pose.visualisation
# Draw the landmarks on the image
if self.media_pipe_pose.visualisation is not None:
frame = self.media_pipe_pose.visualisation

if self.source == Source.CAMERA:
frame = frame.swapaxes(0, 1)
if self.source == Source.CAMERA:
frame = frame.swapaxes(0, 1)

image_surface = pygame.image.frombuffer(
frame.tobytes(), (frame.shape[0], frame.shape[0]), "RGB"
)
image_surface = pygame.image.frombuffer(
frame.tobytes(), (frame.shape[0], frame.shape[0]), "RGB"
)

# # Rotate the image 90 degrees
# if self.source == Source.CAMERA:
# image_surface = pygame.transform.rotate(image_surface, -90)
# # Rotate the image 90 degrees
# if self.source == Source.CAMERA:
# image_surface = pygame.transform.rotate(image_surface, -90)

# Scale the image to fit the window
image_surface = pygame.transform.scale(image_surface, self.rect.size)
# Scale the image to fit the window
image_surface = pygame.transform.scale(image_surface, self.rect.size)

self.prev_surface = image_surface
self.window.blit(image_surface, self.rect.topleft)
self.prev_surface = image_surface
self.window.blit(image_surface, self.rect.topleft)
28 changes: 17 additions & 11 deletions drumpy/app/fps_display.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Self

from mediapipe.tasks.python.vision import RunningMode
from mediapipe.tasks.python.vision import RunningMode # type: ignore
from pygame import Rect
from pygame_gui import UIManager
from pygame_gui.elements import UILabel
from pygame_gui import UIManager # type: ignore
from pygame_gui.elements import UILabel # type: ignore

from drumpy.mediapipe_pose.mediapipe_pose import MediaPipePose

Expand All @@ -21,11 +21,14 @@ def __init__(
media_pipe_pose: MediaPipePose,
) -> None:
mode = ""
match media_pipe_pose.options.running_mode:
case RunningMode.LIVE_STREAM:
match media_pipe_pose.options.running_mode: # type: ignore
case RunningMode.LIVE_STREAM: # type: ignore
mode = "Async Mode"
case RunningMode.VIDEO:
case RunningMode.VIDEO: # type: ignore
mode = "Blocking Mode"
case _: # type: ignore
pass

self.model = media_pipe_pose.model
super().__init__(
Rect((0, 0), (900, 50)),
Expand All @@ -34,8 +37,8 @@ def __init__(
anchors={"top": "top", "left": "left"},
)

self.ui_time_deltas = []
self.mediapipe_time_deltas = []
self.ui_time_deltas: list[float] = []
self.mediapipe_time_deltas: list[int] = []

self.media_pipe_pose = media_pipe_pose
self.ui_manager = ui_manager
Expand Down Expand Up @@ -68,11 +71,14 @@ def update(self: Self, time_delta: float) -> None:
)

mode = ""
match self.media_pipe_pose.options.running_mode:
case RunningMode.LIVE_STREAM:
match self.media_pipe_pose.options.running_mode: # type: ignore
case RunningMode.LIVE_STREAM: # type: ignore
mode = "Async Mode"
case RunningMode.VIDEO:
case RunningMode.VIDEO: # type: ignore
mode = "Blocking Mode"
case _: # type: ignore
pass

self.set_text(
f"UI FPS: {ui_fps:.2f} Camera FPS: {camera_fps:.2f} {mode} Model: {self.model}"
)
20 changes: 10 additions & 10 deletions drumpy/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import pygame
import pygame.camera
from mediapipe.tasks.python import BaseOptions
from mediapipe.tasks.python.vision import RunningMode
from pygame_gui import UIManager
from mediapipe.tasks.python import BaseOptions # type: ignore
from mediapipe.tasks.python.vision import RunningMode # type: ignore
from pygame_gui import UIManager # type: ignore

from drumpy.app.camera_display import VideoDisplay
from drumpy.app.fps_display import FPSDisplay
Expand All @@ -24,9 +24,9 @@ def __init__(
self: Self,
source: Source = Source.CAMERA,
file_path: Optional[str] = None,
running_mode: RunningMode = RunningMode.LIVE_STREAM,
running_mode: RunningMode = RunningMode.LIVE_STREAM, # type: ignore
model: LandmarkerModel = LandmarkerModel.FULL,
delegate: BaseOptions.Delegate = BaseOptions.Delegate.GPU,
delegate: BaseOptions.Delegate = BaseOptions.Delegate.GPU, # type: ignore
log_file: Optional[str] = None,
landmark_type: LandmarkType = LandmarkType.WORLD_LANDMARKS,
) -> None:
Expand All @@ -53,10 +53,10 @@ def __init__(
self.drum_trackers = DrumTrackers()

self.media_pipe_pose = MediaPipePose(
running_mode=running_mode,
running_mode=running_mode, # type: ignore
model=model,
log_file=log_file,
delegate=delegate,
delegate=delegate, # type: ignore
landmark_type=landmark_type,
drum_trackers=self.drum_trackers,
)
Expand All @@ -66,11 +66,11 @@ def __init__(
media_pipe_pose=self.media_pipe_pose,
)

self.video_source = None
match source:
case Source.CAMERA:
self.video_source = CameraSource(cameras[0])
case Source.FILE:
assert file_path is not None, "File path must be provided"
self.video_source = VideoFileSource(file_path)

self.fps = self.video_source.get_fps()
Expand Down Expand Up @@ -109,9 +109,9 @@ def start(self: Self) -> None:
def main() -> None:
app = App(
source=Source.FILE,
running_mode=RunningMode.VIDEO,
running_mode=RunningMode.VIDEO, # type: ignore
file_path="../../recordings/multicam_asil_01_front.mkv",
log_file="test.csv",
# log_file="test.csv",
)
app.start()

Expand Down
25 changes: 12 additions & 13 deletions drumpy/app/video_source.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from abc import ABC, abstractmethod
from enum import Enum
from multiprocessing import Pool
from pathlib import Path
from typing import Self, Optional
from typing import Self

import cv2
import numpy as np
import pygame.transform
from numpy import ndarray
from pygame import camera, surfarray
import numpy.typing as npt
from pygame import camera, surfarray, Surface


class Source(Enum):
Expand Down Expand Up @@ -35,7 +35,7 @@ def get_fps(self: Self) -> float:
"""

@abstractmethod
def get_frame(self: Self) -> Optional[ndarray]:
def get_frame(self: Self) -> npt.NDArray[np.float32] | None:
"""
Get the next frame from the video
:return: The frame and the timestamp
Expand Down Expand Up @@ -68,7 +68,7 @@ class VideoFileSource(VideoSource):
Class to handle a video source from a file
"""

def __init__(self: Self, file_path: str) -> None:
def __init__(self, file_path: str) -> None:
super().__init__()

assert Path(file_path).exists(), f"File {file_path} does not exist"
Expand All @@ -82,7 +82,6 @@ def __init__(self: Self, file_path: str) -> None:
self.size = (smallest, smallest)
self.left_offset = (source_width - smallest) // 2
self.top_offset = (source_height - smallest) // 2
self.pool: Pool = None

def get_fps(self: Self) -> float:
"""
Expand All @@ -91,19 +90,19 @@ def get_fps(self: Self) -> float:
"""
return self.source_fps

def get_frame(self: Self) -> Optional[ndarray]:
def get_frame(self: Self) -> npt.NDArray[np.float32] | None:
"""
Get the next frame from the video
:return: The frame and the timestamp
"""
ret, frame = self.cap.read()
if not ret or frame is None:
if not ret:
self.stopped = True
return None

frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Crop the image to a square aspect ratio
return frame[
return frame[ # type: ignore
self.top_offset : self.top_offset + self.size[1],
self.left_offset : self.left_offset + self.size[0],
].copy()
Expand Down Expand Up @@ -153,16 +152,16 @@ def get_fps(self: Self) -> float:
"""
return 60

def get_frame(self: Self) -> Optional[ndarray]:
def get_frame(self: Self) -> npt.NDArray[np.float32] | None:
"""
Get the next frame from the video
:return: The frame and the timestamp
"""
if self.camera.query_image():
frame = pygame.Surface(self.size)
frame = Surface(self.size)
image = self.camera.get_image()
frame.blit(image, (0, 0), ((self.left_offset, self.top_offset), self.size))
return surfarray.array3d(frame)
return surfarray.array3d(frame) # type: ignore

return None

Expand Down
36 changes: 18 additions & 18 deletions drumpy/drum/drum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from time import sleep
from typing import Self, Optional


from drumpy.drum.sound import Sound, SoundState
from drumpy.util import print_float_array, Position
from drumpy.util import position_str, Position
from drumpy.mediapipe_pose.mediapipe_markers import MarkerEnum


class DrumPresets:
Expand Down Expand Up @@ -49,26 +49,26 @@ def __init__(
sounds: list[Sound],
sleep_option: SleepOption = SleepOption.SLEEP,
) -> None:
self.sounds = sounds
self.sounds: list[Sound] = sounds

# Queue to keep track of sounds that need to be calibrated
self.auto_calibrations = []
self.auto_calibrations: list[Sound] = []

self.sleep_option = sleep_option
self.sleep_option: SleepOption = sleep_option

def __str__(self: Self) -> str:
return "\n".join([str(sound) for sound in self.sounds])

def find_and_play_sound(
self: Self,
position: Position,
marker_label: str,
sounds: Optional[list[Sound]],
marker: MarkerEnum,
sounds: Optional[list[Sound]] = None,
) -> None:
"""
Find the closest sound to the given position and play it
If the drum is calibrating sounds, the to be calibrated sound will be played
:param marker_label: The label of the marker that is hitting the sound
:param marker: The marker that hit the sound
:param sounds: List of sounds to consider, if None, all sounds will be considered
:param position: A 3D position as a numpy array
:return:
Expand All @@ -83,7 +83,7 @@ def find_and_play_sound(
if sound.state == SoundState.CALIBRATING:
sound.hit(position)
print(
f"\t{marker_label}: {sound.name} with distance {distance:.3f} at {print_float_array(position)}"
f"\t{marker}: {sound.name} with distance {distance:.3f} at {position_str(position)}"
)
return

Expand All @@ -94,23 +94,20 @@ def find_and_play_sound(
if closest_sound is not None:
closest_sound.hit(position)
print(
f"{marker_label}: {closest_sound.name} with distance {closest_distance:.3f} "
f"at {print_float_array(position)}"
f"{marker}: {closest_sound.name} with distance {closest_distance:.3f} "
f"at {position_str(position)}"
)
else:
print(
f"{marker_label}: No sound found for position {print_float_array(position)} "
f"with distance {closest_distance:.3f}"
)
print(f"{marker}: No sound found for position {position_str(position)}")

def auto_calibrate(self: Self, sounds: Optional[list[int]] = None) -> None:
def auto_calibrate(self: Self, sounds: list[Sound] | None = None) -> None:
"""
Automatically calibrate all sounds
:param sounds: List of sounds to calibrate, if None, all sounds will be calibrated in order
:return:
"""
if sounds is None:
sounds = list(range(len(self.sounds)))
sounds = self.sounds

self.auto_calibrations = sounds

Expand All @@ -122,7 +119,7 @@ def check_calibrations(self: Self) -> None:
if len(self.auto_calibrations) == 0:
return

sound = self.sounds[self.auto_calibrations[0]]
sound = self.auto_calibrations[0]

match sound.state.value:
case SoundState.UNINITIALIZED.value:
Expand All @@ -133,3 +130,6 @@ def check_calibrations(self: Self) -> None:
case SoundState.READY.value:
self.auto_calibrations.pop(0)
return

case _:
pass
Loading

0 comments on commit 06b35aa

Please sign in to comment.