diff --git a/drumpy/app/camera_display.py b/drumpy/app/camera_display.py index 59b330f..083c02f 100644 --- a/drumpy/app/camera_display.py +++ b/drumpy/app/camera_display.py @@ -1,5 +1,5 @@ -from typing import Self - +import numpy as np +import numpy.typing as npt import pygame.time from pygame import Surface @@ -14,7 +14,7 @@ class VideoDisplay: """ def __init__( - self: Self, + self, video_source: VideoSource, media_pipe_pose: MediaPipePose, window: Surface, @@ -28,40 +28,44 @@ def __init__( self.video_source = video_source self.prev_surface = None - def update(self: Self) -> None: - frame = self.video_source.get_frame() - assert frame is None or frame.shape[0] == frame.shape[1], "Frame is not square" + def update(self) -> None: + result: npt.NDArray[np.float32] | None = self.video_source.get_frame() + assert ( + result is None or result.shape[0] == result.shape[1] + ), "Frame is not square" # There is no new frame to display - if frame is None and self.prev_surface is not None: + if result is None and self.prev_surface is not None: self.window.blit(self.prev_surface, self.rect.topleft) return # There is no new frame to display and no previous frame to display - if frame is None and self.prev_surface is None: + if result is None and self.prev_surface is None: return - timestamp_ms = self.video_source.get_timestamp_ms() + if result is not None: + frame: npt.NDArray[np.float32] = result + timestamp_ms = self.video_source.get_timestamp_ms() - self.media_pipe_pose.process_image(frame, timestamp_ms) + self.media_pipe_pose.process_image(frame, timestamp_ms) - # Draw the landmarks on the image - if self.media_pipe_pose.visualisation is not None: - frame = self.media_pipe_pose.visualisation + # Draw the landmarks on the image + if self.media_pipe_pose.visualisation is not None: + frame = self.media_pipe_pose.visualisation - if self.source == Source.CAMERA: - frame = frame.swapaxes(0, 1) + if self.source == Source.CAMERA: + frame = frame.swapaxes(0, 1) - image_surface = pygame.image.frombuffer( - frame.tobytes(), (frame.shape[0], frame.shape[0]), "RGB" - ) + image_surface = pygame.image.frombuffer( + frame.tobytes(), (frame.shape[0], frame.shape[0]), "RGB" + ) - # # Rotate the image 90 degrees - # if self.source == Source.CAMERA: - # image_surface = pygame.transform.rotate(image_surface, -90) + # # Rotate the image 90 degrees + # if self.source == Source.CAMERA: + # image_surface = pygame.transform.rotate(image_surface, -90) - # Scale the image to fit the window - image_surface = pygame.transform.scale(image_surface, self.rect.size) + # Scale the image to fit the window + image_surface = pygame.transform.scale(image_surface, self.rect.size) - self.prev_surface = image_surface - self.window.blit(image_surface, self.rect.topleft) + self.prev_surface = image_surface + self.window.blit(image_surface, self.rect.topleft) diff --git a/drumpy/app/fps_display.py b/drumpy/app/fps_display.py index 960e1a5..4719cb0 100644 --- a/drumpy/app/fps_display.py +++ b/drumpy/app/fps_display.py @@ -1,9 +1,9 @@ from typing import Self -from mediapipe.tasks.python.vision import RunningMode +from mediapipe.tasks.python.vision import RunningMode # type: ignore from pygame import Rect -from pygame_gui import UIManager -from pygame_gui.elements import UILabel +from pygame_gui import UIManager # type: ignore +from pygame_gui.elements import UILabel # type: ignore from drumpy.mediapipe_pose.mediapipe_pose import MediaPipePose @@ -21,11 +21,14 @@ def __init__( media_pipe_pose: MediaPipePose, ) -> None: mode = "" - match media_pipe_pose.options.running_mode: - case RunningMode.LIVE_STREAM: + match media_pipe_pose.options.running_mode: # type: ignore + case RunningMode.LIVE_STREAM: # type: ignore mode = "Async Mode" - case RunningMode.VIDEO: + case RunningMode.VIDEO: # type: ignore mode = "Blocking Mode" + case _: # type: ignore + pass + self.model = media_pipe_pose.model super().__init__( Rect((0, 0), (900, 50)), @@ -34,8 +37,8 @@ def __init__( anchors={"top": "top", "left": "left"}, ) - self.ui_time_deltas = [] - self.mediapipe_time_deltas = [] + self.ui_time_deltas: list[float] = [] + self.mediapipe_time_deltas: list[int] = [] self.media_pipe_pose = media_pipe_pose self.ui_manager = ui_manager @@ -68,11 +71,14 @@ def update(self: Self, time_delta: float) -> None: ) mode = "" - match self.media_pipe_pose.options.running_mode: - case RunningMode.LIVE_STREAM: + match self.media_pipe_pose.options.running_mode: # type: ignore + case RunningMode.LIVE_STREAM: # type: ignore mode = "Async Mode" - case RunningMode.VIDEO: + case RunningMode.VIDEO: # type: ignore mode = "Blocking Mode" + case _: # type: ignore + pass + self.set_text( f"UI FPS: {ui_fps:.2f} Camera FPS: {camera_fps:.2f} {mode} Model: {self.model}" ) diff --git a/drumpy/app/main.py b/drumpy/app/main.py index a17e36b..78cde1e 100644 --- a/drumpy/app/main.py +++ b/drumpy/app/main.py @@ -2,9 +2,9 @@ import pygame import pygame.camera -from mediapipe.tasks.python import BaseOptions -from mediapipe.tasks.python.vision import RunningMode -from pygame_gui import UIManager +from mediapipe.tasks.python import BaseOptions # type: ignore +from mediapipe.tasks.python.vision import RunningMode # type: ignore +from pygame_gui import UIManager # type: ignore from drumpy.app.camera_display import VideoDisplay from drumpy.app.fps_display import FPSDisplay @@ -24,9 +24,9 @@ def __init__( self: Self, source: Source = Source.CAMERA, file_path: Optional[str] = None, - running_mode: RunningMode = RunningMode.LIVE_STREAM, + running_mode: RunningMode = RunningMode.LIVE_STREAM, # type: ignore model: LandmarkerModel = LandmarkerModel.FULL, - delegate: BaseOptions.Delegate = BaseOptions.Delegate.GPU, + delegate: BaseOptions.Delegate = BaseOptions.Delegate.GPU, # type: ignore log_file: Optional[str] = None, landmark_type: LandmarkType = LandmarkType.WORLD_LANDMARKS, ) -> None: @@ -53,10 +53,10 @@ def __init__( self.drum_trackers = DrumTrackers() self.media_pipe_pose = MediaPipePose( - running_mode=running_mode, + running_mode=running_mode, # type: ignore model=model, log_file=log_file, - delegate=delegate, + delegate=delegate, # type: ignore landmark_type=landmark_type, drum_trackers=self.drum_trackers, ) @@ -66,11 +66,11 @@ def __init__( media_pipe_pose=self.media_pipe_pose, ) - self.video_source = None match source: case Source.CAMERA: self.video_source = CameraSource(cameras[0]) case Source.FILE: + assert file_path is not None, "File path must be provided" self.video_source = VideoFileSource(file_path) self.fps = self.video_source.get_fps() @@ -109,9 +109,9 @@ def start(self: Self) -> None: def main() -> None: app = App( source=Source.FILE, - running_mode=RunningMode.VIDEO, + running_mode=RunningMode.VIDEO, # type: ignore file_path="../../recordings/multicam_asil_01_front.mkv", - log_file="test.csv", + # log_file="test.csv", ) app.start() diff --git a/drumpy/app/video_source.py b/drumpy/app/video_source.py index 164f811..b22c5c2 100644 --- a/drumpy/app/video_source.py +++ b/drumpy/app/video_source.py @@ -1,13 +1,13 @@ from abc import ABC, abstractmethod from enum import Enum -from multiprocessing import Pool from pathlib import Path -from typing import Self, Optional +from typing import Self import cv2 +import numpy as np import pygame.transform -from numpy import ndarray -from pygame import camera, surfarray +import numpy.typing as npt +from pygame import camera, surfarray, Surface class Source(Enum): @@ -35,7 +35,7 @@ def get_fps(self: Self) -> float: """ @abstractmethod - def get_frame(self: Self) -> Optional[ndarray]: + def get_frame(self: Self) -> npt.NDArray[np.float32] | None: """ Get the next frame from the video :return: The frame and the timestamp @@ -68,7 +68,7 @@ class VideoFileSource(VideoSource): Class to handle a video source from a file """ - def __init__(self: Self, file_path: str) -> None: + def __init__(self, file_path: str) -> None: super().__init__() assert Path(file_path).exists(), f"File {file_path} does not exist" @@ -82,7 +82,6 @@ def __init__(self: Self, file_path: str) -> None: self.size = (smallest, smallest) self.left_offset = (source_width - smallest) // 2 self.top_offset = (source_height - smallest) // 2 - self.pool: Pool = None def get_fps(self: Self) -> float: """ @@ -91,19 +90,19 @@ def get_fps(self: Self) -> float: """ return self.source_fps - def get_frame(self: Self) -> Optional[ndarray]: + def get_frame(self: Self) -> npt.NDArray[np.float32] | None: """ Get the next frame from the video :return: The frame and the timestamp """ ret, frame = self.cap.read() - if not ret or frame is None: + if not ret: self.stopped = True return None frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Crop the image to a square aspect ratio - return frame[ + return frame[ # type: ignore self.top_offset : self.top_offset + self.size[1], self.left_offset : self.left_offset + self.size[0], ].copy() @@ -153,16 +152,16 @@ def get_fps(self: Self) -> float: """ return 60 - def get_frame(self: Self) -> Optional[ndarray]: + def get_frame(self: Self) -> npt.NDArray[np.float32] | None: """ Get the next frame from the video :return: The frame and the timestamp """ if self.camera.query_image(): - frame = pygame.Surface(self.size) + frame = Surface(self.size) image = self.camera.get_image() frame.blit(image, (0, 0), ((self.left_offset, self.top_offset), self.size)) - return surfarray.array3d(frame) + return surfarray.array3d(frame) # type: ignore return None diff --git a/drumpy/drum/drum.py b/drumpy/drum/drum.py index 8fddf44..478073a 100644 --- a/drumpy/drum/drum.py +++ b/drumpy/drum/drum.py @@ -2,9 +2,9 @@ from time import sleep from typing import Self, Optional - from drumpy.drum.sound import Sound, SoundState -from drumpy.util import print_float_array, Position +from drumpy.util import position_str, Position +from drumpy.mediapipe_pose.mediapipe_markers import MarkerEnum class DrumPresets: @@ -49,12 +49,12 @@ def __init__( sounds: list[Sound], sleep_option: SleepOption = SleepOption.SLEEP, ) -> None: - self.sounds = sounds + self.sounds: list[Sound] = sounds # Queue to keep track of sounds that need to be calibrated - self.auto_calibrations = [] + self.auto_calibrations: list[Sound] = [] - self.sleep_option = sleep_option + self.sleep_option: SleepOption = sleep_option def __str__(self: Self) -> str: return "\n".join([str(sound) for sound in self.sounds]) @@ -62,13 +62,13 @@ def __str__(self: Self) -> str: def find_and_play_sound( self: Self, position: Position, - marker_label: str, - sounds: Optional[list[Sound]], + marker: MarkerEnum, + sounds: Optional[list[Sound]] = None, ) -> None: """ Find the closest sound to the given position and play it If the drum is calibrating sounds, the to be calibrated sound will be played - :param marker_label: The label of the marker that is hitting the sound + :param marker: The marker that hit the sound :param sounds: List of sounds to consider, if None, all sounds will be considered :param position: A 3D position as a numpy array :return: @@ -83,7 +83,7 @@ def find_and_play_sound( if sound.state == SoundState.CALIBRATING: sound.hit(position) print( - f"\t{marker_label}: {sound.name} with distance {distance:.3f} at {print_float_array(position)}" + f"\t{marker}: {sound.name} with distance {distance:.3f} at {position_str(position)}" ) return @@ -94,23 +94,20 @@ def find_and_play_sound( if closest_sound is not None: closest_sound.hit(position) print( - f"{marker_label}: {closest_sound.name} with distance {closest_distance:.3f} " - f"at {print_float_array(position)}" + f"{marker}: {closest_sound.name} with distance {closest_distance:.3f} " + f"at {position_str(position)}" ) else: - print( - f"{marker_label}: No sound found for position {print_float_array(position)} " - f"with distance {closest_distance:.3f}" - ) + print(f"{marker}: No sound found for position {position_str(position)}") - def auto_calibrate(self: Self, sounds: Optional[list[int]] = None) -> None: + def auto_calibrate(self: Self, sounds: list[Sound] | None = None) -> None: """ Automatically calibrate all sounds :param sounds: List of sounds to calibrate, if None, all sounds will be calibrated in order :return: """ if sounds is None: - sounds = list(range(len(self.sounds))) + sounds = self.sounds self.auto_calibrations = sounds @@ -122,7 +119,7 @@ def check_calibrations(self: Self) -> None: if len(self.auto_calibrations) == 0: return - sound = self.sounds[self.auto_calibrations[0]] + sound = self.auto_calibrations[0] match sound.state.value: case SoundState.UNINITIALIZED.value: @@ -133,3 +130,6 @@ def check_calibrations(self: Self) -> None: case SoundState.READY.value: self.auto_calibrations.pop(0) return + + case _: + pass diff --git a/drumpy/drum/sound.py b/drumpy/drum/sound.py index 9a39264..8e34b49 100644 --- a/drumpy/drum/sound.py +++ b/drumpy/drum/sound.py @@ -2,13 +2,12 @@ from typing import Self, Optional import numpy as np -import numpy.typing as npt import pygame from termcolor import cprint -from drumpy.util import print_float_array, Position +from drumpy.util import position_str, Position, distance_no_depth -MIN_MARGIN = 0.01 +MARGIN = 0.1 # 10 cm, the margin that the sound can be hit with MIN_HIT_COUNT = 10 @@ -21,37 +20,44 @@ class SoundState(Enum): class Sound: """ Represents a part of the drum kit. - When a hit is registered, the marker has to look at the possible sounds that can be played, and find the - closest one to hit impact. + The Sound can be in one of three states: + UNINITIALIZED: The sound has not been hit yet + READY: The sound is ready to be hit + CALIBRATING: The sound is being calibrated + + The sound can be hit with a position, which will update the position of the sound towards the given position """ def __init__( self: Self, name: str, path: str, - min_margin: float, margin: float, - position: tuple[float, float, float] | None = None, + position: Optional[Position] = None, ) -> None: + """ + Initialize the sound + :param name: The name of the sound + :param path: The path to the sound file + :param margin: The accepted distance to the sound, aka the sound radius + :param position: The initial position of the sound, if None the sound will be uninitialized + """ self.name = name self.sound = pygame.mixer.Sound(path) - self.position: np.ndarray = np.array([0, 0, 0]) - self.state: SoundState = SoundState.UNINITIALIZED - - if position is not None: - self.position = np.array(position) - self.state = SoundState.READY + self.position: Position = ( + position if position is not None else np.array([0, 0, 0]) + ) + self.state: SoundState = ( + SoundState.READY if position is not None else SoundState.UNINITIALIZED + ) # the number of hits that have been registered self.hit_count = 0 - self.hits = [] + self.hits: list[Position] = [] - # the maximum and minimum distance from the sound to the hit that we allow - self.min_margin: float = min_margin - self.margin: float = ( - margin # the current margin will move towards the minimum margin over time - ) + # the margin that the sound can be hit with, aka the distance to the sound, or the size of the sound area + self.margin: float = margin def calibrate(self: Self) -> None: """ @@ -67,37 +73,40 @@ def is_hit(self: Self, position: Position) -> Optional[float]: """ Returns whether the given position is close enough to the sound to be considered a hit. If the sound position is being calibrated, the position is automatically set to running average of the hits - :param position: + :param position: The position of the hit :return: None if the position is not close enough, otherwise the distance to the sound """ - if self.state == SoundState.UNINITIALIZED: - return None - - if self.state == SoundState.CALIBRATING: - self.hits.append(position) - prev_position = self.position - self.position = np.mean(self.hits, axis=0) - - # the sound is calibrated when the position is stable and the hit count is high enough - if ( - np.linalg.norm(self.position - prev_position) < MIN_MARGIN - and self.hit_count > MIN_HIT_COUNT - ): - self.state = SoundState.READY - cprint(f"\n{self.name} calibration done", color="green", attrs=["bold"]) - else: - cprint(f"\nCalibrating {self.name}", color="blue") - - print(f"\tPosition: {print_float_array(self.position)}") - print(f"\tHit count: {self.hit_count}") - - distance = np.linalg.norm(self.position - position) - if distance < self.margin or self.state == SoundState.CALIBRATING: - return distance - - return None - - def hit(self: Self, position: npt.NDArray[np.float64]) -> None: + match self.state: + case SoundState.UNINITIALIZED: + return None + + case SoundState.CALIBRATING: + self.hits.append(position) + mean_position = np.mean(self.hits, axis=0) + self.position = mean_position + distance = distance_no_depth(mean_position, position) + + # the sound is calibrated when the position is stable and the hit count is high enough + if self.hit_count >= MIN_HIT_COUNT and distance <= self.margin: + self.state = SoundState.READY + cprint( + f"\n{self.name} calibration done", color="green", attrs=["bold"] + ) + else: + cprint(f"\nCalibrating {self.name}", color="blue") + + print(f"\tPosition: {position_str(self.position)}") + print(f"\tHit count: {self.hit_count}") + + return distance + + case SoundState.READY: + distance = distance_no_depth(self.position, position) + if distance < self.margin: + return distance + return None + + def hit(self: Self, position: Position) -> None: """ Update the position of the sound slowly to the given position and play it :param position: @@ -105,10 +114,6 @@ def hit(self: Self, position: npt.NDArray[np.float64]) -> None: self.sound.play() self.hit_count += 1 self.position = 0.99 * self.position + 0.01 * position - self.margin = max(self.min_margin, 0.99 * self.margin) - - -MARGIN = 0.1 class SnareDrum(Sound): @@ -117,7 +122,6 @@ def __init__(self: Self) -> None: "Snare Drum", "./DrumSamples/Snare/CKV1_Snare Loud.wav", margin=MARGIN, - min_margin=MIN_MARGIN, ) @@ -127,7 +131,6 @@ def __init__(self: Self) -> None: "High Hat", "./DrumSamples/HiHat/CKV1_HH Closed Loud.wav", margin=MARGIN, - min_margin=MIN_MARGIN, ) @@ -137,7 +140,6 @@ def __init__(self: Self) -> None: "Kick Drum", "./DrumSamples/Kick/CKV1_Kick Loud.wav", margin=MARGIN, - min_margin=MIN_MARGIN, ) @@ -147,7 +149,6 @@ def __init__(self: Self) -> None: "High Hat Foot", "./DrumSamples/HiHat/CKV1_HH Foot.wav", margin=MARGIN, - min_margin=MIN_MARGIN, ) @@ -157,7 +158,6 @@ def __init__(self: Self) -> None: "Tom 1", "./DrumSamples/Perc/Tom1.wav", margin=MARGIN, - min_margin=MIN_MARGIN, ) @@ -167,7 +167,6 @@ def __init__(self: Self) -> None: "Tom 2", "./DrumSamples/Perc/Tom2.wav", margin=MARGIN, - min_margin=MIN_MARGIN, ) @@ -177,5 +176,4 @@ def __init__(self: Self) -> None: "Cymbal", "./DrumSamples/cymbals/Hop_Crs.wav", margin=MARGIN, - min_margin=MIN_MARGIN, ) diff --git a/drumpy/mediapipe_pose/mediapipe_markers.py b/drumpy/mediapipe_pose/mediapipe_markers.py index 869ac7f..83dccb5 100644 --- a/drumpy/mediapipe_pose/mediapipe_markers.py +++ b/drumpy/mediapipe_pose/mediapipe_markers.py @@ -74,5 +74,11 @@ class MarkerEnum(Enum): LEFT_FOOT_INDEX = 31 RIGHT_FOOT_INDEX = 32 + # Some own additions + LEFT_DRUM_STICK = 33 + RIGHT_DRUM_STICK = 34 + LEFT_FOOT = 35 + RIGHT_FOOT = 36 + def __str__(self: Self) -> str: return self.name.replace("_", " ").title() diff --git a/drumpy/mediapipe_pose/mediapipe_pose.py b/drumpy/mediapipe_pose/mediapipe_pose.py index 61ef103..0232614 100644 --- a/drumpy/mediapipe_pose/mediapipe_pose.py +++ b/drumpy/mediapipe_pose/mediapipe_pose.py @@ -1,6 +1,7 @@ -from typing import Any, Self, Optional +from typing import Self, Optional import numpy as np +import numpy.typing as npt from mediapipe import Image, ImageFormat from mediapipe.framework.formats import landmark_pb2 from mediapipe.python import solutions @@ -11,17 +12,16 @@ PoseLandmarkerResult, RunningMode, ) -from numpy import ndarray, dtype -from drumpy.trajectory_file import TrajectoryFile from drumpy.mediapipe_pose.landmark_type import LandmarkType from drumpy.mediapipe_pose.landmarker_model import LandmarkerModel from drumpy.tracking.drum_trackers import DrumTrackers +from drumpy.trajectory_file import TrajectoryFile def visualize_landmarks( - rgb_image: ndarray, detection_result: PoseLandmarkerResult -) -> ndarray[Any, dtype[Any]]: + rgb_image: npt.NDArray[np.float64], detection_result: PoseLandmarkerResult +) -> npt.NDArray[np.float64]: """ Visualize the landmarks on the image given the landmarks and the image """ @@ -94,7 +94,7 @@ def __init__( 0 # The timestamp of the latest frame that was processed ) self.latency: int = 0 # The latency of the pose estimation, in milliseconds - self.visualisation: ndarray | None = None + self.visualisation: npt.NDArray[np.float32] | None = None self.landmark_type = landmark_type @@ -119,7 +119,10 @@ def result_callback( self.latency = timestamp_ms - self.latest_timestamp self.latest_timestamp = timestamp_ms self.drum_trackers.drum.check_calibrations() - if result.pose_landmarks is not None and len(result.pose_landmarks) > 0: + if ( + result.pose_world_landmarks is not None + and len(result.pose_world_landmarks) > 0 + ): self.drum_trackers.update(result.pose_world_landmarks[0]) self.visualisation = visualize_landmarks( image.numpy_view(), self.detection_result @@ -160,7 +163,9 @@ def write_landmarks( landmark.presence, ) - def process_image(self: Self, image_array: ndarray, timestamp_ms: int) -> None: + def process_image( + self: Self, image_array: npt.NDArray[np.float32], timestamp_ms: int + ) -> None: """ Process the image :param timestamp_ms: The timestamp of the frame @@ -174,3 +179,5 @@ def process_image(self: Self, image_array: ndarray, timestamp_ms: int) -> None: case RunningMode.VIDEO: result = self.landmarker.detect_for_video(image, timestamp_ms) self.result_callback(result, image, timestamp_ms) + case _: + pass diff --git a/drumpy/tracking/drum_trackers.py b/drumpy/tracking/drum_trackers.py index 2d3c3ae..3913566 100644 --- a/drumpy/tracking/drum_trackers.py +++ b/drumpy/tracking/drum_trackers.py @@ -1,10 +1,10 @@ from typing import Self -from mediapipe.tasks.python.components.containers.landmark import Landmark +from mediapipe.tasks.python.components.containers.landmark import Landmark # pyright: ignore from drumpy.drum.drum import Drum from drumpy.drum.sound import SnareDrum, HiHat, KickDrum -from drumpy.tracking.marker_tracker_wrapper import MarkerTrackerWrapper, Hand, Foot +from drumpy.tracking.marker_tracker_wrapper import MarkerTrackerWrapper, DrumStick, Foot class DrumTrackers: @@ -22,8 +22,8 @@ def __init__(self: Self) -> None: self.drum.auto_calibrate() self.trackers: list[MarkerTrackerWrapper] = [ - Hand.left_hand(self.drum, [snare_drum, hi_hat]), - Hand.right_hand(self.drum, [snare_drum, hi_hat]), + DrumStick.left_hand(self.drum, [snare_drum, hi_hat]), + DrumStick.right_hand(self.drum, [snare_drum, hi_hat]), Foot.left_foot(self.drum, [kick_drum]), Foot.right_foot(self.drum, [kick_drum]), ] diff --git a/drumpy/tracking/marker_tracker.py b/drumpy/tracking/marker_tracker.py index ad72fb8..cb0d9ea 100644 --- a/drumpy/tracking/marker_tracker.py +++ b/drumpy/tracking/marker_tracker.py @@ -1,9 +1,10 @@ from statistics import mean from typing import Self -from drumpy.drum.sound import Sound from drumpy.drum.drum import Drum +from drumpy.drum.sound import Sound from drumpy.util import Position +from drumpy.mediapipe_pose.mediapipe_markers import MarkerEnum MAX_DISTANCE = 100 @@ -15,35 +16,40 @@ class MarkerTracker: def __init__( self: Self, - label: str, + marker: MarkerEnum, drum: Drum, sounds: list[Sound], memory: int = 15, downward_trend: float = -2, upward_trend: float = 1, ) -> None: - self.label = label + """ + Initialize the marker tracker + :param marker: The marker to track + :param drum: The drum to play the sounds on + :param sounds: The sounds that can be played by this marker + :param memory: How many positions to keep track of + :param downward_trend: The threshold for a downward trend on the z-axis + :param upward_trend: The threshold for an upward trend on the z-axis + """ + self.marker: MarkerEnum = marker # the sounds that can be played by this marker self.sounds: list[Sound] = sounds - # keep track of the last 10 velocities on the z-axis - self.velocities = [] - - # keep track of the last 10 positions + self.velocities: list[float] = [] self.positions: list[Position] = [] # time until next hit can be registered self.time_until_next_hit = 0 - # how many positions to keep track of self.memory = memory + # how many positions to look ahead to determine if a hit is registered # should be smaller than memory self.look_ahead = 5 assert self.look_ahead < self.memory - # thresholds for registering a hit self.downward_trend = downward_trend self.upward_trend = upward_trend @@ -64,11 +70,9 @@ def update(self: Self, position: Position) -> None: self.velocities.pop(0) if self.is_hit(): - position = self.positions[-self.look_ahead] - self.time_until_next_hit = self.memory self.drum.find_and_play_sound( - self.positions[-self.look_ahead], self.label, self.sounds + self.positions[-self.look_ahead], self.marker, self.sounds ) def get_velocity(self: Self) -> float: @@ -96,8 +100,3 @@ def is_hit(self: Self) -> bool: and avg_z_look_ahead > self.upward_trend and self.time_until_next_hit == 0 ) - - def __str__(self: Self) -> str: - return "{}: \n{} {}".format( - self.label, self.positions[-1], self.velocities[-1] - ) diff --git a/drumpy/tracking/marker_tracker_wrapper.py b/drumpy/tracking/marker_tracker_wrapper.py index 1a6a1f5..89dc5d5 100644 --- a/drumpy/tracking/marker_tracker_wrapper.py +++ b/drumpy/tracking/marker_tracker_wrapper.py @@ -2,13 +2,13 @@ from typing import Self import numpy as np -from mediapipe.tasks.python.components.containers.landmark import Landmark +from mediapipe.tasks.python.components.containers.landmark import Landmark # pyright: ignore -from drumpy.drum.sound import Sound from drumpy.drum.drum import Drum +from drumpy.drum.sound import Sound from drumpy.mediapipe_pose.mediapipe_markers import MarkerEnum from drumpy.tracking.marker_tracker import MarkerTracker -from drumpy.util import landmark_to_numpy +from drumpy.util import landmark_to_position, Position class MarkerTrackerWrapper(ABC): @@ -24,7 +24,7 @@ def update(self: Self, markers: list[Landmark]) -> None: """ -class Hand(MarkerTrackerWrapper): +class DrumStick(MarkerTrackerWrapper): def __init__( self: Self, wrist: MarkerEnum, @@ -45,18 +45,14 @@ def update(self: Self, markers: list[Landmark]) -> None: pinky_landmark = markers[self.pinky.value] index_landmark = markers[self.index.value] - self.wrist.pos = landmark_to_numpy(wrist_landmark) - self.pinky.pos = landmark_to_numpy(pinky_landmark) - self.index.pos = landmark_to_numpy(index_landmark) + wrist_pos = landmark_to_position(wrist_landmark) + pinky_pos = landmark_to_position(pinky_landmark) + index_pos = landmark_to_position(index_landmark) - direction = ( - self.wrist.pos - + (self.index.pos - self.wrist.pos) - + (self.pinky.pos - self.wrist.pos) - ) + direction = wrist_pos + (index_pos - pinky_pos) + (pinky_pos - wrist_pos) # increase the length of the direction vector by 50 - self.position = self.wrist.pos + 50 * direction / np.linalg.norm(direction) + self.position = wrist_pos + 50 * direction / np.linalg.norm(direction) self.tracker.update(self.position) @@ -66,8 +62,11 @@ def left_hand(drum: Drum, sounds: list[Sound]) -> MarkerTrackerWrapper: pinky = MarkerEnum.LEFT_PINKY index = MarkerEnum.LEFT_INDEX - return Hand( - wrist, pinky, index, MarkerTracker("Left Hand", drum=drum, sounds=sounds) + return DrumStick( + wrist, + pinky, + index, + MarkerTracker(MarkerEnum.LEFT_DRUM_STICK, drum=drum, sounds=sounds), ) @staticmethod @@ -76,29 +75,35 @@ def right_hand(drum: Drum, sounds: list[Sound]) -> MarkerTrackerWrapper: pinky = MarkerEnum.RIGHT_PINKY index = MarkerEnum.RIGHT_INDEX - return Hand( - wrist, pinky, index, MarkerTracker("Right Hand", drum=drum, sounds=sounds) + return DrumStick( + wrist, + pinky, + index, + MarkerTracker(MarkerEnum.RIGHT_DRUM_STICK, drum=drum, sounds=sounds), ) class Foot(MarkerTrackerWrapper): def __init__(self: Self, toe_tip: MarkerEnum, tracker: MarkerTracker) -> None: self.toe_tip = toe_tip - self.pos: np.array = np.array([0, 0, 0]) + self.position: Position = np.array([0, 0, 0]) self.tracker = tracker def update(self: Self, markers: list[Landmark]) -> None: - self.toe_tip.pos = landmark_to_numpy(markers[self.toe_tip.value]) - self.pos = self.toe_tip.pos + self.position = landmark_to_position(markers[self.toe_tip.value]) - self.tracker.update(self.pos) + self.tracker.update(self.position) @staticmethod def left_foot(drum: Drum, sounds: list[Sound]) -> MarkerTrackerWrapper: toe_tip = MarkerEnum.LEFT_FOOT_INDEX - return Foot(toe_tip, MarkerTracker("Left Foot", drum=drum, sounds=sounds)) + return Foot( + toe_tip, MarkerTracker(MarkerEnum.LEFT_FOOT, drum=drum, sounds=sounds) + ) @staticmethod def right_foot(drum: Drum, sounds: list[Sound]) -> MarkerTrackerWrapper: toe_tip = MarkerEnum.RIGHT_FOOT_INDEX - return Foot(toe_tip, MarkerTracker("Right Foot", drum=drum, sounds=sounds)) + return Foot( + toe_tip, MarkerTracker(MarkerEnum.RIGHT_FOOT, drum=drum, sounds=sounds) + ) diff --git a/drumpy/util.py b/drumpy/util.py index 166b5dc..94ccde8 100644 --- a/drumpy/util.py +++ b/drumpy/util.py @@ -1,19 +1,21 @@ -import numpy.typing as npt +from typing import TypeAlias + import numpy as np -from mediapipe.tasks.python.components.containers.landmark import Landmark -from nptyping import NDArray, Shape, Float64 +import numpy.typing as npt +from mediapipe.tasks.python.components.containers.landmark import Landmark # type: ignore + +# Type for a 3D position, x, y, z +Position: TypeAlias = npt.NDArray[np.float64] -def print_float_array(array: npt.NDArray[np.float64]) -> str: +def position_str(position: Position) -> str: """ - Print a float array with 3 decimal places - :param array: - :return: + Print a position with 3 decimal places """ - return f"[{', '.join([f'{x:.3f}' for x in array])}]" + return f"[{', '.join([f'{x:.3f}' for x in position])}]" -def landmark_to_numpy(landmark: Landmark) -> np.array: +def landmark_to_position(landmark: Landmark) -> Position: """ Convert a mediapipe landmark to a numpy array Also switches some axes around: @@ -21,10 +23,20 @@ def landmark_to_numpy(landmark: Landmark) -> np.array: y -> z, the vertical axis z -> x, the depth axis """ - return np.array([landmark.y, landmark.z, landmark.x]) + assert landmark.x is not None + assert landmark.y is not None + assert landmark.z is not None + x = float(landmark.y) + y = float(landmark.z) + z = float(landmark.x) + return np.array([x, y, z]) -# Type for a 3D position, x, y, z -Position = NDArray[Shape["3"], Float64] -# Type for a 3D velocity, x, y, z -Velocity = NDArray[Shape["3"], Float64] +def distance_no_depth(a: Position, b: Position) -> float: + """ + Calculate the distance between two 3D positions without considering the depth, the x-axis + :param a: + :param b: + :return: + """ + return float(np.linalg.norm(a[1:] - b[1:])) diff --git a/poetry.lock b/poetry.lock index eaa001f..1506622 100644 --- a/poetry.lock +++ b/poetry.lock @@ -853,27 +853,6 @@ files = [ [package.dependencies] setuptools = "*" -[[package]] -name = "nptyping" -version = "2.5.0" -description = "Type hints for NumPy." -optional = false -python-versions = ">=3.7" -files = [ - {file = "nptyping-2.5.0-py3-none-any.whl", hash = "sha256:764e51836faae33a7ae2e928af574cfb701355647accadcc89f2ad793630b7c8"}, - {file = "nptyping-2.5.0.tar.gz", hash = "sha256:e3d35b53af967e6fb407c3016ff9abae954d3a0568f7cc13a461084224e8e20a"}, -] - -[package.dependencies] -numpy = {version = ">=1.20.0,<2.0.0", markers = "python_version >= \"3.8\""} - -[package.extras] -build = ["invoke (>=1.6.0)", "pip-tools (>=6.5.0)"] -complete = ["pandas", "pandas-stubs-fork"] -dev = ["autoflake", "beartype (<0.10.0)", "beartype (>=0.10.0)", "black", "codecov (>=2.1.0)", "coverage", "feedparser", "invoke (>=1.6.0)", "isort", "mypy", "pandas", "pandas-stubs-fork", "pip-tools (>=6.5.0)", "pylint", "pyright", "setuptools", "typeguard", "wheel"] -pandas = ["pandas", "pandas-stubs-fork"] -qa = ["autoflake", "beartype (<0.10.0)", "beartype (>=0.10.0)", "black", "codecov (>=2.1.0)", "coverage", "feedparser", "isort", "mypy", "pylint", "pyright", "setuptools", "typeguard", "wheel"] - [[package]] name = "nuitka" version = "2.1.5" @@ -1488,6 +1467,24 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "pyright" +version = "1.1.358" +description = "Command line wrapper for pyright" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyright-1.1.358-py3-none-any.whl", hash = "sha256:0995b6a95eb11bd26f093cd5dee3d5e7258441b1b94d4a171b5dc5b79a1d4f4e"}, + {file = "pyright-1.1.358.tar.gz", hash = "sha256:185524a8d52f6f14bbd3b290b92ad905f25b964dddc9e7148aad760bd35c9f60"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" + +[package.extras] +all = ["twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1578,28 +1575,28 @@ files = [ [[package]] name = "ruff" -version = "0.3.5" +version = "0.3.7" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:aef5bd3b89e657007e1be6b16553c8813b221ff6d92c7526b7e0227450981eac"}, - {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:89b1e92b3bd9fca249153a97d23f29bed3992cff414b222fcd361d763fc53f12"}, - {file = "ruff-0.3.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e55771559c89272c3ebab23326dc23e7f813e492052391fe7950c1a5a139d89"}, - {file = "ruff-0.3.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dabc62195bf54b8a7876add6e789caae0268f34582333cda340497c886111c39"}, - {file = "ruff-0.3.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a05f3793ba25f194f395578579c546ca5d83e0195f992edc32e5907d142bfa3"}, - {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:dfd3504e881082959b4160ab02f7a205f0fadc0a9619cc481982b6837b2fd4c0"}, - {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87258e0d4b04046cf1d6cc1c56fadbf7a880cc3de1f7294938e923234cf9e498"}, - {file = "ruff-0.3.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:712e71283fc7d9f95047ed5f793bc019b0b0a29849b14664a60fd66c23b96da1"}, - {file = "ruff-0.3.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a532a90b4a18d3f722c124c513ffb5e5eaff0cc4f6d3aa4bda38e691b8600c9f"}, - {file = "ruff-0.3.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:122de171a147c76ada00f76df533b54676f6e321e61bd8656ae54be326c10296"}, - {file = "ruff-0.3.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d80a6b18a6c3b6ed25b71b05eba183f37d9bc8b16ace9e3d700997f00b74660b"}, - {file = "ruff-0.3.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a7b6e63194c68bca8e71f81de30cfa6f58ff70393cf45aab4c20f158227d5936"}, - {file = "ruff-0.3.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a759d33a20c72f2dfa54dae6e85e1225b8e302e8ac655773aff22e542a300985"}, - {file = "ruff-0.3.5-py3-none-win32.whl", hash = "sha256:9d8605aa990045517c911726d21293ef4baa64f87265896e491a05461cae078d"}, - {file = "ruff-0.3.5-py3-none-win_amd64.whl", hash = "sha256:dc56bb16a63c1303bd47563c60482a1512721053d93231cf7e9e1c6954395a0e"}, - {file = "ruff-0.3.5-py3-none-win_arm64.whl", hash = "sha256:faeeae9905446b975dcf6d4499dc93439b131f1443ee264055c5716dd947af55"}, - {file = "ruff-0.3.5.tar.gz", hash = "sha256:a067daaeb1dc2baf9b82a32dae67d154d95212080c80435eb052d95da647763d"}, + {file = "ruff-0.3.7-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0e8377cccb2f07abd25e84fc5b2cbe48eeb0fea9f1719cad7caedb061d70e5ce"}, + {file = "ruff-0.3.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:15a4d1cc1e64e556fa0d67bfd388fed416b7f3b26d5d1c3e7d192c897e39ba4b"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d28bdf3d7dc71dd46929fafeec98ba89b7c3550c3f0978e36389b5631b793663"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:379b67d4f49774ba679593b232dcd90d9e10f04d96e3c8ce4a28037ae473f7bb"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c060aea8ad5ef21cdfbbe05475ab5104ce7827b639a78dd55383a6e9895b7c51"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:ebf8f615dde968272d70502c083ebf963b6781aacd3079081e03b32adfe4d58a"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d48098bd8f5c38897b03604f5428901b65e3c97d40b3952e38637b5404b739a2"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:da8a4fda219bf9024692b1bc68c9cff4b80507879ada8769dc7e985755d662ea"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c44e0149f1d8b48c4d5c33d88c677a4aa22fd09b1683d6a7ff55b816b5d074f"}, + {file = "ruff-0.3.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3050ec0af72b709a62ecc2aca941b9cd479a7bf2b36cc4562f0033d688e44fa1"}, + {file = "ruff-0.3.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a29cc38e4c1ab00da18a3f6777f8b50099d73326981bb7d182e54a9a21bb4ff7"}, + {file = "ruff-0.3.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5b15cc59c19edca917f51b1956637db47e200b0fc5e6e1878233d3a938384b0b"}, + {file = "ruff-0.3.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e491045781b1e38b72c91247cf4634f040f8d0cb3e6d3d64d38dcf43616650b4"}, + {file = "ruff-0.3.7-py3-none-win32.whl", hash = "sha256:bc931de87593d64fad3a22e201e55ad76271f1d5bfc44e1a1887edd0903c7d9f"}, + {file = "ruff-0.3.7-py3-none-win_amd64.whl", hash = "sha256:5ef0e501e1e39f35e03c2acb1d1238c595b8bb36cf7a170e7c1df1b73da00e74"}, + {file = "ruff-0.3.7-py3-none-win_arm64.whl", hash = "sha256:789e144f6dc7019d1f92a812891c645274ed08af6037d11fc65fcbc183b7d59f"}, + {file = "ruff-0.3.7.tar.gz", hash = "sha256:d5c1aebee5162c2226784800ae031f660c350e7a3402c4d1f8ea4e97e232e3ba"}, ] [[package]] @@ -1916,4 +1913,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "~3.11" -content-hash = "f47e90dc927bfc078196a9f59cf3dd7e07684f73aef0ad655480fbdeb5e0a133" +content-hash = "3e643ac2941a51cdd3d08d06804a3ab14ab6dbde126977cfd0fdfe5b69920630" diff --git a/pyproject.toml b/pyproject.toml index b43eaa1..9778232 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ mediapipe = "0.10.11" opencv-python = "4.9.0.80" termcolor = "2.4.0" pygame-gui = "0.6.9" -nptyping = "^2.5.0" [tool.poetry.group.dev.dependencies] pre-commit = "3.7.0" @@ -24,11 +23,17 @@ deptry = "0.16.1" nuitka = "2.1.5" ruff-lsp = "^0.0.53" ruff = "^0.3.5" +pyright = "^1.1.358" [tool.ruff.lint] -select = [ "E", "F", "N", "ANN", "FBT", "B", "A", "C4", "PIE", "Q", "RET", "SLF", "SIM", "ARG", "PL", "PERF", "RUF"] +select = ["E", "F", "N", "FBT", "B", "A", "C4", "PIE", "Q", "RET", "SLF", "SIM", "ARG", "PL", "PERF", "RUF"] fixable = ["ALL"] pylint.max-args = 10 [tool.ruff.lint.pycodestyle] max-line-length = 120 + +[tool.pyright] +include = ["drumpy"] +typeCheckingMode = "strict" +exclude = ["**/*mediapipe_pose.py"]